from LF_Ui import *
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
import sys
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time
TEMP_PATH = "temp.jpg" #缓冲图像
class LF(QDialog):
def __init__(self,ui,parent = None):
super().__init__(parent)
self._Qimg = QPixmap() #用于显示图片
self.Data = None #数据 {x,y}
self.function_Data = None # f(x)
self.__ui = ui #ui界面
self._Learning_rate = 0 #学习率
self._k = 0 #斜率 theta_1
self._b = 0 #截距 theta_0
self.done = False #线性拟合完成
self.Iteration = 0 #迭代次数
self._threshold = 0 #阈值
self.__scaledImg = None #防止图片过大,缩放后的图片
self.__ui.setupUi(self) #给界面上控件
self.__ui.ImgLabel.setAlignment(Qt.AlignCenter) #设置对其方式:居中对齐
"""给对应控件关联函数"""
self.__ui.get_btn.clicked.connect(self.on_getData_btn)
self.__ui.step_btn.clicked.connect(self.on_stepTostep_btn)
self.__ui.over_btn.clicked.connect(self.on_stepOver_btn)
self.__ui.predicate_btn.clicked.connect(self.on_predicate_btn)
def on_getData_btn(self):
"""功能:得到编辑框中的信息,只能输入数字"""
try:
self._k = float(self.__ui.k.text())
self._b = float(self.__ui.b.text())
self._Learning_rate = float(self.__ui.learning_rate.text())
self._threshold = float(self.__ui.threshold.text())
self.Data = np.random.multivariate_normal([0, 0], [[1, 0.9], [0.9, 1]], 100) #多元正太分布,生成离散变量
self.Iteration = 0
self.done = False
self._update()
except ValueError:
QMessageBox(QMessageBox.Critical,"error type!","Please input with number").exec_() #提醒需要输入数字
return
def on_stepTostep_btn(self):
"""功能:单步操作"""
self.LF_process()
def on_stepOver_btn(self):
"""功能:一步完成"""
while(not self.done):
self.LF_process()
def on_predicate_btn(self):
"""功能:使用迭代后的参数预测数值"""
if not self.done:
QMessageBox(QMessageBox.Information,"Warning!","Complete linear fit").exec_() #提醒需要完成线性拟合
return
try:
x = float(self.__ui.x.text())
self.__ui.h_x.setText("{:.2f}".format(self._k * x + self._b)) #预测的数值
except ValueError:
QMessageBox(QMessageBox.Critical,"error type!","Please input with number").exec_() #提醒需要输入数字
return
def _update(self): #更新界面(刷新图片)
self.__showGraph() #对图片的处理
if os.path.exists(TEMP_PATH): #缓存图片
self._Qimg.load(TEMP_PATH) #加载图片
if self._Qimg.size().width() > self.__ui.ImgLabel.size().width() and \
self._Qimg.size().height() > self.__ui.ImgLabel.size().height(): #图片过大则缩放图片
self.__scaledImg = self._Qimg.scaled(self.__ui.ImgLabel.size)
elif self._Qimg.size().width() > self.__ui.ImgLabel.size().width(): #根据宽缩放
self.__scaledImg = self._Qimg.scaledToWidth(self.__ui.ImgLabel.size().width())
elif self._Qimg.size().height() > self.__ui.ImgLabel.size().height(): #根据高缩放
self.__scaledImg = self._Qimg.scaledToHeight(self.__ui.ImgLabel.size().height())
else:
self.__scaledImg = self._Qimg.copy() #复制该图片信息
self.__ui.ImgLabel.setPixmap(self.__scaledImg) #给Label贴上图片
super().update() #调用父类的update函数
def __showGraph(self): #显示波形信息图
plt.clf() #清除原图像
if self.Data.shape[0] != 0:
plt.subplot(1,1,1) #设置图片个数,排版方式
plt.scatter(self.Data[:,0],self.Data[:,1],color= "black") #描点
if self.Iteration != 0:
plt.plot(self.Data[:,0],self.function_Data,color= "red") #拟合曲线
plt.title("Discrete values Set") #图像注释
plt.grid('on') #标尺,on:有,off:无。
plt.savefig(TEMP_PATH) #存储临时文件
def LF_process(self):
"""功能:进行拟合"""
if self.done or self.Iteration == self.Data.shape[0]:
return
k,b = self.Calu_function() #迭代计算,更新参数
self.__ui.p_k.setText("{:.2f}".format(self._k)) #将参数显示在界面上
self.__ui.p_b.setText("{:.2f}".format(self._b))
self.__ui.Iteration.setText(str(self.Iteration)) #显示迭代次数,迭代次数不能超过离散点的个数
self._update()
def Calu_function(self):
"""功能:进行迭代计算,更新参数"""
"""
公式:
theta_1 = theta_1 - alpha * (1/m) * [Sum(i = 0 till i = m)(h(xi) - yi)] * xi
theta_0 = theta_0 - alpha * (1/m) * [Sum(i = 0 till i = m)(h(xi) - yi)]
"""
sum_value = (self.LF_function(self._k,self._b) - self.Data[:,1]).sum() #记 r = Sum(i = 0 till i = m)(h(xi) - yi)
calu_value = self._Learning_rate / self.Data.shape[0] * sum_value #记 c = alpha * (1/m) * r
k = self._k - calu_value * self.Data[self.Iteration][0] # theta_1 = theta_1 - c * xi
b = self._b - calu_value # theta_0 = theta_0 - c
if self.Judge_function(k,b): # 计算误差
QMessageBox(QMessageBox.Information,"Done!","Processed over.").exec_()
self.done = True
"""更新参数"""
self._k = k
self._b = b
self.Iteration += 1
self.function_Data = self.LF_function(k,b)
return k,b
def Judge_function(self,k,b):
""""
功能:计算误差,判断是否收敛,收敛条件为 abs(J(L) - J(L-1)) < threshold,threshold应为极小值,该代码默认为0.000001
J(theta_0,theta_1) = (1/(2*m)) * {[Sum(i = 0 till i = m)(h(xi) - yi)]**2}
= (1/(2*m)) * {[Sum(i = 0 till i = m)(theta_0 + theta_1*xi - yi)]**2}
"""
j_old = 1.0/(2 * self.Data.shape[0]) * (np.square((self.LF_function(self._k,self._b) - self.Data[:,1])).sum())
j_new = 1.0/(2 * self.Data.shape[0]) * (np.square((self.LF_function(k,b) - self.Data[:,1])).sum())
if abs(j_old - j_new) < self._threshold:
return True
else:
return False
def LF_function(self,k,b):
"""功能:根据k,b的值得到h(x)"""
return k*self.Data[:,0] + b
def closeEvent(self, event): #当界面关闭的时候响应该函数
if os.path.exists(TEMP_PATH): #判断是否存在缓冲图像
os.remove(TEMP_PATH) #删除该图像
super().closeEvent(event) #响应父类的closeEvent
if __name__ == "__main__":
app = QApplication(sys.argv)
ui = Ui_LF()
p = LF(ui)
sys.exit(app.exit(p.exec_()))
代码+UI:
链接:https://pan.baidu.com/s/1n5zFc5hoVs7E5FYzi4nWAw
提取码:c5wz