Java 机器学习库Smile实战(一)SVM

转自:微信公众号:燕哥带你学算法(Jeemy110

原文链接:点击打开链接


要使用Java机器学习库Smile,需首先在项目的Maven配置文件pom.xml中添加如下的maven依赖项:

<dependency>
   <groupId>com.github.haifengl</groupId>
   <artifactId>smile-core</artifactId>
   <version>1.4.0</version>
</dependency>


    Smile中的SVM是一个泛型类,他可以支持二分类和多分类两种使用方法,而且这两种使用方法差异较大,所以分开介绍。

1. 二分类

       Smile 库的SVM类是一个泛型类型,默认情况下进行二分类,选择参数为核函数类型和惩罚项参数。

import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

public class Demo {
public static void main(String[]args){

double gamma = 1.0;
       double C = 1.0;

       //通过某种方式获取训练数据及其类标
       double[][] data = ...
int[] label = ...

SVM<double[]> svm = new SVM<double[]>(
new GaussianKernel(gamma), C);
       
       svm.learn(data, label); //训练模型
       svm.finish();

       //获取测试数据
       double[][] testData = ...
int[] result = new int[testData.length];
       for(int i=0; i < testData.length; i++){
result[i] = svm.predict(testData[i]);
       }
}
}


2. 多分类

       接下来是我利用SVM对iris数据集进行分类的程序。首先我们将iris数据保存iris.txt文件,如下结构:

5.1 3.5 1.4 0.2 0

4.9 3  1.4 0.2 0

...


   每一行代表一个测试数据项,前4列是属性向量,最后一列是类标(在Smile中类标不能为负数,并且只能是从0开始的正整数,所以上述类标为:0、1、2)。检测的完整的源代码如下:


import smile.classification.SVM;
import smile.math.kernel.GaussianKernel;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
* Created by zhanghuayan on 2017/1/16.
*/
public class ClassificationTest {

public static void main(String[] args) throws Exception {

List<List<Double>> datas =
new ArrayList<List<Double>>();
       List<Double> data = new ArrayList<Double>();
       List<Integer> labels = new ArrayList<Integer>();

       String line;
       List<String> lines;
       File file = new File("iris.txt");
       BufferedReader reader =
new BufferedReader(new FileReader(file));
       
       while
((line = reader.readLine()) != null) {
lines = Arrays.asList(line.trim().split("\t"));
           for (int i = 0; i < lines.size() - 1; i++) {
data.add(Double.parseDouble(lines.get(i)));
           }
labels.add(Integer.parseInt(
lines.get(lines.size() - 1)));

           datas.add(data);
           data = new ArrayList<Double>();

       }

//转换label
       int[] label = new int[labels.size()];
       for (int i = 0; i < label.length; i++) {
label[i] = labels.get(i);
       }

//转换属性
       int rows = datas.size();
       int cols = datas.get(0).size();
       double[][] srcData = new double[rows][cols];
       for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
srcData[i][j] = datas.get(i).get(j);
           }
}

SVM<double[]> svm = new SVM<double[]>(
new GaussianKernel(1.0), 1.0, 3,
SVM.Multiclass.ONE_VS_ALL);

       svm.learn(srcData, label);
       svm.finish();

       double right = 0;
       for (int i = 0; i < srcData.length; i++) {
int tag = svm.predict(srcData[i]);
           if (tag == label[i]) {
right += 1;
           }
}
right = right / srcData.length;

       System.out.println(
"Accrurate: " + right * 100 + "%");
   }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值