Pytorch读取模型文件及视频的维度膨胀

本文介绍如何在PyTorch中加载并检查预训练模型,演示了从模型文件读取内容、获取权重并转换为NumPy数组的方法。特别地,文章详细展示了如何对ResNet50预训练模型进行维度膨胀,以适用于视频行为识别任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在深度学习中,我们载入预训练模型时,经常要查看预训练模型的内容,以便更好使用,我以pytorch为例总结一下简单的操作

模型文件的读取:

pretrained_dict = torch.load(path)#读取下来的pretrained_dict是一个字典

获取模型保存的键值对

for k, v in pretrained_dict.items():

这里的k一般是层名,v是获取的对应层名的内容(有些模型可能保存的形式不同)

获取每一层的权重,并转化成numpy数组

得到的v是parameter格式,包含权重和是否反向传播

我们可以通过v.data得到权重数值,类型为tensor,然后通过v.data.numpy()转化成array,使用v.requires_grad获取是否反向传播的True/False

将numpy转化为parameter:

new_v =nn.Parameter(torch.tensor(array))

保存模型:

torch.save(dict,new_path)

其中dict是字典名,包含键为层名,值为parameter

在视频行为识别中,常常缺少预训练模型,需要图像的预训练模型进行维度膨胀,因此我参考i3d的处理方式,将resnet50的预训练模型进行维度膨胀,仅供参考:

import torch
from torchvision import models
import numpy as np
import torch.nn as nn

pretrained_dict =torch.load(r'/data/zhengrui/dataset/pretrain/resnet50-19c8e357.pth',map_location='cpu')
new_dict = {}
for k, v in pretrained_dict.items():
	new_dict = {}
	print('layer:', k,'\n')
	print('content:',v.data.numpy().shape,'\n')
	print(v.requires_grad)

   	conv_weight = v.data.numpy()
	if 'conv' in k:
       try:
			n = conv_weight.shape[3]
			print(n)
            v1 = np.tile(np.expand_dims(conv_weight / n, 2), [1, 1, n, 1, 1])
			print(v1.shape)
            new_v = nn.Parameter(torch.tensor(v1))
            new_v.requires_grad = v.requires_grad
       except IndexError:
       print('ERROR:',k)   
	elif k == 'layer1.0.downsample.0.weight' or k ==
'layer2.0.downsample.0.weight' or k == 'layer3.0.downsample.0.weight' or k ==
'layer4.0.downsample.0.weight':

       v1 = v.data.unsqueeze(4)
       new_v = nn.Parameter(v1)
       new_v.requires_grad = v.requires_grad

   else:
        new_v = v

  new_dict[k] = new_v

#print(new_dict)
torch.save(new_dict,r'/data/zhengrui/dataset/pretrain/resnet50-2dto3d.pth')
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值