实现mahout0.9 bayes 预测功能(mahout只有trainnb和testnb)

1、概述

mahout0.9 对贝叶斯模型只提供了训练trainnb和测试testnb函数,仅能够得到模型和测试模型的好坏,没有实现模型预测功能,通过对mahout源码的解读,自己编写了mahout bayes模型的预测功能。mahout0.9贝叶斯的使用方式见http://blog.youkuaiyun.com/mach_learn/article/details/39667713

2、mahout不支持predict原因

mahout0.9将训练集合测试集同时进行序列化和向量化,然后再将向量化的文件进行分片,分为测试集合训练集。mahout在向量化时会生成以下文件


其中,dictionary.file-0文件将词对应到整形序号,key对应词或标点符号等,value代表序号值(整数)。frequency.file-0key值对应序号,value值为key序号对应的词在多少文件中出现。df-count文件夹存放的是document frenquency的数据。tf-vectors中存放的是每个文件的term frenquency。tfidf-vectors中存放的是每个文件中词序号和对应的tfidf值。tokenized-documents中存放的是分词后的文件。wordcount存放的是每个词在全部文档中的词频。

mahout向量化结束后将tfidf-vectors中的文件进行分片,分为训练集和测试集,一般是80-20比例,然后使用trainnb对训练集训练得到naiveBayesModel.bin模型,之后再使用testnb对naiveBayesModel.bin模型进行测试评估。

mahout进行统一向量化后会有一个统一的dictionary文件,这就导致了其他单独通过seq2sparse进行向量化的文件时不能使用其他训练数据得到的naiveBayesModel.bin模型,因为两个向量的dictionary是不一样的。

3、mahout预测函数编写思路

为了使用naiveBayesModel.bin模型进行预测,我们需要将需要预测的数据根据使用模型的向量化标准进行处理(即要使预测数据与产生向量时的dictionary等文件对于起来)。首先,将预测数据对应到相应对的dictionary,然后,根据对应词的序号获取df-count数据,之后计算该数据对应的tfidf数据(计算tfidf仅使用df-count和numdocs,以及预测数据的词频),numdocs是df-count中key为-1对应的value值。将tfidf数据代入naiveBayesModel.bin模型,即可求得每种类别对应的似然值,取最大值对应的类别,即是预测类别。

编程环境,mahout0.9,需要的jar包见下图



程序需要mahout训练的模型和seq2sparse向量化的文件。seq2sparse向量化的文件需要使用mahout seqdumper -i inputfile -o outputfile命令,将序列化文件转为文本文件。文件结构图如下:



代码如下:

import java.awt.print.Printable;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.shell.Count;
import org.apache.hadoop.hdfs.server.namenode.status_jsp;
import org.apache.hadoop.mapred.ID;
import org.apache.mahout.cf.taste.hadoop.als.PredictionMapper;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.MutableElement;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.mahout.vectorizer.TFIDF.*;

import com.ibm.icu.impl.ICUService.Key;

public class BayesPredict extends AbstractJob
{
	
	public static HashMap<String, String> dictionaryHashMap = new HashMap<>();
	public static HashMap<String, String> dfcountHashMap = new HashMap<>();
	public static HashMap<String, String> wordcountHashMap = new HashMap<>();
	public static HashMap<String, String> labelindexHashMap = new HashMap<>();
	public BayesPredict()
	{
		readDfCount("model/df-count.txt");
		readDictionary("model/dictionary.txt");
		readLabelIndex("model/labelindex.txt");
		readWordCount("model/wordcount.txt");
	}
	public static String[] readFile(String filename)
	{
		
		File file = new File(filename);
		BufferedReader reader;
		String tempstring = null;
		try 
		{
			reader = new BufferedReader(new FileReader(file));
			tempstring = reader.readLine();
			reader.close();
			if(tempstring==null)
				return null;
		} 
		catch (IOException e) 
		{
			e.printStackTrace();
		}
		
		String[] mess = tempstring.trim().split(" ");
		return mess;
	}
	public static void readDictionary(String fileName)
	{
		File file = new File(fileName);
		BufferedReader reader;
		String tempstring = null;
		try 
		{
			
			reader = new BufferedReader(new FileReader(file));
			while((tempstring = reader.readLine())!=null)
			{

				if(tempstring.startsWith("Key:"))
				{
					String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
					String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
					dictionaryHashMap.put(key.trim(), value.trim());
				}
			}
			reader.close();
		} 
		catch (IOException e) 
		{
			e.printStackTrace();
		}
	}
	public static void readDfCount(String fileName)
	{
		File file = new File(fileName);
		BufferedReader reader;
		String tempstring = null;
		try 
		{
			
			reader = new BufferedReader(new FileReader(file));
			while((tempstring = reader.readLine())!=null)
			{

				if(tempstring.startsWith("Key:"))
				{
					String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
					String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
					dfcountHashMap.put(key.trim(), value.trim());
				}
				
			}
			reader.close();
		} 
		catch (IOException e) 
		{
			e.printStackTrace();
		}
	}
	public static void readWordCount(String fileName)
	{
		File file = new File(fileName);
		BufferedReader reader;
		String tempstring = null;
		try 
		{
			
			reader = new BufferedReader(new FileReader(file));
			while((tempstring = reader.readLine())!=null)
			{

				if(tempstring.startsWith("Key:"))
				{
					String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
					String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
					wordcountHashMap.put(key.trim(), value.trim());
				}
			}
			reader.close();
		} 
		catch (IOException e) 
		{
			e.printStackTrace();
		}
	}
	public static void readLabelIndex(String fileName)
	{
		File file = new File(fileName);
		BufferedReader reader;
		String tempstring = null;
		try 
		{
			
			reader = new BufferedReader(new FileReader(file));
			while((tempstring = reader.readLine())!=null)
			{

				if(tempstring.startsWith("Key:"))
				{
					String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
					String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
					labelindexHashMap.put(key.trim(), value.trim());
				}
			}
			reader.close();
		} 
		catch (IOException e) 
		{
			e.printStackTrace();
		}
	}
	public static HashMap<Integer, Double> calcTfIdf(String filename)
	{
		String[] words = readFile(filename);
		if(words==null)
			return null;
		HashMap<Integer, Double> tfidfHashMap = new HashMap<Integer, Double>(); 
		HashMap<String, Integer> wordHashMap = new HashMap<String, Integer>();
		for(int k=0; k<words.length; k++)
		{
			if(wordHashMap.get(words[k])==null)
			{
				wordHashMap.put(words[k], 1);
			}
			else
			{
				wordHashMap.put(words[k], wordHashMap.get(words[k])+1);
			}
		}
		
//		System.out.println("wordcount:"+wordHashMap.size());
		
		
		/*
		System.out.println("dfcount:"+dfcountHashMap.size());
		System.out.println("dictionary:"+dictionaryHashMap.size());
		System.out.println("labelindex:"+labelindexHashMap.size());
		System.out.println("wordcount:"+wordcountHashMap.size());
 		*/
		
		Iterator iterator = wordHashMap.entrySet().iterator();
		int numDocs = Integer.parseInt(dfcountHashMap.get("-1"));
		
		while(iterator.hasNext())
		{
			Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)iterator.next();
			String key = entry.getKey();
			int value = entry.getValue();
			int tf = value;
//			System.out.println(key+":"+value);
			if(dictionaryHashMap.get(key)!=null)
			{
				String idString = dictionaryHashMap.get(key);
				int df = Integer.parseInt(dfcountHashMap.get(idString));
				TFIDF tfidf = new TFIDF(); 
				double tfidf_value = tfidf.calculate(tf, df, 0, numDocs);
				
				tfidfHashMap.put(Integer.parseInt(idString), tfidf_value);
//				System.out.println(idString+":"+tfidf_value);
			}
			
		}
		return tfidfHashMap;
	}
	public String predict(String filename) throws IOException
	{
		
		HashMap<Integer, Double> tfidfHashMap = calcTfIdf(filename);
		if(tfidfHashMap==null)
			return "file is empty,unknow classify";
//		FileSystem fs = FileSystem.get(getConf());
		NaiveBayesModel model = NaiveBayesModel.materialize(new Path("model/model/"), getConf());
		ComplementaryNaiveBayesClassifier classifier;
		classifier = new ComplementaryNaiveBayesClassifier(model);
		
		double label_1=0;
		double label_2=0;
	    
		Iterator iterator = tfidfHashMap.entrySet().iterator();
		while(iterator.hasNext())
		{
			Map.Entry<Integer, Double> entry = (Map.Entry<Integer, Double>)iterator.next();
			int key = entry.getKey();
			double value = entry.getValue();
			label_1 += value*classifier.getScoreForLabelFeature(0, key);
			label_2 += value*classifier.getScoreForLabelFeature(1, key);
		}
//		System.out.println("label_1:"+label_1);
//		System.out.println("label_2:"+label_2);
		if(label_1>label_2)
			return "fraud-female";
		else
			return "norm-female";
	}
	@Override
	public int run(String[] arg0) throws Exception {
		// TODO Auto-generated method stub
		return 0;
	}
	public static void main(String[] args) 
	{

		//dictionary test
		/*
		readDictionary("model/dictionary.txt");
		Iterator iterator = dictionaryHashMap.entrySet().iterator();
		while(iterator.hasNext())
		{
			Map.Entry<String, String> entry = (Map.Entry<String, String>)iterator.next();
			System.out.println(entry.getKey()+"--"+entry.getValue());
		}
		System.out.println(dictionaryHashMap.size());
		System.out.println(System.getProperty("user.dir"));
		
		*/
		long startTime=System.currentTimeMillis();
		BayesPredict bPredict = new BayesPredict();
		try {
			File file = new File("model/test/");
			String[] filenames = file.list();
			int count1 = 0;
			int count2 = 0;
			int count = 0;
			for(int i=0;i<filenames.length;i++)
			{
				String result = bPredict.predict("model/test/"+filenames[i]);
				count++;
				if(result.equals("fraud-female"))
					count1++;
				else if(result.equals("norm-female"))
					count2++;
				System.out.println(filenames[i]+":"+result);
				
			}
			System.out.println("count:"+count);
			System.out.println("count1:"+count1);
			System.out.println("count2:"+count2);
			System.out.println("time:"+(System.currentTimeMillis()-startTime)/1000.0);
		} catch (IOException e) {
			e.printStackTrace();
		}
	}
}





评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值