版权归属:
- https://blog.youkuaiyun.com/halchan
- chanhal@outlook.com
更多关注:
- https://github.com/chanhal
- https://www.zhihu.com/people/chanhal
Introduction
在写pytorch代码时,发现并不是所有的写法都能将模块和参数注册到网络中,不信,看下面的代码:
class net(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.linears = [nn.Linear(10,10) for i in range(2)]
def forward(self, x):
for m in self.linears:
x = m(x)
return x
mynet = net()
print(mynet)
返回:
net()
表示并没有注册将模块注册至网络中
而下面的两种写法可以将模块注册至网络中。
# 写法一
class net(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])
def forward(self, x):
for m in self.linears:
x = m(x)
return x
mynet = net()
print(mynet)
返回:
net(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
)
)
<

本文探讨了PyTorch中模块和参数注册到网络中的正确方法。通过示例代码展示,只有直接赋值给Module对象的模块才会被注册。nn.ModuleList作为nn.Module的子类,同样可以注册。模块的定义顺序决定了注册顺序,但在forward方法中执行的顺序由代码安排。
最低0.47元/天 解锁文章
736

被折叠的 条评论
为什么被折叠?



