import csv
import numpy as np
import matplotlib.pyplot as plt
def readData(filename):
"""
读取数据
:param filename: csv格式数据集
:return: X:list with shape[N,d], X1:shape[8,2], X2:shape[9,2]
y:list with shape[N,], y1:shape[8,1], y2:shape[9,1]
"""
X1, X2, y1, y2 = [], [], [], []
density, sugar = [], []
with open(filename) as f:
reader = csv.reader(f)
head_row = next(reader)
for line in reader:
if line[9] == '是':
X1.append([float(line[7]), float(line[8])])
y1.append([float(line[10])])
density.append(float(line[7]))
sugar.append(float(line[8]))
if line[9] == "否":
X2.append([float(line[7]), float(line[8])])
y2.append([float(line[10])])
density.append(float(line[7]))
sugar.append(float(line[8]))
return X1, X2, y1, y2, density
西瓜数 课后习题3.5 线性判别分析
最新推荐文章于 2023-10-18 11:44:51 发布