pytorch 获取层权重,对特定层注入hook, 提取中间层输出

#获取模型权重
for k, v in model_2.state_dict().iteritems():
    print("Layer {}".format(k))
    print(v)
#获取模型权重
for layer in model_2.modules():
   if isinstance(layer, nn.Linear):
        print(layer.weight)
#将一个模型权重载入另一个模型
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
    load = torch.load('/home/huangqk/.torch/models/vgg19-dcbb9e9d.pth')
    load_state = {k: v for k, v in load.items() if k not in ['classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias']}
    model_state = model.state_dict()
    model_state.update(load_state)
    model.load_state_dict(model_state)
return model
# 对特定层注入hook
def hook_layers(model):
    def hook_function(module, inputs, outputs):
        recreate_image(inputs[0])

    print(model.features._modules)
    first_layer = list(model.features._modules.items())[0][1]
    first_layer.register_forward_hook(hook_function) 
#获取层
x = someinput
for l in vgg.features.modules():
  x = l(x)
modulelist = list(vgg.features.modules())
for l in modulelist[:5]:
  x = l(x)
keep = x
for l in modulelist[5:]:
  x = l(x)
# 提取vgg模型的中间层输出
# coding:utf8
import torch
import torch.nn as nn
from torchvision.models import vgg16
from collections import namedtuple


class Vgg16(torch.nn.Module):
    def __init__(self):
        super(Vgg16, self).__init__()
        features = list(vgg16(pretrained=True).features)[:23]
        # features的第3,8,15,22层分别是: relu1_2,relu2_2,relu3_3,relu4_3
        self.features = nn.ModuleList(features).eval()

    def forward(self, x):
        results = []
        for ii, model in enumerate(self.features):
            x = model(x)
            if ii in {3, 8, 15, 22}:
                results.append(x)

        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'])
        return vgg_outputs(*results)
<think>在PyTorch中,访问模型是常见的操作,例如为了提取中间特征、修改特定或进行模型分析等。以下是几种访问模型的方法及其应用场景。###1.**直接通过模型属性访问**如果模型是通过`nn.Module`类定义的,并且每一都被赋予了一个属性名,那么可以直接通过属性名访问。```pythonimporttorchimporttorch.nnasnn#定义一个简单的模型classSimpleModel(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(3,16,3)self.pool=nn.MaxPool2d(2,2)self.fc1=nn.Linear(16*13*13,10)#假设输入图像为28x28,经过卷积和池化后为13x13defforward(self,x):x=self.pool(torch.relu(self.conv1(x)))x=x.view(-1,16*13*13)x=self.fc1(x)returnxmodel=SimpleModel()#直接访问print(model.conv1)#访问卷积print(model.fc1)#访问全连接```###2.**使用`named_children`或`children`方法迭代访问**-`children()`:返回模型直接子模块的迭代器。-`named_children()`:返回直接子模块的迭代器,同时提供模块的名称。```python#遍历直接子模块print("Children:")forchildinmodel.children():print(child)print("\nNamedChildren:")forname,moduleinmodel.named_children():print(f"{name}:{module}")```###3.**使用`named_modules`或`modules`方法递归访问**-`modules()`:返回模型中所有模块(包括当前模块和所有子模块)的迭代器,递归遍历。-`named_modules()`:返回所有模块的迭代器,同时提供模块的名称(包括路径)。```python#递归遍历所有模块print("Allmodules:")formoduleinmodel.modules():print(module)print("\nNamedmodules:")forname,moduleinmodel.named_modules():print(f"{name}:{module}")```###4.**通过索引访问(适用于`nn.Sequential`模型)**如果模型是`nn.Sequential`定义的,则可以通过索引来访问。```python#定义一个Sequential模型seq_model=nn.Sequential(nn.Conv2d(3,16,3),nn.ReLU(),nn.MaxPool2d(2,2),nn.Flatten(),nn.Linear(16*13*13,10))#通过索引访问print(seq_model[0])#第一:卷积print(seq_model[4])#第五:全连接```###5.**通过名称访问(使用`get_submodule`方法,PyTorch1.9+)**PyTorch1.9引入了`get_submodule`方法,可以通过字符串路径获取子模块。```python#通过路径访问子模块conv1=model.get_submodule('conv1')print(conv1)#对于嵌套模块classNestedModel(nn.Module):def__init__(self):super().__init__()self.features=nn.Sequential(nn.Conv2d(3,16,3),nn.ReLU())self.classifier=nn.Linear(16*13*13,10)defforward(self,x):x=self.features(x)x=x.view(-1,16*13*13)x=self.classifier(x)returnxnested_model=NestedModel()#访问嵌套模块中的卷积conv_in_nested=nested_model.get_submodule('features.0')print(conv_in_nested)```###6.**修改模型**有时需要替换模型中的某一。可以通过直接赋值或使用`nn.Module`的`__setattr__`来实现。```python#替换模型中的#将SimpleModel中的ReLU激活函数替换为LeakyReLUclassSimpleModelWithActivation(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(3,16,3)self.activation=nn.ReLU()self.pool=nn.MaxPool2d(2,2)self.fc1=nn.Linear(16*13*13,10)defforward(self,x):x=self.pool(self.activation(self.conv1(x)))x=x.view(-1,16*13*13)x=self.fc1(x)returnxmodel_act=SimpleModelWithActivation()#替换激活model_act.activation=nn.LeakyReLU(0.1)print(model_act)```###7.**提取中间层输出(特征提取)**在模型前向传播过程中,我们可能需要获取中间层输出。可以通过以下方法实现:####方法1:修改前向传播函数```pythonclassSimpleModelWithOutputs(SimpleModel):defforward(self,x):#保存中间输出conv1_out=self.conv1(x)relu_out=torch.relu(conv1_out)pool_out=self.pool(relu_out)#如果需要返回中间输出,可以返回多个值returnpool_out,self.fc1(pool_out.view(-1,16*13*13))model_with_outputs=SimpleModelWithOutputs()x=torch.randn(1,3,28,28)intermediate,output=model_with_outputs(x)print(intermediate.shape)#torch.Size([1,16,13,13])```####方法2:使用钩子(Hook)钩子可以在不修改模型定义的情况下获取中间层输出。```python#注册前向钩子activations={}defget_activation(name):defhook(model,input,output):activations[name]=output.detach()returnhookmodel=SimpleModel()model.conv1.register_forward_hook(get_activation('conv1'))model.fc1.register_forward_hook(get_activation('fc1'))x=torch.randn(1,3,28,28)output=model(x)#访问保存的输出print(activations['conv1'].shape)#torch.Size([1,16,26,26])print(activations['fc1'].shape)#torch.Size([1,10])```###8.**根据条件查找特定**例如,查找模型中所有的卷积并替换为新的卷积(比如进行模型量化时替换为量化卷积)。```python#查找所有卷积并替换defreplace_layers(model,layer_type,new_layer_func):forname,moduleinmodel.named_children():ifisinstance(module,layer_type):#替换该setattr(model,name,new_layer_func(module))elifisinstance(module,nn.Module):#递归处理子模块replace_layers(module,layer_type,new_layer_func)#示例:将卷积替换为带权重标准化的卷积model=SimpleModel()print("Beforereplacement:")print(model)replace_layers(model,nn.Conv2d,lambdaconv:nn.utils.weight_norm(conv))print("\nAfterreplacement:")print(model)```###9.**使用`apply`方法遍历所有模块**`apply`方法递归地对每个模块应用一个函数。```python#初始化权重definit_weights(m):ifisinstance(m,nn.Conv2d):nn.init.kaiming_normal_(m.weight)ifm.biasisnotNone:nn.init.zeros_(m.bias)elifisinstance(m,nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.zeros_(m.bias)model.apply(init_weights)```###10.**访问模型参数**除了访问,有时需要直接访问参数(权重和偏置)。```python#访问特定的参数forname,paraminmodel.named_parameters():if'weight'inname:print(f"{name}:{param.shape}")#获取某一权重conv1_weight=model.conv1.weightprint(conv1_weight.shape)```###总结PyTorch提供了多种灵活的方式来访问和操作模型,包括:-直接属性访问-迭代方法(`children`,`named_children`,`modules`,`named_modules`)-索引访问(针对`Sequential`模型)-钩子机制获取中间输出-递归替换-使用`apply`初始化权重这些方法使得模型的操作和特征提取变得非常方便。---§§相关问题§§:1.如何在PyTorch中实现模型的并行计算(如DataParallel)?2.PyTorch中如何保存和加载模型权重?3.如何实现自定义的PyTorch?4.在PyTorch中如何实现梯度裁剪(gradientclipping)?</think>### PyTorch 模型访问详解 在 PyTorch 中,访问模型是调试、特征提取和模型修改的关键操作。以下是 5 种主要的访问方法,每种方法都有不同的适用场景: #### 1. **直接属性访问(最简单直接)** ```python import torch import torch.nn as nn # 创建简单模型 class MyModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 16, 3) self.pool = nn.MaxPool2d(2, 2) self.fc = nn.Linear(16 * 13 * 13, 10) # 假设输入为224x224 def forward(self, x): x = self.pool(torch.relu(self.conv1(x))) x = torch.flatten(x, 1) x = self.fc(x) return x model = MyModel() # 直接访问 print("直接访问:") print("卷积:", model.conv1) print("全连接权重:", model.fc.weight.shape) ``` #### 2. **使用 `children()` 和 `named_children()`(访问直接子模块)** ```python # 遍历直接子模块 print("\n使用children():") for child in model.children(): print(child) # 带名称访问 print("\n使用named_children():") for name, module in model.named_children(): print(f"{name}: {module}") ``` #### 3. **使用 `modules()` 和 `named_modules()`(递归访问所有)** ```python # 递归访问所有模块(包括容器本身) print("\n使用modules():") for module in model.modules(): print(module.__class__.__name__) # 带完整路径访问 print("\n使用named_modules():") for name, module in model.named_modules(): print(f"{name}: {module}") ``` #### 4. **使用 `get_submodule()`(PyTorch 1.9+ 按路径访问)** ```python # 按路径访问特定 conv_layer = model.get_submodule('conv1') print("\n使用get_submodule():", conv_layer) # 访问嵌套模型中的 class NestedModel(nn.Module): def __init__(self): super().__init__() self.features = nn.Sequential( nn.Conv2d(3, 16, 3), nn.ReLU(), nn.MaxPool2d(2, 2) ) self.classifier = nn.Linear(16 * 13 * 13, 10) model_nested = NestedModel() pool_layer = model_nested.get_submodule('features.2') print("嵌套模型访问:", pool_layer) ``` #### 5. **使用钩子(Hook)访问中间层输出** ```python # 注册钩子捕获中间输出 activations = {} def get_activation(name): def hook(model, input, output): activations[name] = output.detach() return hook # 注册钩子 model.conv1.register_forward_hook(get_activation('conv1')) model.fc.register_forward_hook(get_activation('fc')) # 前向传播触发钩子 input_tensor = torch.randn(1, 3, 224, 224) output = model(input_tensor) print("\n钩子捕获的中间输出:") for name, tensor in activations.items(): print(f"{name}: {tensor.shape}") ``` ### 实际应用场景 #### 场景1:特征提取提取中间层输出) ```python class FeatureExtractor(nn.Module): def __init__(self, model, layer_names): super().__init__() self.model = model self.layer_names = layer_names self.features = {} # 注册钩子 for name in layer_names: layer = model.get_submodule(name) layer.register_forward_hook(self.save_features(name)) def save_features(self, name): def hook(module, input, output): self.features[name] = output.detach() return hook def forward(self, x): _ = self.model(x) return self.features # 使用示例 extractor = FeatureExtractor(model, ['conv1', 'fc']) features = extractor(input_tensor) print("\n提取的特征:") for name, feat in features.items(): print(f"{name}: {feat.shape}") ``` #### 场景2:模型修改(冻结特定) ```python # 冻结所有卷积 def freeze_conv_layers(model): for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): print(f"冻结: {name}") for param in module.parameters(): param.requires_grad = False freeze_conv_layers(model) # 检查冻结状态 print("\n参数冻结状态:") for name, param in model.named_parameters(): print(f"{name}: 需要梯度={param.requires_grad}") ``` #### 场景3:模型可视化 ```python import matplotlib.pyplot as plt def visualize_filters(layer): if not isinstance(layer, nn.Conv2d): return filters = layer.weight.detach().cpu() num_filters = filters.size(0) plt.figure(figsize=(12, 6)) for i in range(min(16, num_filters)): plt.subplot(2, 8, i+1) # 归一化显示 filter_img = filters[i].mean(0) # 对输入通道取平均 plt.imshow(filter_img, cmap='viridis') plt.axis('off') plt.suptitle(f'卷积核可视化 (共{num_filters}个)') plt.tight_layout() plt.show() # 可视化第一卷积核 visualize_filters(model.conv1) ``` ### 高级技巧:动态访问和修改 #### 1. 按类型查找 ```python def find_layers_by_type(model, layer_type): return [name for name, module in model.named_modules() if isinstance(module, layer_type)] # 查找所有线性 linear_layers = find_layers_by_type(model, nn.Linear) print("\n所有线性:", linear_layers) ``` #### 2. 替换模型 ```python # 替换激活函数 def replace_activations(model, old_type, new_layer): for name, module in model.named_modules(): if isinstance(module, old_type): # 获取父模块和子模块名称 path = name.split('.') parent = model for p in path[:-1]: parent = parent.get_submodule(p) # 替换 setattr(parent, path[-1], new_layer) print(f"已替换: {name}") # 将所有ReLU替换为LeakyReLU replace_activations(model, nn.ReLU, nn.LeakyReLU(0.1)) ``` #### 3. 添加新 ```python # 在特定位置添加 class ModelWithDropout(nn.Module): def __init__(self, original_model): super().__init__() self.features = original_model.features self.classifier = nn.Sequential( nn.Dropout(0.5), original_model.classifier ) def forward(self, x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x # 使用示例 model_with_dropout = ModelWithDropout(model) print("\n添加Dropout后的模型结构:") print(model_with_dropout) ``` ### 总结:模型访问方法对比 | 方法 | 优点 | 缺点 | 适用场景 | |------|------|------|----------| | **直接属性访问** | 简单直接 | 无法访问嵌套 | 简单模型、已知名 | | **children()** | 访问直接子模块 | 不递归 | 顶模块遍历 | | **named_modules()** | 递归访问所有 | 包含容器本身 | 完整模型分析 | | **get_submodule()** | 精确路径访问 | 需要知道完整路径 | 精确操作 | | **钩子(Hook)** | 捕获中间输出 | 增加计算开销 | 特征提取、调试 | ### 最佳实践建议: 1. **调试时**:使用 `named_modules()` 查看完整结构 2. **特征提取**:使用钩子捕获中间输出 3. **模型修改**:结合 `get_submodule()` 和直接赋值 4. **大型模型**:使用按类型查找提高效率 5. **可视化**:结合权重访问和matplotlib ---
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值