Mahout 中文分类 (2)

原文链接:http://zaumreit.me/blog/2013/12/15/mahout-chinese-classification/

===================

mahout 中文分类 里只写了如何部署 Mahout 以及如何训练模型以及测试,并没有写如何对新的数据进行分类。这篇文章讲如何对新的数据进行分类。

Mahout 好像没有提供命令行工具来对新数据进行向量化。好在 上文 中,从训练数据生成向量(vector)的过程中生成了 dictionarydf-countlabelindex 等文件(这些文件在 hdfs 上),Mahout API 也提供了读取这些文件的相关方法,所以可以自己写代码对新文档进行分类。

构造 TF-IDF 向量

上文 中,训练模型的训练数据我用了 tfidf-vector ,测试的时候也是 tfidf-vector,所以为了应用这个模型,需要把新文档表示成 tfidf-vector

Class TFIDF 提供了一个方法 calculate(int tf,int df,int length,int numDocs),用来计算一个词的 tf-idf 值,4 个参数分别代表:
tf: 单词在新文档中出现的次数 
df: 训练集中包含单词的文档个数 
length: 新文档包含的所有单词个数 
numDocs: 训练集所有文档个数

Mahout 的源码里计算 tf-idf 值的函数,length 参数没有被用到:

1
2
3
4
5
6
7
8
public class TFIDF implements Weight {
  private final DefaultSimilarity sim = new DefaultSimilarity();
  @Override
  public double calculate(int tf, int df, int length, int numDocs) {
    // ignore length    
    return sim.tf(tf) * sim.idf(df, numDocs);
  }
}

tf 和 length 从新文档里可以统计出来,df 和 numDocs 需要从 df-count 文件里取到。df-count 包含所有词的 key/value,其中 key 是 dictionary 文件里对应的 valuedf-count文件里 key = 1 代表训练集文档个数。

dictionary 文件,key 是单词,Value 是对应的 ID(我没有去停用词和标点符号):

1
2
3
4
5
6
7
Key: !: Value: 0
Key: ": Value: 1
Key: #: Value: 2
Key: $: Value: 3
Key: %: Value: 4
Key: ': Value: 5
....

df-count 文件,其实是 df-count 目录下的 part-r-00000 的文件,key 是 ID,value 是包含单词的文档个数:

1
2
3
4
5
6
7
Key: -1: Value: 8361
Key: 0: Value: 154
Key: 1: Value: 94
Key: 2: Value: 6
Key: 3: Value: 3
Key: 4: Value: 10
....

有了这些信息就可以计算 tf-idf,并且构造 vector 了。这里用到的 vector 是 Class Vector,每一个元素是 id : tf-idf

1
2
Vector vector = new RandomAccessSparseVector(10000);
vector.setQuick(wordId, tfIdfValue); //设定一个词的tf-idf值

分类

利用 Class StandardNaiveBayesClassifier 的 classifyFull() 函数进行分类。

1
2
3
4
5
6
//从模型文件读取模型
NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);
//用模型初始化分类器
StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);
//返回 vector 在所有类别下的得分,得分最高的就是最后的分类
Vector resultVector = classifier.classifyFull(vector);

也可以把计算好的 tf-idf vector 输出到 sequence 类型的文件里,然后用命令行工具 mahout testnb 来看朴素贝叶斯分类器对新文档的分类效果。输出到 sequence 文件的代码,sequence文件也是由 key/value 对组成。

1
2
3
4
5
6
7
Writer writer = new SequenceFile.Writer(fs, configuration, new Path(outputFileName), Text.class, VectorWritable.class);
Text key = new Text();
VectorWritable value = new VectorWritable();
key.set("/" + label + "/" + inputFileName); // label 是预期的类别标签,inputFileName 作为向量的标识
value.set(vector); // value 是 输入文档的 `tf-idf` 向量
writer.append(key, value);
writer.close();

这样就把向量输出到一个 sequence 文件里了。

完整代码

这是我用来分类的完整代码,我是先把 modellabelindexdf-countdictionary 文件从 hdfs 上弄下来之后放到工程目录下使用的,也可以直接连接 hdfs 来读取这些文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import java.io.StringReader;
import java.util.HashMap;
import java.util.Map;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.core.WhitespaceAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.vectorizer.TFIDF;

import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.Multiset;

public class Classifier {

    public static Map<String, Integer> readDictionnary(Configuration conf, Path dictionnaryPath) {
        Map<String, Integer> dictionnary = new HashMap<String, Integer>();
        for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(dictionnaryPath, true, conf)) {
            dictionnary.put(pair.getFirst().toString(), pair.getSecond().get());
        }
        return dictionnary;
    }

    public static Map<Integer, Long> readDocumentFrequency(Configuration conf, Path documentFrequencyPath) {
        Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>();
        for (Pair<IntWritable, LongWritable> pair : new SequenceFileIterable<IntWritable, LongWritable>(documentFrequencyPath, true, conf)) {
            documentFrequency.put(pair.getFirst().get(), pair.getSecond().get());
        }
        return documentFrequency;
    }

    public static void main(String[] args) throws Exception {
        //上述几个文件路径
        String modelPath = "./mahout/model";
        String labelIndexPath = "./mahout/labelindex";
        String dictionaryPath = "./mahout/vectors/dictionary.file-0";
        String documentFrequencyPath = "./mahout/vectors/df-count/part-r-00000";

        Configuration configuration = new Configuration();
        //hdfs 配置
        //configuration.set("fs.default.name", "hdfs://172.21.1.129:9000");
        //configuration.set("mapred.job.tracker", "172.21.1.129:9001");

        //读取模型文件
        NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), configuration);
        //初始化训练器
        StandardNaiveBayesClassifier classifier = new StandardNaiveBayesClassifier(model);

        //读取 labelindex、dictionary、df-count
        Map<Integer, String> labels = BayesUtils.readLabelIndex(configuration, new Path(labelIndexPath));
        Map<String, Integer> dictionary = readDictionnary(configuration, new Path(dictionaryPath));
        Map<Integer, Long> documentFrequency = readDocumentFrequency(configuration, new Path(documentFrequencyPath));

        //文本分析的 analyzer,我之前是用 fudannlp 对文件进行了分词
        //输入是以空格分割的文件
        //所以用 WhitespaceAnalyzer,也可以换成其他 analyzer
        //lucene 版本是 4.3.0
        Analyzer analyzer = new WhitespaceAnalyzer(Version.LUCENE_43);
        //读取训练集包含的文档个数
        int documentCount = documentFrequency.get(-1).intValue();

        //待分类文本
        String content = "";

        Multiset<String> words = ConcurrentHashMultiset.create();

        TokenStream ts = analyzer.tokenStream("text", new StringReader(content));
        CharTermAttribute termAtt = ts.addAttribute(CharTermAttribute.class);
        ts.reset();
        int wordCount = 0;
        //统计在 dictionary 里出现的待分类的新文档的词
        while (ts.incrementToken()) {
            if (termAtt.length() > 0) {
                String word = ts.getAttribute(CharTermAttribute.class).toString();
                Integer wordId = dictionary.get(word);

                if (wordId != null) {
                    words.add(word);
                    wordCount++;
                }
            }
        }
        //计算 TF-IDF,并构造 Vector 
        Vector vector = new RandomAccessSparseVector(10000);
        TFIDF tfidf = new TFIDF();
        for (Multiset.Entry<String> entry : words.entrySet()) {
            String word = entry.getElement();
            int count = entry.getCount();
            Integer wordId = dictionary.get(word);
            Long freq = documentFrequency.get(wordId);
            double tfIdfValue = tfidf.calculate(count, freq.intValue(), wordCount, documentCount);
            vector.setQuick(wordId, tfIdfValue);
        }
        //分类
        Vector resultVector = classifier.classifyFull(vector);
        double bestScore = -Double.MAX_VALUE;
        int bestCategoryId = -1;
        for(Element element : resultVector.all()) {
            int categoryId = element.index();
            double score = element.get();
            if (score > bestScore) {
                bestScore = score;
                bestCategoryId = categoryId;
            }
            System.out.print("  " + labels.get(categoryId) + ": " + score);
        }
        System.out.println(" => " + labels.get(bestCategoryId));
        analyzer.close();
    }
}

利用 Mahout 朴素贝叶斯分类大概就这样了。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值