SGD对20Newsgroups训练

本文介绍了如何在Hadoop环境下使用Mahout库,通过逻辑回归算法训练模型来处理20个新闻组数据集。详细阐述了环境准备、数据加载、特征提取、模型配置及训练过程,最终实现对新闻组数据的有效分类。

前言:

SGD又名Logistic Regression,逻辑回归。


1.环境准备:hadoop2.2.0集群(或伪集群),mahout0.9,有关hadoop2与mahout0.9冲突问题见其他文档。


2. 下载20Newsgroups数据集放到hadoop主节点上,因为主节点配置了mahout


3.具体代码如下:

package mahout.SGD;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;

import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;

import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multiset;

public class TrainNewsGroups {
	private static final int FEATURES = 10000;
	private static Multiset<String> overallCounts;
	
	public static void main(String[] args) throws IOException {
		String path = "E:\\data\\20news-bydate\\20news-bydate-train";
		File base = new File(args[0]);
//		File base = new File(path);
		overallCounts = HashMultiset.create();

		// 建立向量编码器
		Map<String, Set<Integer>> traceDictionary = new TreeMap<String, Set<Integer>>();
		FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
		encoder.setProbes(2);
		encoder.setTraceDictionary(traceDictionary);
		FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
		bias.setTraceDictionary(traceDictionary);
		FeatureVectorEncoder lines = new ConstantValueEncoder("Lines");
		lines.setTraceDictionary(traceDictionary);
		Dictionary newsGroups = new Dictionary();
		
		// 配置学习算法
		OnlineLogisticRegression learningAlgorithm = 
		    new OnlineLogisticRegression(
		          20, FEATURES, new L1())
		        .alpha(1).stepOffset(1000)
		        .decayExponent(0.9) 
		        .lambda(3.0e-5)
		        .learningRate(20);
		
		
		// 访问数据文件
		List<File> files = new ArrayList<File>();
		for (File newsgroup : base.listFiles()) {
		  newsGroups.intern(newsgroup.getName());
		  files.addAll(Arrays.asList(newsgroup.listFiles()));
		}

		Collections.shuffle(files);
		System.out.printf("%d training files\n", files.size());
		
		// 数据词条化前的预备工作
		double averageLL = 0.0;
		double averageCorrect = 0.0;
		double averageLineCount = 0.0;
		int k = 0;
		double step = 0.0;
		int[] bumps = new int[]{1, 2, 5};
		double lineCount = 0;
		
		
		// 读取数据并进行词条化处理
		Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_31);
		for (File file : files) {
			
			BufferedReader reader = new BufferedReader(new FileReader(file));
			String ng = file.getParentFile().getName();
			int actual = newsGroups.intern(ng);
			Multiset<String> words = ConcurrentHashMultiset.create();

			String line = reader.readLine();
			while (line != null && line.length() > 0) {
				if (line.startsWith("Lines:")) {
					// String count = Iterables.get(onColon.split(line), 1);
					String[] lineArr = line.split("Lines:"); // 获得line行数
					String count = lineArr[1];

					try {
						lineCount = Integer.parseInt(count);
						averageLineCount += (lineCount - averageLineCount) / Math.min(k + 1, 1000);
					} catch (NumberFormatException e) {
						lineCount = averageLineCount;
					}
				}
				
				boolean countHeader = (line.startsWith("From:")
						|| line.startsWith("Subject:")
						|| line.startsWith("Keywords:") || line.startsWith("Summary:"));
				do {
					StringReader in = new StringReader(line);
					if (countHeader) {
						countWords(analyzer, words, in);
					}
					
					line = reader.readLine();
				} while (line.startsWith(" "));
				
			}
			
			countWords(analyzer, words, reader);
			reader.close();
			
			
			// 数据向量化
			Vector v = new RandomAccessSparseVector(FEATURES);
			bias.addToVector("", 1, v);
//			lines.addToVector("", lineCount / 30, v);
			lines.addToVector("", Math.log(lineCount + 1), v);
//			logLines.addToVector(nu	ll, Math.log(lineCount + 1), v);
			for (String word : words.elementSet()) {
				encoder.addToVector(word, Math.log(1 + words.count(word)), v);
			}

			
			// 评估当前进度
			double mu = Math.min(k + 1, 200);
			double ll = learningAlgorithm.logLikelihood(actual, v);
			averageLL = averageLL + (ll - averageLL) / mu;

			Vector p = new DenseVector(20);
			learningAlgorithm.classifyFull(p, v);
			int estimated = p.maxValueIndex();

			int correct = (estimated == actual? 1 : 0);
			averageCorrect = averageCorrect + (correct - averageCorrect) / mu;
			
			
			
			// 用编码数据训练SGD模型
			learningAlgorithm.train(actual, v);

			k++;
			int bump = bumps[(int) Math.floor(step) % bumps.length];
			int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
			
			if (k % (bump * scale) == 0) {
				step += 0.25;
				System.out.printf("%10d %10.3f %10.3f %10.2f %s %s\n",
						k, ll, averageLL, averageCorrect * 100, ng, newsGroups.values().get(estimated));
			}
			learningAlgorithm.close();
		}
//		System.out.println(overallCounts);
//		System.out.println(overallCounts.size());
		
		
		
	}

	private static void countWords(Analyzer analyzer, Collection<String> words,
			Reader in) throws IOException {
		TokenStream ts = analyzer.tokenStream("text", in);
		
		ts.addAttribute(CharTermAttribute.class);
		
		// 这里解决方案见:http://ask.youkuaiyun.com/questions/57173
		ts.reset();
		while (ts.incrementToken()) {
			String s = ts.getAttribute(CharTermAttribute.class).toString();
			words.add(s);
		}
		ts.end();
		ts.close();
		overallCounts.addAll(words); 
	}
}


4. 打包,并在hadoop上调用。注意,jar包需放在java项目的新建lib文件夹下,否则hadoop会找不到包而报ClassNotFoundException。

[root@hadoop1 bin]# hadoop jar ../../jar/javaTex2.jar mahout.SGD.TrainNewsGroups /usr/local/mahout/data/20news-bydate/20news-bydate-train/



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值