模型转换时候的detach()函数详细说明

模型不同框架之间进行转换加载的时,采用detach方式进行分离权重;

是否必须:非必须选项

是否有价值:非常有价值,相当于减小了debug成本,对于应用场景比较小的任务中非常具有价值,不需要进行debug了

 



# .detach() 相关参数的问题,其中主要是由于模型权重导出的时候是否保留梯度信息,一般来说还是不需要加载梯度信息的
import torch

# 模拟一个带有梯度跟踪的 PyTorch 卷积权重
conv_weight = torch.randn(4, 3, 3, 3, requires_grad=True)  # [out_ch=4, in_ch=3, kH=3, kW=3]

# 错误操作:未使用 .detach() 直接转换为 NumPy
try:
    kernel_np = conv_weight.permute(2, 3, 1, 0).numpy()
except RuntimeError as e:
    print(f"错误信息:{e}")

 以上如果不适用detach会有如下问题

错误信息:Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

以下的Pytorch模型,导出会出错,因为不使用detach

# 举个例子

# PyTorch 模型中的卷积层(假设处于训练模式)
class PyTorchConv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, groups=1)
        self.bn = torch.nn.BatchNorm2d(4)

# 初始化模型并进入训练模式
model = PyTorchConv().train()

# 提取权重对象(此时 conv.weight 的 requires_grad=True)
w = type('Weights', (), {'conv': model.conv, 'bn': model.bn})



# 错误代码:未使用 detach()
try:
    kernel = w.conv.weight.permute(2, 3, 1, 0).numpy()
except RuntimeError as e:
    print(f"错误信息:{e}")

修改成为以下会成功



# 正确代码:使用 detach()
kernel = w.conv.weight.detach().permute(2, 3, 1, 0).numpy()
bias = w.conv.bias.detach().numpy() if hasattr(w.conv, "bias") else None
print("权重转换成功,形状:", kernel.shape)


model.eval()  # 关闭梯度跟踪
w.conv.weight.requires_grad_(False)  # 确保权重无梯度

# 安全提取权重的通用写法
def get_weights(module):
    weights = {}
    for name, param in module.named_parameters():
        weights[name] = param.detach().cpu().numpy()
    return weights

# 导出整个模型的权重
weights_dict = get_weights(model)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值