英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

本文对英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]进行代码分析。介绍了ASP模块包含的文件和目录,重点分析了asp.py文件中ASP类的成员变量和函数,如prune_trained_model、init_model_for_pruning等方法的功能和实现,涉及模型和优化器的稀疏初始化、mask计算等。

伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析

ASP整个模块的结果如下:

.
├── COPYRIGHT
├── README.md
├── __init__.py
├── asp.py
├── permutation_lib.py
├── permutation_search_kernels
│   ├── CUDA_kernels
│   │   └── permutation_search_kernels.cu
│   ├── __init__.py
│   ├── call_permutation_search_kernels.py
│   ├── channel_swap.py
│   ├── exhaustive_search.py
│   └── permutation_utilities.py
├── permutation_tests
│   ├── README.md
│   ├── ablation_studies.sh
│   ├── permutation_test.py
│   ├── runtime_table.sh
│   └── unstructured_study.sh
├── sparse_masklib.py
└── test
    ├── checkpointing_test_part1.py
    ├── checkpointing_test_part2.py
    ├── checkpointing_test_reference.py
    ├── test_permutation_application.py
    └── toy_problem.py

共包含三个主要文件:

  • asp.py
  • permutation_lib. py
  • sparse_masklib.py

以及三个主要目录

  • permutation_search_kernels
  • permutation_tests
  • test

其中目录test用于展示一些具体的实例,目录permutation_tests是一个单独的模块,用于复现论文中的实验,这两个目录不用关注。如果不需要使用通道置换算法的话,目录permutation_search_kernels和文件permutation_lib.py也不需要关注。

因此,ASP源代码中最主要的还是asp.py文件和sparse_masklib.py文件,如果需要使用通道置换算法的话,可以在此基础上探询一下permutation_search相关的算法和代码实现。

asp.py文件
ASP类

asp.py中主定义了ASP类,其成员函数定义了init_model_for_pruninginit_optimizer_for_pruningcompute_sparse_masksalready_init_asp_modelrestore_pruned_weightsis_sparsity_enabledprune_trained_modelset_permutation_saving_params八个静态方法,分别用于对模型、优化器进行稀疏初始化、计算稀疏mask、检查模型是否已经进行稀疏初始化,检查模型是否进行了稀疏化,恢复模型的权重以及为通道设置算法设置参数。其中最主要的是prune_trained_model及其调用的init_model_for_pruninginit_optimizer_for_pruningcompute_sparse_masks三个方法。

成员变量
    __model = None                           	# 待处理的模型
    __verbosity = 0							 	# 输出信息的详细程度
     __optimizer = None							# 待处理的优化器
    __sparse_parameters = []					# 用于保存稀疏参数信息
    __calculate_mask = None						# 一个函数指针,能够通过传入的tensor的shape为tensor生成相应的mask
    __allow_permutation = True					# 是否需要开启通道置换算法
    __all_parameters = []						# 用于保存模型中所有参数的信息
    __save_permutation_graph = False			# 是否保存通道置换的graph
    __permutation_output_dir = ''				# 通道置换信息的输出目录
成员函数
  • prune_trained_model

prune_trained_model是用法介绍中需要在模型训练文件中需要添加的两行代码之一,也是ASP模块的使用入口:

@classmethod
def prune_trained_model(cls, model, optimizer):
    # add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
    cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
    cls.init_optimizer_for_pruning(optimizer)
    cls.compute_sparse_masks()

prune_trained_model方法接受两个参数,分别是需要训练后的模型和优化器。

该方法中又分别调用了三个方法:首先使用init_model_for_pruninginit_optimizer_for_pruning方法分别对模型和优化器中的权重进行分析和初始化准备工作(为模型添加mask buffer),并通过compute_sparse_masks方法为每个权重计算生成对应的稀疏mask。

  • init_model_for_pruning
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
         verbosity=3,
         whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention], 
         allowed_layer_names=None, 
         disallowed_layer_names=[],
         allow_recompute_mask=False, 
         custom_layer_dict={
   
   },
         allow_permutation=True):

    assert (cls.__model is None), "ASP has been initialized already."
    cls.__model = model
    cls.__verbosity = verbosity
    cls.__allow_permutation = allow_permutation

    if isinstance(mask_calculator, str):
        def create_mask_from_pattern(param):
            return create_mask(param, mask_calculator).bool()
        cls.__calculate_mask = create_mask_from_pattern
    else:
        cls.__calculate_mask = mask_calculator #user defined function

    # function to extract variables that will be sparsified. 
    # idea is that you will add one of these functions for each module type that can be sparsified.
    if torchvision_imported:
        print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
        torchvision_version = str(torchvision.__version__)
        torchvision_version_major = int(torchvision_version.split('.')[0])
        torchvision_version_minor = int(torchvision_version.split('.')[1])
        if torchvision_version_major == 0 and torchvision_version_minor < 12:
            sparse_parameter_list = {
   
   torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}
   
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值