局部加权线性回归

本文详细介绍了局部加权线性回归算法,探讨了解决线性回归中欠拟合问题的方法。通过引入高斯核函数对数据点进行加权,实现对附近点的重视,提高了模型的预测精度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 目的

在线性回归中会出现欠拟合的情况,解决办法之一就是局部加权线性回归,
此时的线性回归系数可以表示为:
W^=(xTMX)−1XTMY\hat W = (x^TMX)^{-1}X^TMYW^=(xTMX)1XTMY
M为每个点的权重。使用核函数对附近的点赋予跟高的权重,常用的为高斯核。
M(i,i)=exp(∣∣Xi−X∣∣2−2k2)M(i,i) = exp(\frac {||X^i - X||^2}{-2k^2})M(i,i)=exp(2k2XiX2)

2 算法

def lwlr_function(feature, label, k):
    '''局部加权线性回归
    input:  feature(mat):特征
            label(mat):标签
            k(int):核函数的系数
    output: predict(mat):最终的结果
    '''
    m = np.shape(feature)[0]
    predict = np.zeros(m)
    weights = np.mat(np.eye(m))
    for i in range(m):
        for j in range(m):
            diff = feature[i] - feature[j]
            weights[j,j] = np.exp(diff * diff.T / (-2.0 * k ** 2))
        xTx = feature.T * (weights * feature)
        ws = xTx.I * (feature.T * (weights * label))
        predict[i] = feature[i] * ws
    return predict

3 生成数据集

    feature = np.random.normal(0, 3, 100)
    feature = np.array(feature).reshape(100,1)
   
    label = 2 * np.sin(feature) + np.random.normal(0, 0.07, 100).reshape(100,1)
    predict = lwlr_1(feature, label, 1)
    m = np.shape(predict)[0]

4 完整代码

# coding:UTF-8

import numpy as np
import matplotlib.pyplot as plt

def lwlr_function(feature, label, k):
    '''局部加权线性回归
    input:  feature(mat):特征
            label(mat):标签
            k(int):核函数的系数
    output: predict(mat):最终的结果
    '''
    m = np.shape(feature)[0]
    predict = np.zeros(m)
    weights = np.mat(np.eye(m))
    for i in range(m):
        for j in range(m):
            diff = feature[i] - feature[j]
            weights[j,j] = np.exp(diff * diff.T / (-2.0 * k ** 2))
        xTx = feature.T * (weights * feature)
        ws = xTx.I * (feature.T * (weights * label))
        predict[i] = feature[i] * ws
    return predict
if __name__ == "__main__":
    x = np.linspace(0,20,100).reshape(100,1)
    t = np.sin(x)+ 0.2*x
    y = t + np.random.normal(0, 0.3, 100).reshape(100,1)
    z = lwlr_function(x, y, 0.6)
    plt.plot(x,t,'g')
    plt.plot(x,y,'b.')
    plt.plot(x,z,'r-')
    plt.show()

5 运行结果

在这里插入图片描述蓝色的是加了噪声的数据,绿色为原数据点,红色为拟合后的数据点

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值