李宏毅《机器学习》| 回归-Gradient Descent代码展示

 上一篇笔记写了LeeML的回归部分,此处为gradient descent的代码展示,并写了较为详细的注释

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt

x_data = [338., 333., 328., 207., 226., 25., 179., 60., 208., 606.]
y_data = [640., 633., 619., 393., 428., 27., 193., 66., 226., 1591.]
#y_data = b+w*x_data

x = np.arange(-200, -100, 1)   #bias,返回一个有终点和起点的固定步长的排列
y = np.arange(-5, 5, 0.1)      #weight
Z = np.zeros((len(x),len(y)))
# array([[0., 0., 0., ..., 0., 0., 0.],
#        [0., 0., 0., ..., 0., 0., 0.],
#        [0., 0., 0., ..., 0., 0., 0.],
#        ...,
#        [0., 0., 0., ..., 0., 0., 0.],
#        [0., 0., 0., ..., 0., 0., 0.],
#        [0., 0., 0., ..., 0., 0., 0.]])
X, Y = np.meshgrid(x, y)       #生成网格点坐标矩阵
# X = array([[-200, -199, -198, ..., -103, -102, -101],
#          [-200, -199, -198, ..., -103, -102, -101],
#          [-200, -199, -198, ..., -103, -102, -101],
#          ...,
#          [-200, -199, -198, ..., -103, -102, -101],
#          [-200, -199, -198, ..., -103, -102, -101],
#          [-200, -199, -198, ..., -103, -102, -101]])
for i in range(len(x)):        #i,j都在(0, 100)里
    for j in range(len(y)):
        b = x[i]
        w = y[j]
        Z[j][i] = 0
        for n in range(len(x_data)):           #n在(0, 10)里
            Z[j][i] = Z[j][i] + (y_data[n]-b-w*x_data[n])**2
            #每一个b和w求出的是10个误差的和
        Z[j][i] = Z[j][i]/len(x_data)          #求每组b和w的平均误差

b = -120        #initial b
w = -4          #initial w
lr = 0.0000001  #learning rate
iteration = 100000
#store initial values for plotting
b_history = [b]
w_history = [w]

for i in range(iteration):
    b_grad = 0.0
    w_grad = 0.0
    for n in range(len(x_data)):   #计算每一个b和w对loss的偏微分
        b_grad = b_grad-2.0*(y_data[n]-b-w*x_data[n])*1.0
        w_grad = w_grad-2.0*(y_data[n]-b-w*x_data[n])*x_data[n]
    #update parameters
    b = b-lr*b_grad
    w = w-lr*w_grad
    #store parameters for plotting
    b_history.append(b)            #在被选元素的结尾插入指定内容
    w_history.append(w)

#绘制等高线
plt.contourf(x, y, Z,   #x,y数据需为mesh网格状数据,等高线的显示是在网格的基础上添加高度值
             50,        #绘制轮廓的高度值
             alpha=0.5, #混合值,介于0(透明)和1(不透明)之间
             cmap=plt.get_cmap('jet'))
plt.plot([-188.4], [2.67], 'x',
         ms=12, markeredgewidth=3,     #分别代表指定点的长度和宽度
         color='orange')
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$', fontsize=16)
plt.ylabel(r'$w$', fontsize=16)
plt.show()
#不同颜色代表不同参数下的loss,橙色x所处区域附近是loss最低的位置
#当前迭代的情况显示,从初始值(黑线最下方)一直向上迭代到某处向左转
#但迭代100000次后依然离最佳点有很大距离,说明learning rate太小


lr = 0.000001  #将learning rate扩大10倍
for i in range(iteration):
    b_grad = 0.0
    w_grad = 0.0
    for n in range(len(x_data)):
        b_grad = b_grad-2.0*(y_data[n]-b-w*x_data[n])*1.0
        w_grad = w_grad-2.0*(y_data[n]-b-w*x_data[n])*x_data[n]
    b = b-lr*b_grad
    w = w-lr*w_grad
    b_history.append(b)
    w_history.append(w)

plt.contourf(x, y, Z, 50, alpha=0.5, cmap=plt.get_cmap('jet'))
plt.plot([-188.4], [2.67], 'x', ms=12, markeredgewidth=3, color='orange')
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$', fontsize=16)
plt.ylabel(r'$w$', fontsize=16)
plt.show()
#此时效果比之前好了很多,但发生了剧烈的震荡
#将lr再扩大就会发现震荡过于剧烈,loss已无法在图中显示


#考虑将lr变成特制化的learning rate
lr = 1      #调节不同的lr参数可以获得不同的曲线长度
lr_b = 0
lr_w = 0
for i in range(iteration):
    b_grad = 0.0
    w_grad = 0.0
    for n in range(len(x_data)):
        b_grad = b_grad-2.0*(y_data[n]-b-w*x_data[n])*1.0
        w_grad = w_grad-2.0*(y_data[n]-b-w*x_data[n])*x_data[n]
    
    lr_b = lr_b + b_grad**2
    lr_w = lr_w + w_grad**2
    b = b-lr/np.sqrt(lr_b)*b_grad    #对b、w定制化学习率lr,采用Adagard
    w = w-lr/np.sqrt(lr_w)*w_grad
    b_history.append(b)
    w_history.append(w)

plt.contourf(x, y, Z, 50, alpha=0.5, cmap=plt.get_cmap('jet'))
plt.plot([-188.4], [2.67], 'x', ms=12, markeredgewidth=3, color='orange')
plt.plot(b_history, w_history, 'o-', ms=3, lw=1.5, color='black')
plt.xlim(-200, -100)
plt.ylim(-5, 5)
plt.xlabel(r'$b$', fontsize=16)
plt.ylabel(r'$w$', fontsize=16)
plt.show()

但是其中关于坐标矩阵这一部分查了相关资料,还是有一点不太理解。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值