用 Hadoop 统计词频并存入 HBase 中

统计一个 TXT 中的所有词语出现的平均频率(总出现次数/总共出现过的TXT文档数量),并写入 Hbase

一共用到 MapRecuce 的四个步骤:

Mapper 负责把把原来的任务分成很多Key-Value块。本题中,我们把任务分成这样的键值对:<Term#Doc, 1>

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class TokenizerMapper extends Mapper<LongWritable, Text, Text, IntWritable> {

private static final IntWritable one = new IntWritable(1);
private Text word = new Text();

@Override
public void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
StringTokenizer itr = new StringTokenizer(value.toString());
while (itr.hasMoreTokens()) {
word.set(itr.nextToken() + '#' + getSource((FileSplit) context.getInputSplit()));
context.write(word, one);
}
}

private static String getSource(FileSplit split) {
String fileName = split.getPath().getName();
return fileName.split("\\.", 2)[0];
}
}

Combiner 是可选的,负责在 Mapper 后做个简单的合并,从而可以减少 Mapper 节点到 Reducer 节点之间的传输代价。比如对于上述的键值对,如果 Term#Doc 相同(换句话说,同一个词在同一篇文档中出现多次),可以直接替换成 <Term#Doc, n> , n为出现次数

1
2
3
4
5
6
7
8
9
10
11
12
13
public class LocalSumCombiner extends Reducer<Text, IntWritable, Text, IntWritable> {
private IntWritable result = new IntWritable();

@Override
public void reduce(Text key, Iterable<IntWritable> values, Context context) throws IOException, InterruptedException {
int sum = 0;
for (IntWritable val : values) {
sum += val.get();
}
result.set(sum);
context.write(key, result);
}
}

Partitioner 负责把 Key-Value 对 shuffle 到某个 Reducer 上。缺省的 Partitioner 就是对 Key 做哈希、再和 Reducer 数量取模。一方面 Partition 能实现 Reducer 的负载均衡,另一方面,我们常常希望这个分配的过程有意义,例如把同一个词(term)分到同一个 Reducer 上。

1
2
3
4
5
6
7
public class TermPartitioner extends Partitioner<Text, IntWritable> {
@Override
public int getPartition(Text key, IntWritable value, int numReduceTasks) {
String term = key.toString().split("#")[0];
return term.hashCode() % numReduceTasks;
}
}

Reducer 负责汇总数据,产生结果。注意到 MapReduce 框架保证 shuffle 过来的数据是排序过的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
public class WordFreqReducer extends TableReducer<Text, IntWritable, ImmutableBytesWritable> {

private String prevTerm = "";

private List<Posting> postings = new ArrayList<>();

@Override
public void reduce(Text key, Iterable<IntWritable> values, WordFreqReducer.Context context)
throws IOException, InterruptedException {
String keySplits[] = key.toString().split("#");
String term = keySplits[0];
String document = keySplits[1];

if (!term.equals(prevTerm)) {
cleanup(context);
}

int count = 0;
for (IntWritable val : values) {
count += val.get();
}

postings.add(new Posting(document, count));
prevTerm = term;
}

@Override
public void cleanup(WordFreqReducer.Context context) throws IOException, InterruptedException {
if (!"".equals(prevTerm)) {
int total = postings.stream().mapToInt(Posting::getFrequency).sum();
double average = (double) total / postings.size();

Put put = new Put(Bytes.toBytes(prevTerm));
put.addColumn(Bytes.toBytes(InvertedIndexer.COLUMN_FAMILY_NAME),
Bytes.toBytes(InvertedIndexer.COLUMN_NAME), Bytes.toBytes(String.format("%.2f", average)));
context.write(null, put);
postings.clear();
}
}

static private class Posting {
private final String document;
private final int frequency;

Posting(String document, int frequency) {
this.document = document;
this.frequency = frequency;
}

public String getDocument() {
return document;
}

public int getFrequency() {
return frequency;
}

@Override
public String toString() {
return String.format("%s:%d", document, frequency);
}
}
}

最后附上 main 函数的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
public class InvertedIndexer {
private static Logger logger = Logger.getLogger(InvertedIndexer.class);

private static final Configuration HBASE_CONFIG = HBaseConfiguration.create();

public static final String TABLE_NAME = "wuxia";
public static final String COLUMN_FAMILY_NAME = "cf";
public static final String COLUMN_NAME = "freq";

public static void main(String[] args) throws Exception {
createTableIfNotExist(TABLE_NAME, COLUMN_FAMILY_NAME);

Job job = Job.getInstance(new Configuration(), "word_freq");
job.setJarByClass(InvertedIndexer.class);

TableMapReduceUtil.initTableReducerJob(TABLE_NAME, WordFreqReducer.class, job, TermPartitioner.class);
job.setMapperClass(TokenizerMapper.class);
job.setPartitionerClass(TermPartitioner.class);
job.setCombinerClass(LocalSumCombiner.class);
job.setReducerClass(WordFreqReducer.class);

job.setOutputKeyClass(Text.class);
job.setOutputValueClass(IntWritable.class);
FileInputFormat.addInputPath(job, new Path(args[0]));

System.exit(job.waitForCompletion(true) ? 0 : 1);
}

private static void createTableIfNotExist(String tableName, String... families)
throws Exception {
Connection connection = ConnectionFactory.createConnection(HBASE_CONFIG);
Admin admin = connection.getAdmin();

HTableDescriptor desc = new HTableDescriptor(TableName.valueOf(tableName));
for (String family : families) {
desc.addFamily(new HColumnDescriptor(family));
}

if (admin.tableExists(TableName.valueOf(tableName))) {
logger.info(String.format("Table '%s' already created", tableName));
} else {
admin.createTable(desc);
logger.info(String.format("Created table '%s'", tableName));
}
}
}

注意,用到 HBase 的情况下,依赖只需要这两个就行了:

1
2
3
4
5
6
7
8
9
10
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-client</artifactId>
<version>1.2.3</version>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase-server</artifactId>
<version>1.2.3</version>
</dependency>
以及 Apache 源
1
2
3
4
<repository>
<id>apache</id>
<url>http://maven.apache.org</url>
</repository>