目录
2.PyTorch实现 线性回归的封装写法(实际项目中的常用写法)
1. 基础写法
1.1导包
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
2.2加载读取数据¶
data = pd.read_csv('./dataset/Income1.csv')
data
#读取数据类型为dataframe类型
输出结果截图所示(部分数据)
data.head() #查看dataframe数据的前五条数据
data.tail() #后五条数据
data.Education.head() #查看数据的Education列的前五条数据 #是一个Series
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
data.Education[:5] #查看数据的Education列的前五条数据
0 10.000000 1 10.401338 2 10.842809 3 11.244147 4 11.645485 Name: Education, dtype: float64
2.3原始数据可视化(画图显示)
#画散点图,观察数据Education 与 Income 是否具有线性关系
plt.scatter(data.Education, data.Income)
plt.xlabel('Education')
plt.ylabel('Income')