使用Java语言实现简单的网页分类
本文欲记录我使用Java语言去实现一个简单的网页分类。其中涉及的第三方库有分词工具HanLP、数据操作工具joinery、SVM分类工具LibSVM、网页结构处理Jsoup、以及TOMCAT。
如果其中文章有什么错误,或者我使用的方法本身就有问题,务必请读者朋友指出!
这堆屎山代码是真的没办法多看一眼,多看一眼都得爆炸。
网页分类的特征选择
我们要分类一个网页,首先要确定根据网页的什么的特征来进行分类,最常见的一个就是根据网页文本来进行分类,到今天,已经有根据网页的文本结合网页的结构作为特征来进行分类、将网页保存为图片作为特征进行分类等选择不同的网页特征来进行分类的方法。
(首先我们要确定一件事情,那就是网页分类和网站分类有所不同,这个具体的不同我觉得还是以你怎么去解释网站和网页,在我看的论文中就提到网站是由多个网页构程,不过我这里就以内容去区分一个网页,所以没有考虑到网页结构、超链接等其他网页内容)
本文仅仅以网页文本内容为特征进行讨论。
模块设计
一般来说,要想得到一个分类器的话要经历三个步骤:获取数据、处理数据、训练数据。
获取数据
本文通过爬取站长之家的网站排行榜中的行业排行榜作为数据:
其中可以能看到,站长之家中的行业排行榜共给出12个行业,分别是:购物网站,交通旅游, 教育文化, 生活服务, 体育健身,网络科技, 新闻媒体, 行业企业, 休闲娱乐, 医疗健康, 政府组织, 综合其他。
本文爬取每个行业排行榜下的网页内容、以行业名为lable,保存为csv文件,作为中文数据集。
爬取每一行业行业排行榜下出现的url:
package com.example.webproject;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import org.jsoup.nodes.Element;
import org.jsoup.select.Elements;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import static java.lang.Thread.sleep;
/**
* 本程序用来爬取站长之家的网站排行榜中的url
*/
public class WebCrawler {
private List<String> urlLables;
public static void main(String[] args) throws IOException, InterruptedException {
int pageNum = 0;
HashMap<String, String> firstUrlMap = FirstSpider("https://top.chinaz.com/hangye/");
FileWriter urlLableWriter = new FileWriter("urlLableSet.txt");
// 通过FirstSpider生成一级分类的url
for (String firstUrlLable : firstUrlMap.keySet()) {
HashMap<String, String> urls = SecondSpider(firstUrlMap.get(firstUrlLable));
// 通过SecondSpider生成二级分类的url
for (String secondUrlLable : urls.keySet()) {
sleep(8000);
// 创建 secondUrl
String secondUrl = urls.get(secondUrlLable);
// 创建 CSV 文件
FileWriter csvWriter = new FileWriter( firstUrlLable + "." + secondUrlLable + ".csv");
urlLableWriter.append(firstUrlLable + "." + secondUrlLable + "\n");
pageNum = 2;
System.out.println("开始保存" + secondUrlLable + ".csv" + "\n----------------------------");
// 通过第三层爬虫获取我所需要的url
for (String str : ThirdSpider(secondUrl)) {
csvWriter.append(str + "\n");
}
while (true) {
// 更新第二级分类的页码
String nextUrl = secondUrl.replace(".html", "_" + pageNum++ + ".html");
// 获取页码的html页面
Document doc = Jsoup.connect(nextUrl).get();
// 获取该select的内容
Elements targetElement = doc.select("div.emptyCss");
// 如果该页面不存在div.emptyCss(抱歉,未找到相关网站!)则爬取,存在则跳出循环(说明已经到达最后一页)
if (targetElement.isEmpty()) {
// 通过第三层爬虫获取我所需要的url
for (String str : ThirdSpider(nextUrl)) {
csvWriter.append(str + "\n");
}
} else {
urlLableWriter.flush();
csvWriter.flush();
csvWriter.close();
System.out.println("数据已保存到" + secondUrlLable + ".csv 文件中");
break;
}
}
}
}
urlLableWriter.flush();
urlLableWriter.close();
return;
}
/**
* 定义第一层爬虫url获取规则
*
* @return
*/
private static HashMap<String, String> FirstSpider(String url) {
HashMap<String, String> urls = null;
try {
// 1. 定义目标网址和选择器
String selector = "div.TopListCent-Head a[href]";
// 2. 使用Jsoup连接并解析网页
Document doc = Jsoup.connect(url).get();
// 3. 使用选择器获取目标元素
Elements targetElement = doc.select(selector);
urls = new HashMap<>();
for (Element link : targetElement) {
String href = link.absUrl("href");
String lable = link.text();
// 递归调用自身处理链接对应的页面
urls.put(lable, href);
}
// 5. 输出提取的内容
// System.out.println(targetContent);
} catch (IOException e) {
System.out.println("Error occurred while crawling: " + e.getMessage());
}
return urls;
}
/**
* 定义第二层爬虫url获取规则
*
* @return
*/
private static HashMap<String, String> SecondSpider(String url) {
HashMap<String, String> urls = null;
try {
// 定义目标网址和选择器
String selector = "div.TopListCent-Head a[href]";
// 使用Jsoup连接并解析网页
Document doc = Jsoup.connect(url).get();
// 使用选择器获取目标元素
Elements targetElement = doc.select(selector);
// 提取元素内容
// String targetContent = targetElement.text();
urls = new HashMap<>();
for (Element link : targetElement) {
String href = link.absUrl("href");
String lable = link.text();
// 递归调用自身处理链接对应的页面
urls.put(lable, href);
// System.out.println(href);
}
} catch (IOException e) {
System.out.println("Error occurred while crawling: " + e.getMessage());
}
return urls;
}
/**
* 定义第三层爬虫url获取规则
*
* @return
*/
private static List<String> ThirdSpider(String url) {
List<String> urls = null;
try {
// 定义目标网址和选择器
String firstSelector = "div.TopListCent-listWrap span.col-gray";
// 使用Jsoup连接并解析网页
Document doc = Jsoup.connect(url).get();
// 使用选择器获取目标元素
Elements targetElement = doc.select(firstSelector);
// 提取元素内容
String targetContent = targetElement.text();
urls = new ArrayList<>();
for (Element link : targetElement) {
String href = link.text();
// System.out.println(href);
// 递归调用自身处理链接对应的页面
urls.add(href);
}
} catch (IOException e) {
System.out.println("Error occurred while crawling: " + e.getMessage());
}
return urls;
}
}
得到12个CSV文件,每个文件里有许多url,这就是我的目标,现在我需要将其中的文本内容爬下来:
package com.example.webproject;
import java.io.*;
import java.net.InetSocketAddress;
import java.net.Proxy;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.jsoup.Jsoup;
import org.jsoup.nodes.Document;
import static java.lang.Thread.sleep;
/**
* 本程序用来将从WebCrawler.java程序中爬好的站长之家网站排行榜的url,
* 进行爬取数据并保存到本地。
*/
public class Spider {
public void run() throws Exception {
// test();
// List url = getUrl("http://17k.com/", "filename");
List<String> lableNames = getLableName(); // 获取标签名字
//
HashMap<String, List<String>> tagUrlMap = new HashMap<>();
for (String lableName:lableNames){
tagUrlMap.put(lableName, getUrlStringList("UrlData/" + lableName + ".csv"));
}
for (String lable:tagUrlMap.keySet()){
System.out.printf("正在爬" + lable + "\n");
FileWriter csvWriter = new FileWriter( lable + ".csv");
for (String url:tagUrlMap.get(lable)){
String urlText = requestParse(url);
if (urlText != null){
csvWriter.append(url + "," + solveText(urlText) + "\n");
// System.out.println(solveText(urlText));
}
}
csvWriter.flush();
csvWriter.close();
}
}
private List<String> getUrlStringList(String fileName) throws IOException {
List<String> validUrl = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader(fileName))) {
String line;
while ((line = br.readLine()) != null) {
if (line.startsWith("(更新")){
continue;
}
validUrl.add(line);
}
} catch (IOException e) {
System.err.format("IOException: %s%n", e);
}
return validUrl;
}
private List<String> getLableName() throws FileNotFoundException {
String fileName = "urlLableSet.txt";
List<String> lableName = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader(fileName))) {
String line;
while ((line = br.readLine()) != null) {
lableName.add(line);
}
} catch (IOException e) {
System.err.format("IOException: %s%n", e);
}
return lableName;
}
private String tryRequestUrl(String url) throws Exception {
Document doc;
TrustAllCertificates.trustAllCertificates();
// System.out.println("tryRequestUrl: " + url);
try {
Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress("127.0.0.1", 8080));
doc = Jsoup.connect(url)
// .proxy(proxy)
.header("User-Agent","Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36")
// .header("User-Agent", "curl/7.71.1")
.header("Accept", "*/*")
// .header("Accept-Encoding", "gzip, deflate, br")
// .header("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
// .header("Cache-Control", "max-age=0")
.header("Connection", "close")
.get();
} catch (IOException e){
return null;
}
return doc.text();
}
private String requestParse(String url) throws Exception {
// TODO: 实现爬取并解析页面的逻辑
String htmlString = tryRequestUrl("http://www."+url);
if (htmlString == null){
htmlString = tryRequestUrl("https://www."+url);
}
if (htmlString == null && !url.startsWith("www.")){
// System.out.println();
htmlString = tryRequestUrl("http://www."+url);
if (htmlString == null){
htmlString = tryRequestUrl("https://www."+url);
}
}
return htmlString;
}
private String solveText(String text) {
// 去除除字母、数字、空格外的所有字符
text = text.replaceAll("[\\p{P}‘’“”]", "");
return text;
}
public static void main(String[] args) throws Exception {
Spider spider = new Spider();
System.out.println("开始爬取");
spider.run();
System.out.println("结束爬取");
}
}
处理数据
- 将得到的每一个网页文本进行分词
- 将分词后的每一个网页文本的词向量进行停用词过滤
// 将得到的每一个网页文本进行分词后进行停用词过滤
private static HashMap<List<Term>, String> getTrainText() {
String csvFile = "训练.csv"; // CSV文件路径
HanLP.Config.ShowTermNature = false; // 分词结果不显示词性
DoubleArrayTrieSegment segment = new DoubleArrayTrieSegment(); // 默认加载配置文件指定的 CoreDictionaryPath
segment.enablePartOfSpeechTagging(true); // 启用数词和单词检测
List<List<Term>> trainData = new ArrayList<>(); // trainData: 用来存储分词以后的训练数据
List<String> lables = new ArrayList<>(); // lables: 用来存储标签,为做卡方检测做准备
HashMap<List<Term>, String> lablesHashMap = new HashMap<>();
try (BufferedReader br = new BufferedReader(new FileReader(csvFile))) {
String line;
while ((line = br.readLine()) != null) {
// 拆分CSV行数据为字段数组
String[] fields = line.split(",");
// 检查至少存在三个字段
if (fields.length >= 3) {
String lable = fields[0].trim();
lables.add(lable);
// 提取第三列及其之后的数据
for (int i = 2; i < fields.length; i++) {
String data = fields[i].trim(); // 可能需要使用trim()方法去除字段值前后的空格
List<Term> list = segment.seg(data); // 接收分词结果
System.out.println(list);
// // 以下for循环用于去除空格
// for(Term term: list){
// if (!term.toString().equals(" ")){
// endList.add(term.toString());
// }
// }
list.removeAll(loadStopword()); // 用自带函数去除停用词
trainData.add(list);
lablesHashMap.put(list, lable);
// 在这里对数据进行处理或存储
// System.out.print(endList + " ");
}
// System.out.println();
}
}
} catch (IOException e) {
e.printStackTrace();
}
return lablesHashMap;
}
- 对最后得到的每个网页的词向量计算TF-IDF值和卡方值
- 获取每一个网页前一千卡方值的词的TF-IDF的值
// 计算卡方值并抽取每一类的前一千作为词典
package com.example.webproject;
import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.seg.Other.DoubleArrayTrieSegment;
import com.hankcs.hanlp.seg.common.Term;
import java.io.*;
import java.util.*;
import java.util.regex.Pattern;
public class FeatureDictionary {
public static void main(String[] args) {
HashMap<List<Term>, String> trainTextHashMap = getTrainText();
// 卡方检测生成特征词典
ChiSquareTest(trainTextHashMap);
}
private static HashMap<List<Term>, String> getTrainText() {
String csvFile = "训练.csv"; // CSV文件路径
HanLP.Config.ShowTermNature = false; // 分词结果不显示词性
DoubleArrayTrieSegment segment = new DoubleArrayTrieSegment(); // 默认加载配置文件指定的 CoreDictionaryPath
segment.enablePartOfSpeechTagging(true); // 启用数词和单词检测
List<List<Term>> trainData = new ArrayList<>(); // trainData: 用来存储分词以后的训练数据
List<String> lables = new ArrayList<>(); // lables: 用来存储标签,为做卡方检测做准备
HashMap<List<Term>, String> lablesHashMap = new HashMap<>();
try (BufferedReader br = new BufferedReader(new FileReader(csvFile))) {
String line;
while ((line = br.readLine()) != null) {
// 拆分CSV行数据为字段数组
String[] fields = line.split(",");
// 检查至少存在三个字段
if (fields.length >= 3) {
String lable = fields[0].trim();
lables.add(lable);
// 提取第三列及其之后的数据
for (int i = 2; i < fields.length; i++) {
String data = fields[i].trim(); // 可能需要使用trim()方法去除字段值前后的空格
List<Term> list = segment.seg(data); // 接收分词结果
// // 以下for循环用于去除空格
// for(Term term: list){
// if (!term.toString().equals(" ")){
// endList.add(term.toString());
// }
// }
list.removeAll(loadStopword()); // 用自带函数去除停用词
trainData.add(list);
lablesHashMap.put(list, lable);
// 在这里对数据进行处理或存储
// System.out.print(endList + " ");
}
// System.out.println();
}
}
} catch (IOException e) {
e.printStackTrace();
}
return lablesHashMap;
}
private static List<String> loadStopword() throws IOException {
// 缓冲区读取
BufferedReader bufferedReader = new BufferedReader(new FileReader(TrainingSetInformation.StopWordPath));
List<String> stopWords = new ArrayList<String>();
String temp = null;
// 按行获取
while ((temp = bufferedReader.readLine()) != null) {
stopWords.add(temp.trim());
}
return stopWords;
}
/**
* 其中:
* N 为文档总数,
* A 为在这个分类下包含这个词的文档数量,
* B 为不在该分类下包含这个词的文档数量,
* C 为在这个分类下不包含这个词的文档数量,
* D 为不在该分类下,且不包含这个词的文档数量
*
* @param trainTextHashMap key为lable,value为分词向量
*/
private static void ChiSquareTest(HashMap<List<Term>, String> trainTextHashMap) {
// 生成每一个类的每一文本的每个词在自己文本出现的次数
HashMap<String, List<Map<String, Double>>> trainTextTFHashMap = CalculateTermFrequency(trainTextHashMap);
// 通过map实现类和每个类下的词卡方
Map<String, Map<String, Double>> calculateChiSquareTestMapListMap = new HashMap<>();
for (String lable : TrainingSetInformation.fileNames) {// 遍历类别
// 计算每个类里每个词的卡方
Map<String, Double> calculateChiSquareTestMap = new HashMap<>();
for(Map<String, Double> map:trainTextTFHashMap.get(lable)){
for (String term:map.keySet()){
calculateChiSquareTestMap.put(term, CalculateChiSquareTest(trainTextTFHashMap, term, lable));
}
}
calculateChiSquareTestMapListMap.put(lable, calculateChiSquareTestMap);
}
Map<String, List<String>> featureDictionaryMap = getFeatureDictionary(calculateChiSquareTestMapListMap);
SaveFile(featureDictionaryMap);
}
private static HashMap<String, List<Map<String, Double>>> CalculateTermFrequency(HashMap<List<Term>, String> trainTextHashMap){
HashMap<String, List<Map<String, Double>>> trainTextTFHashMap = new HashMap<>();
for (String lable : TrainingSetInformation.fileNames) {// 遍历类别
List<Map<String, Double>> trainTextTFList= new ArrayList<>();// 保存一个类的每一文本的每个词在自己文本出现的次数的List
for (List<Term> termList : trainTextHashMap.keySet()) { // 遍历每个文本的分词结果
Map<String, Double> inverseDocumentFrequency = new HashMap<>(); // 保存每一文本的每个词在自己文本出现的次数Map
if (trainTextHashMap.get(termList).equals(lable + ".csv")) { // 判断是否是同一类
for (Term term : termList) { // 逐字分析
if (isChineseString(String.valueOf(term))){
inverseDocumentFrequency.put(String.valueOf(term), inverseDocumentFrequency.getOrDefault(String.valueOf(term), 0.0) + 1);
// 将term放进字典,并利用map自带的函数getOrDefault函数做判断,该函数判断该字符在不在字典中,在的话加一,不在默认地置为零
}
}
}else {
// 字符串中包含特殊字符或数字,跳过当前循环
continue;
}
trainTextTFList.add(inverseDocumentFrequency);
}
trainTextTFHashMap.put(lable, trainTextTFList);
}
return trainTextTFHashMap;
}
public static boolean isChineseString(String str) {
String pattern = "^[\u4E00-\u9FFF]+$";
return Pattern.matches(pattern, str);
}
private static double CalculateChiSquareTest(HashMap<String, List<Map<String, Double>>> trainTextTFHashMap,
String term,
String lable) {
// A
int A = CalculateAC(trainTextTFHashMap, term, lable)[0];
// B
int B = CalculateBD(trainTextTFHashMap, term, lable)[0];
// C
int C = CalculateAC(trainTextTFHashMap, term, lable)[1];
// D
int D = CalculateBD(trainTextTFHashMap, term, lable)[1];
// N
int N = trainTextTFHashMap.size();
// x
double x = 0;
if (((A+C)*(A+B)*(B+D)*(C+D))!=0){
x = (double) (N * ((A * D - B * C) ^ 2))/((A+C)*(A+B)*(B+D)*(C+D));
}else {
x = (double) (N * ((A * D - B * C) ^ 2));
}
return x;
}
private static int[] CalculateAC(HashMap<String, List<Map<String, Double>>> trainTextTFHashMap, String term, String lable){
int A = 0;
int C = 0;
int[] AC = new int[2];
for (Map<String, Double> map : trainTextTFHashMap.get(lable)) {
if(map.containsKey(term)){
A++;
}else {
C++;
}
}
AC[0] = A;
AC[1] = C;
return AC;
}
private static int[] CalculateBD(HashMap<String, List<Map<String, Double>>> trainTextTFHashMap, String term, String lable){
int B = 0;
int D = 0;
int[] BD = new int[2];
for (String otherLable : TrainingSetInformation.fileNames)
if (otherLable.equals(lable)){
for (Map<String, Double> map : trainTextTFHashMap.get(lable)) {
if(map.containsKey(term)){
B++;
}else {
D++;
}
}
}
BD[0] = B;
BD[1] = D;
return BD;
}
private static Map<String, List<String>> getFeatureDictionary(Map<String, Map<String, Double>> calculateChiSquareTestMapListMap){
Map<String, List<String>> featureDictionaryMap = new HashMap<>();
for (String lable:TrainingSetInformation.fileNames){
// 将Map的Entry对象存储到List中
List<Map.Entry<String, Double>> entryList = new ArrayList<>(calculateChiSquareTestMapListMap.get(lable).entrySet());
// 根据value的大小对List进行排序(从大到小)
Collections.sort(entryList, (entry1, entry2) -> entry2.getValue().compareTo(entry1.getValue()));
// 提取前一千个key值
List<String> topKeys = new ArrayList<>();
int limit = Math.min(1000, entryList.size());
for (int i = 0; i < limit; i++) {
topKeys.add(entryList.get(i).getKey());
}
System.out.println(lable + "共有" + topKeys.size() + "个词" + ":" + topKeys);
featureDictionaryMap.put(lable, topKeys);
}
return featureDictionaryMap;
}
public static void SaveFile(Map<String, List<String>> featureDictionaryMap) {
List<String> mergedList = new ArrayList<>();
for (String lable:TrainingSetInformation.fileNames){
String mergedString = String.join(",", featureDictionaryMap.get(lable));
mergedList.add(mergedString);
}
String filePath = "FeatureDictionary.txt"; // 指定输出文件路径
try (BufferedWriter writer = new BufferedWriter(new FileWriter(filePath))) {
for (String item : mergedList) {
writer.write(item);
writer.newLine();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
// 计算每个网页的词的TF-IDF值
private static Map<HashMap<String, Double>, String> getTfIdf(HashMap<List<Term>, String> documents) {
// train的idf值
Map<String, Double> idf = new HashMap<>();
idf = getInverseDocumentFrequency(documents); // 在这个函数里获得每个词在所有文档中的出现次数
int i = 1;
HashMap<String, Integer> allWords = new HashMap<>();
for (String str : idf.keySet()) {
allWords.put(str, i++);
}
// train tf-idf值:
// List<HashMap<Integer, Double>> tfidf = new ArrayList<>();
Map<HashMap<String, Double>, String> tfidfString = new HashMap<>();
int documentLength = documents.size(); // 总共文档数
for (List<Term> termList : documents.keySet()) {
Map<String, Integer> tf = getTermFrequency(termList);
int numDocuments = termList.size(); // 文档总单词数.
HashMap<Integer, Double> test = new HashMap<>();
HashMap<String, Double> stringTest = new HashMap<>();
for (String term : tf.keySet()) {
double result = computeTFIDF(documentLength, term, tf, idf, numDocuments);
test.put(allWords.get(term), result);
stringTest.put(term, result);
}
// tfidf.add(test);
tfidfString.put(stringTest, documents.get(termList));
}
// getKeyWord(tfidf, allWords); // 获得关键词
return tfidfString;
}
/**
* 生成idf
*
* @param documents 分词且停用词过滤后的分词List
* @return inverseDocumentFrequency 首先这个返回一是可以做全字符字典,二是可以做idf
*/
// 输入List<String>, 返回
private static Map<String, Double> getInverseDocumentFrequency(HashMap<List<Term>, String> documents) {
Map<String, Double> inverseDocumentFrequency = new HashMap<>();
for (List<Term> termList : documents.keySet()) {
for (Term term : termList) {
// 正则表达式匹配特殊字符和数字
String regex = "[^\\p{L}]+|\\d+"; // 匹配非字母和非数字的字符
if (String.valueOf(term).matches(".*" + regex + ".*")) {
// 字符串中包含特殊字符或数字,跳过当前循环
continue;
} else {
inverseDocumentFrequency.put(String.valueOf(term), inverseDocumentFrequency.getOrDefault(String.valueOf(term), 0.0) + 1);
// 将term放进字典,并利用map自带的函数getOrDefault函数做判断,该函数判断该字符在不在字典中,在的话加一,不在默认地置为零
}
}
}
return inverseDocumentFrequency;
}
/**
* 计算tf值
*
* @param terms
* @return
*/
private static Map<String, Integer> getTermFrequency(List<Term> terms) {
Map<String, Integer> termFrequency = new HashMap<>();
for (Term term : terms) {
termFrequency.put(String.valueOf(term), termFrequency.getOrDefault(term, 0) + 1);
}
return termFrequency;
}
/**
* 计算tf-idf值
*
* @param documentLength
* @param term
* @param tfMap
* @param dfMap
* @param numDocuments
* @return
*/
private static double computeTFIDF(double documentLength, String term, Map<String, Integer> tfMap, Map<String, Double> dfMap, int numDocuments) {
double tf = (double) tfMap.getOrDefault(term, 0) / documentLength;
double idf = Math.log((double) numDocuments / dfMap.getOrDefault(term, 0.0));
return tf * idf;
}
- 得到所有网页的特征向量
分类训练
- 将得到的所有网页的特征向量转成LibSVM所能接受的数据格式
// 将得到的所有网页的特征向量转成LibSVM所能接受的数据格式
private static svm_problem createProblem(Map<Map<Integer, Double>, String> trainData) {
Map<String, Integer> lablesMap = new HashMap<>();
int labelsMapNum = 1;
for (String lable:TrainingSetInformation.fileNames){
lablesMap.put(lable+".csv", labelsMapNum++);
}
svm_problem problem = new svm_problem();
problem.l = trainData.size(); // 训练数据的数量
problem.x = new svm_node[problem.l][]; // 数组指针,每一个指针指向一个训练向量,训练向量是一个 svm_node 的数组;
problem.y = new double[problem.l]; // lable
int i = 0;
for (Map<Integer, Double> featureVector:trainData.keySet()) {
problem.x[i] = new svm_node[featureVector.size()];
int y = 0;
Set<Integer> x = featureVector.keySet();
int[] sortIntArray = sortAndConvertToIntArray(x);
for (int j : sortIntArray) {
svm_node node = new svm_node();
node.index = j;
// System.out.println("index:" + node.index);
node.value = featureVector.get(j);
// System.out.println("value:" + node.value);
problem.x[i][y] = node;
y++;
}
problem.y[i] = lablesMap.get(trainData.get(featureVector));
i++;
}
return problem;
}
// 设置训练参数
public static svm_parameter createParameter(double C, double gamma) {
svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.RBF;
param.C = C;
param.gamma = gamma;
param.eps = 0.001;
return param;
}
- 调第三方包LibSVM进行训练
svm_model model = svm.svm_train(problem, param);
- 获得分类器
svm.svm_save_model("model.model", model);
屎删,纯屎删/(ㄒoㄒ)/~~