pytorch基础入门一: 训练原始的线性模型

pytorch基础入门一: 训练原始的线性模型

Backgroud 背景说明

由于公司业务需要(基于单张全景图的三维场景重建),从五一开始拾起了神经网络的学习;嗯!这个月的flag就是五月份完成业务模型的技术验证;
搜索到的几篇文章,都是基于PyTorch搭建的模型,那理所当然也就踏入PyTorch的世界;按照以往的学习习惯,一切从官网文档开始入手;
官网入门指南 第一节讲tensor的主要操作,这一部分理解起来问题不大,就是一些细节需要在具体使用中熟悉;
第二节 AUTOGRAD 自动求导,也能明白个大概 ;到第三节的时候,直接扑面而来的新对象

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 6 * 6, 120)  # 6*6 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

简直 懵逼树上懵逼果,懵逼树下你和我

那只能暂时放下tutorials,我胡汉三还会再回来的

然后,就找官方出版的开源书籍 Deep-Learning-with-PyTorch
第一章 介绍PyTorch概况,一眼带过
第二章 讲述 tensor 的底层实现,果然比 tutorials 详尽充实多了,看来这本书适合概念建立;
第三章 介绍输入数据的处理,也暂时跳过
第四章 神经网络学习机制的介绍,看到这就知道这章就是适合我的

所以本篇文章主要记录, Deep-Learning-with-PyTorch 书中第四章内容,包括以下几点:

  • 定义一个线性模型
  • 定义线性模型的损失函数(代价函数)
  • 实现线性模型的梯度回归
  • 学习速率的调试
  • 特征值的缩放

定义线性模型

假设的问题如下:

现在有两种类型的温度计,一个是已知单位,另一个是未知单位的温度计,需要拟合两种温度计的换算公式

作最简单的假设,换算公式为线性的:

# 对应的线性模型为
# t_c = w * t_u + b

对应的代码如下:

"""
PyTorch 基础入门
线性回归参数估计
问题: 华氏温度转换
"""
import torch

# 定义输入数据
t_c = [0.5,  14.0, 15.0, 28.0, 11.0,  8.0,  3.0, -4.0,  6.0, 13.0, 21.0]
t_u = [35.7, 55.9, 58.2, 81.9, 56.3, 48.9, 33.9, 21.8, 48.4, 60.4, 68.4]

t_c = torch.tensor(t_c)
t_u = torch.tensor(t_u)

# 对应的线性模型为
# t_c = w * t_u + b

def model(t_u, w, b):
    return w*t_u + b

定义损失函数

损失函数是用来判断模型参数拟合实验数据优劣的一个方法,所以损失函数也会存在多种多样,比如;
针对我们的问题,可以列出两种损失函数:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fvdrfn70-1589179469322)(https://cdn.jsdelivr.net/gh/austinxishou/pics/feimao/20200511134915.png)]

左边为 预估数据与实验数据的距离
右边为 预估数据与实验数据的距离平方

如果不绕弯子,直接给出结论的话,右边损失函数会优于左

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值