import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
# from sklearn.cross_validation import train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import matplotlib.pyplot as plt
sales = pd.read_csv('Advertising.csv')
a = sales.head()
b = sales.describe()
# print(a)
# print(b)
#抽样,构造训练集和测试集
#train_test_split:交叉验证中常用的函数,功能:随机划分训练集和测试集
Train,Test = train_test_split(sales,train_size = 0.8,random_state = 1234)
#建模
fit = smf.ols('sales~TV+radio+newspaper',data = Train).fit()
#模型概览
c = fit.summary()
# print(c)
#重新建模(newspapaer这个广告渠道并没有影响到销售量的变动,需要剔除)
#模型是通过显著性检验的,但是模型的显著性通过检验的话,并不代表每个自变量都对因变量是重要的
fit2 = smf.ols('sales~TV+radio',data = Train.drop('newspaper',axis = 1)).fit()
#模型信息反馈
d = fit2.summary()
# print(d)
#第一个模型的预测结果
pred = fit.predict(exog = Test)
#第二个模型的预测结果
pred2 = fit2.predict(exog = Test.drop('newspaper',axis = 1))
#模型效果对比(RMSE为均方根误差,即真实值与预测值的均方根,值越小,模型越优秀)
RMSE = np.sqrt(mean_squared_error(Test.sales,pred))
RMSE2 = np.sqrt(mean_squared_error(Test.sales,pred2))
print('第一个模型的预测结果:RMES=%.4f\n' %RMSE)
print('第二个模型的预测结果:RMES=%.4f\n' %RMSE2)
#真实值与预测值的关系
#设置绘图风格
plt.style.use('ggplot')
#设置中文编码和负号的正常显示
plt.rcParams['font.sans-serif'] = 'Microsoft YaHei'
#散点图
plt.scatter(Test.sales,pred,label = '观测点')
#回归线
plt.plot([Test.sales.min(),Test.sales.max()],[pred.min(),pred.max()],'r--',lw=2,label = '拟合线')
#添加轴标签和标题
plt.title('真实值VS.预测值')
plt.xlabel('真实值')
plt.ylabel('预测值')
#去除图边框的顶部刻度和右边刻度
plt.tick_params(top = 'off',right = 'off')
#添加图例
plt.legend(loc = 'upper left')
#图形展现
plt.show()
python3 线性回归
最新推荐文章于 2024-09-10 11:06:02 发布