数据集
- 数据集:ris鸢尾花数据集,它包含3个不同品种的鸢尾花:[Setosa,Versicolour,and Virginica]数据,特征:[‘sepal length’, ‘sepal width’, ‘petal length’, ‘petal width’],一共150个数据。由于这是2分类问题,所以选择前两类数据进行算法测试。
代码实现
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
def create_data():
iris = load_iris()
Data=np.array(iris["data"])[:100]
Label=np.array(iris["target"])[:100]
Label=Label*2-1
print("dara shape:",Data.shape)
print("label shape:",Label.shape)
return Data, Label
class SVM:
def __init__(self, max_iter=100, kernel='linear'):
self.max_iter = max_iter
self._kernel = kernel
def init_args(self, features, labels):
self.m, self.n = features.shape
self.X = features
self.Y = labels
self.b = 0.0
self.alpha = np.ones(self.m)
self.computer_product_matrix()
self.C = 1.0
self.create_E()
def judge_KKT(self, i):
y_g = self.function_g(i) * self.Y[i]
if self.alpha[i] == 0:
return y_g >= 1
elif 0 < self.alpha[i] < self.C:
return y_g ==