一次性访问所有参数
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(2, 4)
out = net(X)
print(*[(name, param.shape) for name, param in net[0].named_parameters()])
去除print(),[ ]内的叫做列表推导式,用于遍历net[0](网络中的第一个子模块)的所有参数,并收集每个参数的名称和形状
net[0]:通过索引访问网络的第一个子模块。.named_parameters():(字面意思:名字和参数),生成模块中所有参数的名称和参数本身的元组。for name, param in net[0].named_parameters():for循环,遍历上述元组(参数名称, 参数)name是参数的名称,param是参数张量。- (name, param.shape) :这是for循环的返回值,是一个元组包含参数名称和参数形状
*操作符在这里用于将列表解包,使得列表中的每个元素作为独立的参数传递给print函数-
[('weight', torch.Size([8, 4])), ('bias', torch.Size([8]))] # 不带* ('weight', torch.Size([8, 4])) ('bias', torch.Size([8])) # 带*使用状态字典(state_dict()),通过键值来访问具体某个参数的数据,比如:
print(net.state_dict()['2.bias'].data)
输出:tensor([0.0575])
通过 state_dict() 函数可以获得 net 的关于参数的键值,通过 [2.bias].data访问键的具体值(在这个例子里,1.relu没有参数, 2.bias 指的时nn.Linear(8,1).bias)
658





