nn.linear()函数

本文介绍了如何使用PyTorch实现一个简单的线性模型LinearFC,探讨了nn.Linear的作用,并解释了训练时调用Net.train()的原因。重点讲解了模型参数初始化及dropout在训练过程中的运用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearFC(nn.Module):

    def __init__(self):
        super(DropoutFC, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, input):
        out = self.fc(input)
        return out

Net = LinearFC()
x = torch.randint(10, (2, 3)).float()  # 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据
Net.train()
output = Net(x)
print(output)

# train the Net

创建了一个最简单的LinearFC模型,里面有一个线性函数nn.Linear(3, 2),线性变换公式为:y=xWT+by=x W^T + by=xWT+b

通过Debug,一步一步查看运行情况:

在这里插入图片描述

当前这一步可以看到模型给我们随机初始化了权重W2×3W_{2 \times 3}W2×3和偏置b2×3b_{2 \times 3}b2×3,为什么权重WWW的shape是2×32\times32×3,因为公式里需要转置。

xxx随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据。
在这里插入图片描述
可以看出使用模型算出来的output,与手动使用公式算出来的结果一致。
在这里插入图片描述

Net.train()的作用

当网络中有 dropout,Batch Normalization 的时候。训练的要记得 Net.train(), 测试 要记得 Net.eval()。

在训练模型时会在前面加上:

Net.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值