【凸优化】无约束优化的PyTorch实现(梯度下降法、牛顿下降法)+ torch.optim的使用例

【凸优化】无约束优化的PyTorch实现(梯度下降法、牛顿下降法)

本实验实现的机器学习算法使用PyTorch的应用程序接口(API,Application Program Interface) 实现。具体的,实验使用的torch.optim是一个实现多种优化算法的包,包含了几种最常用的机器学习优化方法代码。

目标函数
目标函数(一个典型的凸函数):
  f 0 = 1 2 A x 2 + b x \ f_0 = \frac{1}{2}Ax^2+bx  f0=21Ax2+bx

定义基本方法

import torch

def f0(x):  # 将这里的f0改为你的目标函数
    b = torch.tensor([-2., 0]).unsqueeze(1)
    A = torch.tensor([[3., -1], [-1, 1]])
    return 1 / 2 * x.t() @ A @ x + b.t() @ x  # example

def f0_grad(x):  # 计算梯度
    y = f0(x)
    grad = torch.autograd.grad(y, x, retain_graph=True, create_graph=True)[0]
    return grad

def f0_Hessian(x):  # 计算Hessian矩阵,用于牛顿法
    y = torch.tensor([])
    for anygrad in f0_grad(x):
        y = torch.cat((y, torch.autograd.grad(anygrad, x, retain_graph=True)[0]), 1)
    return y

梯度下降法

  1. 梯度下降法的步骤是:
    在这里插入图片描述
  2. 代码部分
import matplotlib.pyplot as plt
import pandas
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import Variable

def main():
    alpha = torch.tensor([0.25])
    beta = torch.tensor([0.5])

    ETA = 0.01
    t = torch.tensor([1.])
    x0 = torch.tensor([-2., 4.], requires_grad=True).reshape(2, 1)
    x = x0

    count = 0
    x_hist = []
    y_hist = []
    z_hist = []

    plt.rcParams['legend.fontsize'] = 10
    fig = plt.figure(figsize=(4, 4), dpi=200)
    ax = fig.add_subplot(projection='3d')

    while torch.norm(f0_grad(x)) > ETA:    
        delta_x = - f0_grad(x)
        t = torch.tensor([1.])
        while f0(x + t.mul(delta_x)) > (f0(x) + alpha.mul(t).mul((f0_grad(x).t()) @ delta_x)):
            t = beta.mul(t)
        x = x + t.mul(delta_x)
        count = count + 1
        print('iter ' + str(count) + ' : ' + str(x.tolist()))
    print('The optimal solution is ' + str(x.tolist()))
    print('After '' + str(count)' + ' iterations')

    tmp = pandas.DataFrame(x_hist, columns=['X'])
    tmp['Y'] = y_hist
    tmp['Z'] = z_hist
    ax.plot(tmp.X, tmp.Y, tmp.Z, label='gradient descent curve')
    ax.legend()

    plt.show()


if __name__ == '__main__':
    main()
  1. 运行结果
    梯度下降法

牛顿下降法

  1. 牛顿下降法的步骤是:
    在这里插入图片描述

  2. 代码部分

import matplotlib.pyplot as plt
import pandas
from mpl_toolkits.mplot3d import Axes3D
from torch.autograd import Variable


def n_step(x):  # 计算步长
    return -torch.linalg.inv(f0_Hessian(x)) @ f0_grad(x)

def n_decrement(x):  # 计算减量
    return f0_grad(x).t() @ (torch.linalg.inv(f0_Hessian(x))) @ f0_grad(x)

def main():
    alpha = torch.tensor([0.25])
    beta = torch.tensor([0.5])

    eta = torch.tensor([0.01])
    t = torch.tensor([1.])

    x0 = torch.tensor([-2., 4.], requires_grad=True).reshape(2, 1)
    x = x0

    count = 0
    x_hist = []
    y_hist = []
    z_hist = []
    eta

    plt.rcParams['legend.fontsize'] = 10
    fig = plt.figure(figsize=(4, 4), dpi=200)
    ax = fig.add_subplot(projection='3d')

    newton_decrement = n_decrement(x)

    while newton_decrement > eta:
        
        newton_step = n_step(x)
        newton_decrement = n_decrement(x)
        t = torch.tensor([1.])
        while f0(x + t.mul(newton_step)) > (f0(x) - alpha.mul(t).mul(newton_decrement)):
            t = beta.mul(t)

        x = x + t.mul(newton_step)

        count = count + 1
        print('iter ' + str(count) + ' : ' + str(x.tolist()))

    print('The optimal solution is ' + str(x.tolist()))
    print('After ' + str(count) + ' iterations')

    tmp = pandas.DataFrame(x_hist, columns=['X'])
    tmp['Y'] = y_hist
    tmp['Z'] = z_hist

    ax.plot(tmp.X, tmp.Y, tmp.Z, label='Newton descent curve')
    ax.legend()

    plt.show()


if __name__ == '__main__':
    main()
  1. 运行结果
    牛顿下降法

附加题:使用torch自带的SGD、RMSProp、Rprop实现无约束优化

  1. 代码
def main():
    epoch = 50

    x = torch.tensor([-2., 4.]).reshape(2, 1)

    optimizer = torch.optim.SGD([x, ], lr=0.1, momentum=0)
    # optimizer = torch.optim.RMSprop([x,],lr=0.1)
    # optimizer = torch.optim.Rprop([x,],lr=0.01)
    ''' 选择合适的优化器取消注释 '''
    count = 0
    x.requires_grad = True
    for step in range(epoch):

        if step:
            f.backward()
            optimizer.step()
            count = count + 1
            print('iter ' + str(count) + ' : ' + str(x.tolist()))
        f = f0(x)

if __name__ == '__main__':
    main()
  1. 运行结果
    在这里插入图片描述

总结

几种优化法的可视化
几种优化方法的可视化

这里的动图参考了另外一篇博客 stackoverflow
在这里插入图片描述

如果我的博客对你有帮助,欢迎点赞收藏
欢迎转载,转载请注明出处

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值