线性拟合LF,Python代码实现+Qt界面

from LF_Ui import *
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import * 
import sys
import matplotlib.pyplot as plt
import numpy as np
import os
import math
import time

TEMP_PATH = "temp.jpg"          #缓冲图像

class LF(QDialog):
    def __init__(self,ui,parent = None):
        super().__init__(parent)
        self._Qimg = QPixmap()                      #用于显示图片
        self.Data = None                            #数据   {x,y}
        self.function_Data = None                   #       f(x)
        self.__ui = ui                              #ui界面
        self._Learning_rate = 0                     #学习率
        self._k = 0                                 #斜率   theta_1
        self._b = 0                                 #截距   theta_0
        self.done = False                           #线性拟合完成
        self.Iteration = 0                          #迭代次数
        self._threshold = 0                         #阈值
        self.__scaledImg = None                     #防止图片过大,缩放后的图片
        self.__ui.setupUi(self)                                     #给界面上控件
        self.__ui.ImgLabel.setAlignment(Qt.AlignCenter)             #设置对其方式:居中对齐
        """给对应控件关联函数"""
        self.__ui.get_btn.clicked.connect(self.on_getData_btn)
        self.__ui.step_btn.clicked.connect(self.on_stepTostep_btn)
        self.__ui.over_btn.clicked.connect(self.on_stepOver_btn)
        self.__ui.predicate_btn.clicked.connect(self.on_predicate_btn)
    def on_getData_btn(self):
        """功能:得到编辑框中的信息,只能输入数字"""
        try:
            self._k = float(self.__ui.k.text())
            self._b = float(self.__ui.b.text())
            self._Learning_rate = float(self.__ui.learning_rate.text())
            self._threshold = float(self.__ui.threshold.text())
            self.Data = np.random.multivariate_normal([0, 0], [[1, 0.9], [0.9, 1]], 100)           #多元正太分布,生成离散变量
            self.Iteration = 0
            self.done = False
            self._update()
        except ValueError:
            QMessageBox(QMessageBox.Critical,"error type!","Please input with number").exec_()     #提醒需要输入数字
            return
    def on_stepTostep_btn(self):
        """功能:单步操作"""
        self.LF_process()           
    def on_stepOver_btn(self):
        """功能:一步完成"""
        while(not self.done):
            self.LF_process()
    def on_predicate_btn(self):
        """功能:使用迭代后的参数预测数值"""
        if not self.done:
            QMessageBox(QMessageBox.Information,"Warning!","Complete linear fit").exec_()          #提醒需要完成线性拟合
            return
        try:
            x = float(self.__ui.x.text())
            self.__ui.h_x.setText("{:.2f}".format(self._k * x + self._b))                          #预测的数值
        except ValueError:
            QMessageBox(QMessageBox.Critical,"error type!","Please input with number").exec_()     #提醒需要输入数字
            return

    def _update(self):                    #更新界面(刷新图片)
        self.__showGraph()                #对图片的处理
        if os.path.exists(TEMP_PATH):     #缓存图片
            self._Qimg.load(TEMP_PATH)    #加载图片
            if self._Qimg.size().width() > self.__ui.ImgLabel.size().width() and   \
                self._Qimg.size().height() > self.__ui.ImgLabel.size().height():                #图片过大则缩放图片
                self.__scaledImg = self._Qimg.scaled(self.__ui.ImgLabel.size)
            elif self._Qimg.size().width() > self.__ui.ImgLabel.size().width():                 #根据宽缩放
                self.__scaledImg = self._Qimg.scaledToWidth(self.__ui.ImgLabel.size().width())
            elif self._Qimg.size().height() > self.__ui.ImgLabel.size().height():               #根据高缩放
                self.__scaledImg = self._Qimg.scaledToHeight(self.__ui.ImgLabel.size().height())
            else:
                 self.__scaledImg = self._Qimg.copy()               #复制该图片信息
            self.__ui.ImgLabel.setPixmap(self.__scaledImg)          #给Label贴上图片
        super().update()                                            #调用父类的update函数
    
    def __showGraph(self):                                          #显示波形信息图
        plt.clf()                                                   #清除原图像
        if self.Data.shape[0] != 0:
            plt.subplot(1,1,1)                                               #设置图片个数,排版方式
            plt.scatter(self.Data[:,0],self.Data[:,1],color= "black")        #描点
            if self.Iteration != 0:
                plt.plot(self.Data[:,0],self.function_Data,color= "red")     #拟合曲线
            plt.title("Discrete values Set")                                 #图像注释
            plt.grid('on')                                                   #标尺,on:有,off:无。
        plt.savefig(TEMP_PATH)                                               #存储临时文件

    def LF_process(self):
        """功能:进行拟合"""
        if self.done or self.Iteration == self.Data.shape[0]:
            return
        k,b = self.Calu_function()                          #迭代计算,更新参数
        self.__ui.p_k.setText("{:.2f}".format(self._k))     #将参数显示在界面上
        self.__ui.p_b.setText("{:.2f}".format(self._b))
        self.__ui.Iteration.setText(str(self.Iteration))    #显示迭代次数,迭代次数不能超过离散点的个数
        self._update()
    def Calu_function(self):
        """功能:进行迭代计算,更新参数"""
        """
            公式:
                theta_1 = theta_1 - alpha * (1/m) * [Sum(i = 0 till i = m)(h(xi) - yi)] * xi
                theta_0 = theta_0 - alpha * (1/m) * [Sum(i = 0 till i = m)(h(xi) - yi)]
        """
        sum_value = (self.LF_function(self._k,self._b) - self.Data[:,1]).sum()      #记 r = Sum(i = 0 till i = m)(h(xi) - yi) 
        calu_value = self._Learning_rate / self.Data.shape[0] * sum_value           #记 c = alpha * (1/m) * r
        k = self._k - calu_value * self.Data[self.Iteration][0]                     # theta_1 = theta_1 - c * xi
        b = self._b - calu_value                                                    # theta_0 = theta_0 - c
        if self.Judge_function(k,b):                                                # 计算误差
            QMessageBox(QMessageBox.Information,"Done!","Processed over.").exec_()     
            self.done = True
        """更新参数"""
        self._k = k
        self._b = b
        self.Iteration += 1
        self.function_Data = self.LF_function(k,b)
        return k,b
    def Judge_function(self,k,b):
        """"
            功能:计算误差,判断是否收敛,收敛条件为 abs(J(L) - J(L-1)) < threshold,threshold应为极小值,该代码默认为0.000001
            J(theta_0,theta_1) = (1/(2*m)) * {[Sum(i = 0 till i = m)(h(xi) - yi)]**2}
                               = (1/(2*m)) * {[Sum(i = 0 till i = m)(theta_0 + theta_1*xi - yi)]**2}
        """
        j_old = 1.0/(2 * self.Data.shape[0]) * (np.square((self.LF_function(self._k,self._b) - self.Data[:,1])).sum())
        j_new = 1.0/(2 * self.Data.shape[0]) * (np.square((self.LF_function(k,b) - self.Data[:,1])).sum())
        if abs(j_old - j_new) < self._threshold:
            return True
        else:
            return False
    def LF_function(self,k,b):
        """功能:根据k,b的值得到h(x)"""
        return k*self.Data[:,0] + b
    def closeEvent(self, event):                                #当界面关闭的时候响应该函数
        if os.path.exists(TEMP_PATH):                           #判断是否存在缓冲图像
            os.remove(TEMP_PATH)                                #删除该图像
        super().closeEvent(event)                               #响应父类的closeEvent

if __name__ == "__main__":
    app = QApplication(sys.argv)
    ui = Ui_LF()
    p = LF(ui)
    sys.exit(app.exit(p.exec_()))







代码+UI:

链接:https://pan.baidu.com/s/1n5zFc5hoVs7E5FYzi4nWAw 
提取码:c5wz

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值