说明:下文所使用的训练数据集ex2data1.txt来自Andrew Ng的机器学习公开课,数据集中包含有学生两次测试的得分和学生的录取情况。
代码实现:
%matplotlib notebook
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# 读入数据:
data = pd.read_csv('ex2data1.txt',names=['Exam1','Exam2','Admitted'])
data.head() # 查看data中的前五条数据
# 查看学生的录取情况与两次测试的分之间的关系:
fig,axes = plt.subplots()
sns.scatterplot(x='Exam1',y='Exam2',hue='Admitted',s=100,style='Admitted',data=data,ax=axes)
axes.set_title("Student's admission situation")
fig.savefig('Admitting.png')
# 定义sigmoid函数:
def sigmoid(z):
retur