接着上篇seq2sparse系列,本次主要分析TFParitialVectorReducer的源码。
打开该类文件,首先分析逻辑流。TFParitialVectorReducer有两个函数分别是setup和reduce,setup函数,主要是读取基本的参数设置,然后就是读取一个相对来说比较重要的变量(从文件中读取),如下代码:
// key is word value is id
for (Pair<Writable, IntWritable> record
: new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
dictionary.put(record.getFirst().toString(), record.getSecond().get());
}
这里把单词和单词的映射读取到一个本地的dictionary变量中,这个变量的定义如下:
private final OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<String>();
这个OpenObjectIntHashMap也是写mahout的大牛们定义的,他们觉得java的map弱爆了,所以自己写了一个map(这个肯定比map速度和性能上有很大的提高);
重点来分析下reduce:reduce进入后,首先判断是否value有值,没有就退出,否则继续。然后就取出了这个value值,类型是StringTuple的,接着定义了一个vector,用于存储后面的东西。
接下来看if条件判断,主要判断maxNGramSize,我也不知道这个是干嘛的,不过默认是1,在参数解释中可以看到它的解释:
- --maxNGramSize(-ng)ngramSize(Optional)Themaximumsizeofngramsto
- create(2=bigrams,3=trigrams,etc)
- DefaultValue:1
for (String term : value.getEntries()) {
if (!term.isEmpty() && dictionary.containsKey(term)) { // unigram
int termId = dictionary.get(term);
vector.setQuick(termId, vector.getQuick(termId) + 1);
}
}
value.getEntries()肯定是一个文件中的所有单词的集合了,然后遍历这个集合,取出一个如果它不是空且在dictionary(单词和数字编码的映射变量)
中有这个单词,那么就把这个单词对应的数字取出来,放入vector中(放入1,而不是这个数字),当然如果这个单词已经在vector中出现了,那么在vector中的值就放入原值+1;后面就是根据初始参数要设定的输出的格式之类的东东了。
可以编写下面的仿制代码进行测试:
package mahout.fansy.test.bayes;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.map.OpenObjectIntHashMap;
public class TFParitialVectorReducerFollow {
/**
* TFParitialVectorReducer 的仿制程序
* @param args
*/
private static Configuration conf=new Configuration();
private static int dimension=93563; // 设置单词的总个数
static{
conf.set("mapred.job.tracker", "ubuntu:9001");
}
public static void main(String[] args) {
reduce(getValue(),getDictionary());
}
public static void reduce(StringTuple value,OpenObjectIntHashMap<String> dictionary){
Vector vector = new RandomAccessSparseVector(dimension, value.length());
for (String term : value.getEntries()) {
if (!term.isEmpty() && dictionary.containsKey(term)) { // unigram
int termId = dictionary.get(term);
vector.setQuick(termId, vector.getQuick(termId) + 1);
}
}
Text key=new Text("49960"); // key的名字,或者说是文件名
vector = new NamedVector(vector, key.toString());
VectorWritable vectorWritable = new VectorWritable(vector);
System.out.println(key+" : "+vectorWritable);
}
/**
* 获得value值,读取HDFS上面
* /home/mahout/mahout-work-mahout/20news-vectors/tokenized-documents/part-m-00000文件
* 测试程序,只取第一个值;
* 仿造reduce函数
*/
public static StringTuple getValue(){
StringTuple st=new StringTuple();
String path="hdfs://ubuntu:9000/home/mahout/mahout-work-mahout/20news-vectors/tokenized-documents/part-m-00000";
Path hdfsPath=new Path(path);
for (Writable value : new SequenceFileDirValueIterable<Writable>(hdfsPath, PathType.LIST,
PathFilters.partFilter(), conf)) {
Class<? extends Writable> valueClass = value.getClass();
if (valueClass.equals(StringTuple.class)) {
st = (StringTuple) value;
break;
} else {
throw new IllegalStateException("Bad value class: " + valueClass);
}
}
return st;
}
/**
* 获得dictionary的单词和数字编码的映射,在HDFS上面的
* /home/mahout/mahout-work-mahout/20news-vectors/dictionary.file-0文件
* 仿造setup函数
*/
public static OpenObjectIntHashMap<String> getDictionary(){
OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<String>();
String path="hdfs://ubuntu:9000/home/mahout/mahout-work-mahout/20news-vectors/dictionary.file-0";
Path dictionaryFile = new Path(path);
// key is word value is id
for (Pair<Writable, IntWritable> record
: new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
dictionary.put(record.getFirst().toString(), record.getSecond().get());
}
return dictionary;
}
}
可以分别在71行、87行设置断点,查看读取到的输入值是什么。比如我看到的值如下:
st(即value的值,只显示前面几个):
[from, mathew, mathew, mantis.co.uk, subject, alt.atheism, faq, atheist, resources, summary, books, addresses, music, anything, related, atheism, keywords
dictionary:
[sophomore->78643, moon's->59206, flamewar->38539, indiscriminately->47036,
下面是reduce中的for循环,首先第一个单词是from,在dictionary中查到的数字是39560,然后一次循环后vector值为{39560:1.0};两次循环后值为:{56411:1.0,39560:1.0}
三次循环后:{56411:2.0,39560:1.0},由于第二次和第三次的单词一样,所以其对应的数字也是一样的,这样就会在vector中产生一个值,但是它的value是2;
分享,快乐,成长
转载请注明出处:http://blog.youkuaiyun.com/fansy1990