1.简概
上一篇简单介绍doc2vec的实现以及原理,这一篇看看用doc2vec用于文本分类情况。2.数据格式
跟cnn、lstm输入格式一样
1 看头像 加微信
1 专业 办理 二手房交易 公积金贷款 商业贷款 租房 需要 请来 店咨询
1 奥森 健身 要 砸金蛋 捂脸 砸金蛋 不是 李咏 专利 活动 当天 人人 都是 李咏 奥森 健身 小季 特邀 您 6月 18号 来 咱们 店 砸金蛋 百分百 中奖 百分百 有礼 只有 你 想不到 没有 我们 做 不到 呲 牙 优惠力度 之 大 我们 老总 知道 害怕 偷偷 咱 告诉 他 捂脸 咨询电话 手机号码
1 建议 下次 麻花 另外 拿 袋子 装 起来 放 盒子 里 太软 想 只 吃 长肉 美女 加 座机号码
0 出 两张 情侣卡 还有 10个 月 3 k 出
0 专业 甲醛检测 甲醛 治理 清除 装修 异味 给 您 健康 呼吸 保驾护航
1 不错 儿子 喜欢吃 宝妈 们 想 照顾好 家的 同时 能 月入 上万 加我 其他数字
1 可以 啊 微信同号
3.代码实现
实现是实现LabelAwareIterator接口,看看实现的情况
package com.dianping.deeplearning.paragraphvectors;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.datavec.api.util.RandomUtils;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.documentiterator.LabelsSource;
import org.nd4j.linalg.collection.CompactHeapStringList;
public class TxtLabelAwareIterator implements LabelAwareIterator {
private int totalCount;
private Map<String, List<String>> filesByLabel;
private List<String> normList;
private List<String> negList;
private final List<String> sentenslist;
private final int[] labelIndexes;
private final Random rng;
private final int[] order;
private final List<String> allLabels;
private LabelsSource source;
private int cursor = 0;
public TxtLabelAwareIterator(String path) {
this(path, new Random());
}
public TxtLabelAwareIterator(String path, Random rng) {
totalCount = 0;
filesByLabel = new HashMap<String, List<String>>();
normList = new ArrayList<String>();
negList = new ArrayList<>();
BufferedReader buffered = null;
try {
buffered = new BufferedReader(new InputStreamReader(
new FileInputStream(path)));
String line = buffered.readLine();
while (line != null) {
String[] lines = line.split("\t");
String label = lines[0];
String contennt = lines[1];
if ("1".equalsIgnoreCase(label)) {
normList.add(contennt);
} else if ("0".equalsIgnoreCase(label)) {
negList.add(contennt);
}
totalCount++;
line = buffered.readLine();
}
buffered.close();
} catch (Exception e) {
e.printStackTrace();
}
filesByLabel.put("1", normList);
filesByLabel.put("0", negList);
this.rng = rng;
if (rng == null) {
order = null;
} else {
order = new int[totalCount];
for (int i = 0; i < totalCount; i++) {
order[i] = i;
}
RandomUtils.shuffleInPlace(order, rng);
}
allLabels = new ArrayList<>(filesByLabel.keySet());
source = new LabelsSource(allLabels);
Collections.sort(allLabels);
Map<String, Integer> labelsToIdx = new HashMap<>();
for (int i = 0; i < allLabels.size(); i++) {
labelsToIdx.put(allLabels.get(i), i);
}
sentenslist = new CompactHeapStringList();
labelIndexes = new int[totalCount];
int position = 0;
for (Map.Entry<String, List<String>> entry : filesByLabel.entrySet()) {
int labelIdx = labelsToIdx.get(entry.getKey());
for (String f : entry.getValue()) {
sentenslist.add(f);
labelIndexes[position] = labelIdx;
position++;
}
}
}
@Override
public boolean hasNext() {
return cursor < totalCount;
}
@Override
public LabelledDocument next() {
return nextDocument();
}
@Override
public boolean hasNextDocument() {
return hasNextDocument();
}
@Override
public LabelledDocument nextDocument() {
LabelledDocument document = new LabelledDocument();
int idx;
if (rng == null) {
idx = cursor++;
} else {
idx = order[cursor++];
}
;
String label = allLabels.get(labelIndexes[idx]);
String sentence;
sentence = sentenslist.get(idx);
document.setContent(sentence);
document.addLabel(label);
return document;
}
@Override
public void reset() {
cursor = 0;
if (rng != null) {
RandomUtils.shuffleInPlace(order, rng);
}
}
@Override
public LabelsSource getLabelsSource() {
return source;
}
@Override
public void shutdown() {
}
}
分类器实现:
package com.dianping.deeplearning.paragraphvectors;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.text.documentiterator.LabelAwareIterator;
import org.deeplearning4j.text.documentiterator.LabelledDocument;
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
public class Doc2VecAdxClassify {
private String path = "adx/rnnsenec.txt";
ParagraphVectors paragraphVectors;
LabelAwareIterator iterator;
TokenizerFactory tokenizerFactory;
public static void main(String[] args) {
Doc2VecAdxClassify doc2vec = new Doc2VecAdxClassify();
doc2vec.makeParagraphVectors();
// 预测分类
System.out.println(doc2vec.paragraphVectors
.predict("专业 甲醛检测 甲醛 治理 清除 装修 异味 给 您 健康 呼吸"));
MeansBuilder meansBuilder = new MeansBuilder(
(InMemoryLookupTable<VocabWord>) doc2vec.paragraphVectors
.getLookupTable(),
doc2vec.tokenizerFactory);
LabelSeeker seeker = new LabelSeeker(doc2vec.iterator.getLabelsSource()
.getLabels(),
(InMemoryLookupTable<VocabWord>) doc2vec.paragraphVectors
.getLookupTable());
LabelledDocument document = new LabelledDocument();
document.setContent("专业 甲醛检测 甲醛 治理 清除 装修 异味 给 您 健康 呼吸");
document.addLabel("0");
meansBuilder.documentAsVector(document);
INDArray documentAsCentroid = meansBuilder.documentAsVector(document);
List<Pair<String, Double>> scores = seeker
.getScores(documentAsCentroid);
for (Pair<String, Double> score : scores) {
System.out.println(" " + score.getFirst() + ": "+ score.getSecond());
}
}
public void makeParagraphVectors() {
System.out.println("path is :" + path);
iterator = new TxtLabelAwareIterator(path);
System.out.println(iterator.getLabelsSource().getLabels());
tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
paragraphVectors = new ParagraphVectors.Builder()
.learningRate(0.025)
.minLearningRate(0.001).
batchSize(1000).
epochs(20)
.iterate(iterator)
.trainWordVectors(true)
.tokenizerFactory(tokenizerFactory)
.build();
// Start model training
paragraphVectors.fit();
}
}
4.结果
1
0: -0.2978013753890991
1: 0.17002613842487335
最后输出的是输出文本与标签之间的余弦相似度 ,分类还是较为准确的。