Mahout之bayes算法学习(四)

本文介绍如何使用Mahout工具实现基于天气条件预测是否适合进行羽毛球活动的分类模型,通过构建分类数据集、训练模型、转换检测数据为向量,并最终对检测数据进行分类,得出预测结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前看了Mahout官方示例 20news 的调用实现;于是想根据示例的流程实现其他例子。网上看到了一个关于天气适不适合打羽毛球的例子。

训练数据:
Day Outlook  Temperature Humidity Wind PlayTennis
D1 Sunny  Hot High Weak No
D2 Sunny  Hot High Strong No
D3 Overcast  Hot High Weak Yes
D4 Rain  Mild High Weak Yes
D5 Rain  Cool Normal Weak Yes
D6 Rain  Cool Normal Strong No
D7 Overcast  Cool Normal Strong Yes
D8 Sunny  Mild High Weak No
D9 Sunny  Cool Normal Weak Yes
D10 Rain  Mild Normal Weak Yes
D11 Sunny  Mild Normal Strong Yes
D12 Overcast  Mild High Strong Yes
D13 Overcast  Hot Normal Weak Yes
D14 Rain  Mild High Strong No


检测数据:
sunny,hot,high,weak

结果:
Yes=》 0.007039
No=》 0.027418

于是使用Java代码调用Mahout的工具类实现分类。

基本思想:

1. 构造分类数据。

2. 使用Mahout工具类进行训练,得到训练模型。

3。将要检测数据转换成vector数据。

4.  分类器对vector数据进行分类。

接下来贴下我的代码实现=》

1. 构造分类数据:


在hdfs主要创建一个文件夹路径  /zhoujainfeng/playtennis/input 并将分类文件夹 no 和 yes 的数据传到hdfs上面。

数据文件格式,如D1文件内容: Sunny Hot High Weak 


2. 使用Mahout工具类进行训练,得到训练模型。

3。将要检测数据转换成vector数据。

4.  分类器对vector数据进行分类。

这三步,代码我就一次全贴出来;主要是两个类 PlayTennis1 和 BayesCheckData   = =》

package myTesting.bayes;


import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.text.SequenceFilesFromDirectory;
import org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles;


public class PlayTennis1 {


private static final String WORK_DIR = "hdfs://192.168.9.72:9000/zhoujianfeng/playtennis";

/*
* 测试代码
*/
public static void main(String[] args) {



//将训练数据转换成 vector数据
makeTrainVector();
//产生训练模型
makeModel(false);

//测试检测数据
BayesCheckData.printResult();

}

public static void makeCheckVector(){
//将测试数据转换成序列化文件
try {
Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));

String input = WORK_DIR+Path.SEPARATOR+"testinput";
String output = WORK_DIR+Path.SEPARATOR+"tennis-test-seq";

Path in = new Path(input);
Path out = new Path(output);

FileSystem fs = FileSystem.get(conf);

if(fs.exists(in)){
if(fs.exists(out)){
//boolean参数是,是否递归删除的意思
fs.delete(out, true);
}
SequenceFilesFromDirectory sffd = new SequenceFilesFromDirectory();
String[] params = new String[]{"-i",input,"-o",output,"-ow"};
ToolRunner.run(sffd, params);
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("文件序列化失败!");
System.exit(1);
}

//将序列化文件转换成向量文件
try {
Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));

String input = WORK_DIR+Path.SEPARATOR+"tennis-test-seq";
String output = WORK_DIR+Path.SEPARATOR+"tennis-test-vectors";

Path in = new Path(input);
Path out = new Path(output);

FileSystem fs = FileSystem.get(conf);

if(fs.exists(in)){
if(fs.exists(out)){
//boolean参数是,是否递归删除的意思
fs.delete(out, true);
}

SparseVectorsFromSequenceFiles svfsf = new SparseVectorsFromSequenceFiles();
String[] params = new String[]{"-i",input,"-o",output,"-lnorm","-nv","-wt","tfidf"};
ToolRunner.run(svfsf, params);

}



} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("序列化文件转换成向量失败!");
System.out.println(2);
}
}

public static void makeTrainVector(){
//将测试数据转换成序列化文件
try {
Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));

String input = WORK_DIR+Path.SEPARATOR+"input";
String output = WORK_DIR+Path.SEPARATOR+"tennis-seq";

Path in = new Path(input);
Path out = new Path(output);

FileSystem fs = FileSystem.get(conf);

if(fs.exists(in)){
if(fs.exists(out)){
//boolean参数是,是否递归删除的意思
fs.delete(out, true);
}
SequenceFilesFromDirectory sffd = new SequenceFilesFromDirectory();
String[] params = new String[]{"-i",input,"-o",output,"-ow"};
ToolRunner.run(sffd, params);
}
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("文件序列化失败!");
System.exit(1);
}

//将序列化文件转换成向量文件
try {
Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));


String input = WORK_DIR+Path.SEPARATOR+"tennis-seq";
String output = WORK_DIR+Path.SEPARATOR+"tennis-vectors";

Path in = new Path(input);
Path out = new Path(output);

FileSystem fs = FileSystem.get(conf);

if(fs.exists(in)){
if(fs.exists(out)){
//boolean参数是,是否递归删除的意思
fs.delete(out, true);
}

SparseVectorsFromSequenceFiles svfsf = new SparseVectorsFromSequenceFiles();
String[] params = new String[]{"-i",input,"-o",output,"-lnorm","-nv","-wt","tfidf"};
ToolRunner.run(svfsf, params);

}



} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("序列化文件转换成向量失败!");
System.out.println(2);
}
}

public static void makeModel(boolean completelyNB){
try {

Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));

String input = WORK_DIR+Path.SEPARATOR+"tennis-vectors"+Path.SEPARATOR+"tfidf-vectors";
String model = WORK_DIR+Path.SEPARATOR+"model";
String labelindex = WORK_DIR+Path.SEPARATOR+"labelindex";

Path in = new Path(input);
Path out = new Path(model);
Path label = new Path(labelindex);

FileSystem fs = FileSystem.get(conf);

if(fs.exists(in)){
if(fs.exists(out)){
//boolean参数是,是否递归删除的意思
fs.delete(out, true);
}

if(fs.exists(label)){
//boolean参数是,是否递归删除的意思
fs.delete(label, true);
}
TrainNaiveBayesJob tnbj = new TrainNaiveBayesJob();
String[] params =null;
if(completelyNB){
params = new String[]{"-i",input,"-el","-o",model,"-li",labelindex,"-ow","-c"};
}else{
params = new String[]{"-i",input,"-el","-o",model,"-li",labelindex,"-ow"};
}
ToolRunner.run(tnbj, params);
}

} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("生成训练模型失败!");
System.exit(3);
}
}


}




package myTesting.bayes;


import java.io.IOException;
import java.util.HashMap;
import java.util.Map;


import org.apache.commons.lang.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.PathFilter;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.common.Pair;
import org.apache.mahout.common.iterator.sequencefile.PathType;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.vectorizer.TFIDF;


import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.Multiset;


public class BayesCheckData {


private static StandardNaiveBayesClassifier classifier;
private static Map<String, Integer> dictionary;
private static Map<Integer, Long> documentFrequency;
private static Map<Integer, String> labelIndex;


public void init(Configuration conf){

try {
String modelPath = "/zhoujianfeng/playtennis/model";
String dictionaryPath = "/zhoujianfeng/playtennis/tennis-vectors/dictionary.file-0";
String documentFrequencyPath = "/zhoujianfeng/playtennis/tennis-vectors/df-count";
String labelIndexPath = "/zhoujianfeng/playtennis/labelindex";

dictionary = readDictionnary(conf, new Path(dictionaryPath));
documentFrequency = readDocumentFrequency(conf, new Path(documentFrequencyPath));
labelIndex = BayesUtils.readLabelIndex(conf, new Path(labelIndexPath));


NaiveBayesModel model = NaiveBayesModel.materialize(new Path(modelPath), conf);
classifier = new StandardNaiveBayesClassifier(model);

} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
System.out.println("检测数据构造成vectors初始化时报错。。。。");
System.exit(4);
}
}

/**
* 加载字典文件,Key: TermValue; Value:TermID
* @param conf
* @param dictionnaryDir
* @return
*/
private static Map<String, Integer> readDictionnary(Configuration conf, Path dictionnaryDir) {
Map<String, Integer> dictionnary = new HashMap<String, Integer>();


PathFilter filter = new PathFilter() {
@Override
public boolean accept(Path path) {
String name = path.getName();
   return name.startsWith("dictionary.file");
}
};
for (Pair<Text, IntWritable> pair : new SequenceFileDirIterable<Text, IntWritable>(dictionnaryDir, PathType.LIST, filter, conf)) {
dictionnary.put(pair.getFirst().toString(), pair.getSecond().get());
}
return dictionnary;
}

/**
* 加载df-count目录下TermDoc频率文件,Key: TermID; Value:DocFreq
* @param conf
* @param dictionnaryDir
* @return
*/
private static Map<Integer, Long> readDocumentFrequency(Configuration conf, Path documentFrequencyDir) {
Map<Integer, Long> documentFrequency = new HashMap<Integer, Long>();


PathFilter filter = new PathFilter() {
@Override
public boolean accept(Path path) {
return path.getName().startsWith("part-r");
}
};
for (Pair<IntWritable, LongWritable> pair : new SequenceFileDirIterable<IntWritable, LongWritable>(documentFrequencyDir, PathType.LIST, filter, conf)) {
documentFrequency.put(pair.getFirst().get(), pair.getSecond().get());
}
return documentFrequency;
}


public static String getCheckResult(){
Configuration conf = new Configuration();
conf.addResource(new Path("/usr/local/hadoop/conf/core-site.xml"));

String classify = "NaN";
BayesCheckData cdv = new BayesCheckData();
cdv.init(conf);

System.out.println("init  done...............");

Vector vector = new RandomAccessSparseVector(10000);
TFIDF tfidf = new TFIDF();

//sunny,hot,high,weak
Multiset<String> words = ConcurrentHashMultiset.create();
words.add("sunny",1);
words.add("hot",1);
words.add("high",1);
words.add("weak",1);



int documentCount = documentFrequency.get(-1).intValue(); // key=-1时表示总文档数

for (Multiset.Entry<String> entry : words.entrySet()) {
String word = entry.getElement();
int count = entry.getCount();
Integer wordId = dictionary.get(word); // 需要从dictionary.file-0文件(tf-vector)下得到wordID,
if (StringUtils.isEmpty(wordId.toString())){
continue;
}
if (documentFrequency.get(wordId) == null){
continue;
}
Long freq = documentFrequency.get(wordId);


double tfIdfValue = tfidf.calculate(count, freq.intValue(), 1, documentCount);
vector.setQuick(wordId, tfIdfValue);
}


//  利用贝叶斯算法开始分类,并提取得分最好的分类label
Vector resultVector = classifier.classifyFull(vector);
double bestScore = -Double.MAX_VALUE;
int bestCategoryId = -1;
for(Element element: resultVector.all()) {
int categoryId = element.index();
double score = element.get();
System.out.println("categoryId:"+categoryId+"  score:"+score);
if (score > bestScore) {
bestScore = score;
bestCategoryId = categoryId;
}
}
classify = labelIndex.get(bestCategoryId)+"(categoryId="+bestCategoryId+")";

return classify;


}

public static void printResult(){
System.out.println("检测所属类别是:"+getCheckResult());
}
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值