pytorch中参数访问

部署运行你感兴趣的模型镜像

一次性访问所有参数

net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 1))
X = torch.rand(2, 4)
out = net(X)

print(*[(name, param.shape) for name, param in net[0].named_parameters()])

去除print(),[  ]内的叫做列表推导式,用于遍历net[0](网络中的第一个子模块)的所有参数,并收集每个参数的名称和形状

  • net[0]:通过索引访问网络的第一个子模块。
  • .named_parameters():(字面意思:名字和参数),生成模块中所有参数的名称和参数本身的元组。
  • for name, param in net[0].named_parameters():for循环,遍历上述元组(参数名称, 参数)name 是参数的名称,param 是参数张量。
  • (name, param.shape) :这是for循环的返回值,是一个元组包含参数名称和参数形状
  • * 操作符在这里用于将列表解包,使得列表中的每个元素作为独立的参数传递给 print 函数
  • [('weight', torch.Size([8, 4])), ('bias', torch.Size([8]))] # 不带*
    ('weight', torch.Size([8, 4])) ('bias', torch.Size([8]))  # 带*

    使用状态字典(state_dict()),通过键值来访问具体某个参数的数据,比如:

print(net.state_dict()['2.bias'].data)

输出:tensor([0.0575])

 通过 state_dict() 函数可以获得 net 的关于参数的键值,通过 [2.bias].data访问键的具体值(在这个例子里,1.relu没有参数, 2.bias 指的时nn.Linear(8,1).bias)

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值