torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
-
in_features – 输入样本的大小
-
out_features – 输出样本的大小
-
bias – 如果设置为False,该层将不会学习偏移量。默认值:True
nn.Linear()的作用是对传入的数据进行线性变换。
![]()
传入的数据是x,经过变换后得到y,A是权重,b是偏移量。
查看官方文档,最开始生成的权重矩阵A的shape为(out_features, in_features)
![]()
m = nn.Linear(20,30)
# 查看生成的权重矩阵的shape,会发现权重的shape为30*20
print(m.weight.size())
# 生成一个128*20的矩阵
input = torch.randn(128, 20)
# 查看生成矩阵的大小
print(input.size())
# 传入input(shape为128*20)后,input和权重矩阵A转置后的矩阵
# 相乘(转置后A的shape变为20*30),最终得到的结果shape为128*30
print(m(input).size())
程序运行生成的结果

本文详细介绍了PyTorch中nn.Linear函数的工作原理,包括参数in_features、out_features及bias的作用,并通过实例演示了如何创建和使用线性变换层。重点展示了权重矩阵的生成和输入数据的处理过程。
13万+

被折叠的 条评论
为什么被折叠?



