原文链接:http://zaumreit.me/blog/2013/12/15/mahout-chinese-classification/
===================
mahout 中文分类 里只写了如何部署 Mahout 以及如何训练模型以及测试,并没有写如何对新的数据进行分类。这篇文章讲如何对新的数据进行分类。
Mahout 好像没有提供命令行工具来对新数据进行向量化。好在 上文 中,从训练数据生成向量(vector)的过程中生成了 dictionary
,df-count
,labelindex
等文件(这些文件在 hdfs 上),Mahout API 也提供了读取这些文件的相关方法,所以可以自己写代码对新文档进行分类。
构造 TF-IDF 向量
上文 中,训练模型的训练数据我用了 tfidf-vector
,测试的时候也是 tfidf-vector
,所以为了应用这个模型,需要把新文档表示成 tfidf-vector
。
Class TFIDF
提供了一个方法 calculate(int tf,int df,int length,int numDocs)
,用来计算一个词的 tf-idf
值,4 个参数分别代表:
tf
: 单词在新文档中出现的次数
df
: 训练集中包含单词的文档个数
length
: 新文档包含的所有单词个数
numDocs
: 训练集所有文档个数
Mahout 的源码里计算 tf-idf 值的函数,length
参数没有被用到:
1
2
3
4
5
6
7
8
| public class TFIDF implements Weight {
private final DefaultSimilarity sim = new DefaultSimilarity();
@Override
public double calculate(int tf, int df, int length, int numDocs) {
// ignore length
return sim.tf(tf) * sim.idf(df, numDocs);
}
}
|
tf
和 length
从新文档里可以统计出来,df
和 numDocs
需要从 df-count
文件里取到。df-count
包含所有词的 key/value
,其中 key
是 dictionary
文件里对应的 value
。df-count
文件里 key = 1
代表训练集文档个数。
dictionary
文件,key 是单词,Value 是对应的 ID(我没有去停用词和标点符号):
1
2
3
4
5
6
7
| Key: !: Value: 0
Key: ": Value: 1
Key: #: Value: 2
Key: $: Value: 3
Key: %: Value: 4
Key: ': Value: 5
....
|
df-count
文件,其实是 df-count
目录下的 part-r-00000
的文件,key 是 ID,value 是包含单词的文档个数:
1
2
3
4
5
6
7
| Key: -1: Value: 8361
Key: 0: Value: 154
Key: 1: Value: 94
Key: 2: Value: 6
Key: 3: Value: 3
Key: 4: Value: 10
....
|
有了这些信息就可以计算 tf-idf
,并且构造 vector 了。这里用到的 vector 是 Class Vector
,每一个元素是 id : tf-idf
。
1
2
| Vector vector = new RandomAccessSparseVector(10000);
vector.setQuick(wordId, tfIdfValue); //设定一个词的tf-idf值
|
分类
利用 Class StandardNaiveBayesClassifier
的 classifyFull()
函数进行分类。
1
2
3
4
5
6
| //从模型文件读取模型
NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);
//用模型初始化分类器
StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);
//返回 vector 在所有类别下的得分,得分最高的就是最后的分类
Vector resultVector = classifier.classifyFull(vector);
|
也可以把计算好的 tf-idf vector
输出到 sequence
类型的文件里,然后用命令行工具 mahout testnb
来看朴素贝叶斯分类器对新文档的分类效果。输出到 sequence
文件的代码,sequence
文件也是由 key/value
对组成。
1
2
3
4
5
6
7
| Writer writer = new SequenceFile.Writer(fs, configuration, new Path(outputFileName), Text.class, VectorWritable.class);
Text key = new Text();
VectorWritable value = new VectorWritable();
key.set("/" + label + "/" + inputFileName); // label 是预期的类别标签,inputFileName 作为向量的标识
value.set(vector); // value 是 输入文档的 `tf-idf` 向量
writer.append(key, value);
writer.close();
|
这样就把向量输出到一个 sequence
文件里了。
完整代码
这是我用来分类的完整代码,我是先把 model
,labelindex
,df-count
,dictionary
文件从 hdfs 上弄下来之后放到工程目录下使用的,也可以直接连接 hdfs 来读取这些文件:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
| import java.io.StringReader;
import java.util.HashMap;
import java.util.Map;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.vectorizer.TFIDF;
import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.Multiset;
public class Classifier {
public static Map<String, Integer> readDictionnary(Configuration conf, Path dictionnaryPath) {
Map<String, Integer> dictionnary = new HashMap<String, Integer>();
for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(dictionnaryPath, true, conf)) {
dictionnary.put(pair.getFirst().toString(), pair.getSecond().get());
}
return dictionnary;
}
public static Map<Integer, Long> readDocumentFrequency(Configuration conf, Path documentFrequencyPath) {
Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>();
for (Pair<IntWritable, LongWritable> pair : new SequenceFileIterable<IntWritable, LongWritable>(documentFrequencyPath, true, conf)) {
documentFrequency.put(pair.getFirst().get(), pair.getSecond().get());
}
return documentFrequency;
}
public static void main(String[] args) throws Exception {
//上述几个文件路径
String modelPath = "./mahout/model";
String labelIndexPath = "./mahout/labelindex";
String dictionaryPath = "./mahout/vectors/dictionary.file-0";
String documentFrequencyPath = "./mahout/vectors/df-count/part-r-00000";
Configuration configuration = new Configuration();
//hdfs 配置
//configuration.set("fs.default.name", "hdfs://172.21.1.129:9000");
//configuration.set("mapred.job.tracker", "172.21.1.129:9001");
//读取模型文件
NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);
//初始化训练器
StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);
//读取 labelindex、dictionary、df-count
Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath));
//文本分析的 analyzer,我之前是用 fudannlp 对文件进行了分词
//输入是以空格分割的文件
//所以用 WhitespaceAnalyzer,也可以换成其他 analyzer
//lucene 版本是 4.3.0
Analyzer analyzer = new WhitespaceAnalyzer(Version.LUCENE_43);
//读取训练集包含的文档个数
int documentCount = documentFrequency.get(-1).intValue();
//待分类文本
String content = "";
Multiset<String> words = ConcurrentHashMultiset.create();
TokenStream ts = analyzer.tokenStream("text", new StringReader(content));
CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
ts.reset();
int wordCount = 0;
//统计在 dictionary 里出现的待分类的新文档的词
while (ts.incrementToken()) {
if (termAtt.length() > 0) {
String word = ts.getAttribute(CharTermAttribute.class).toString();
Integer wordId = dictionary.get(word);
if (wordId != null) {
words.add(word);
wordCount++;
}
}
}
//计算 TF-IDF,并构造 Vector
Vector vector = new RandomAccessSparseVector(10000);
TFIDF tfidf = new TFIDF();
for (Multiset.Entry<String> entry : words.entrySet()) {
String word = entry.getElement();
int count = entry.getCount();
Integer wordId = dictionary.get(word);
Long freq = documentFrequency.get(wordId);
double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
vector.setQuick(wordId, tfIdfValue);
}
//分类
Vector resultVector = classifier.classifyFull(vector);
double bestScore = -Double.MAX_VALUE;
int bestCategoryId = -1;
for(Element element : resultVector.all()) {
int categoryId = element.index();
double score = element.get();
if (score > bestScore) {
bestScore = score;
bestCategoryId = categoryId;
}
System.out.print(" " + labels.get(categoryId) + ": " + score);
}
System.out.println(" => " + labels.get(bestCategoryId));
analyzer.close();
}
}
|
利用 Mahout 朴素贝叶斯分类大概就这样了。