伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP](2)——代码分析
ASP整个模块的结果如下:
.
├── COPYRIGHT
├── README.md
├── __init__.py
├── asp.py
├── permutation_lib.py
├── permutation_search_kernels
│ ├── CUDA_kernels
│ │ └── permutation_search_kernels.cu
│ ├── __init__.py
│ ├── call_permutation_search_kernels.py
│ ├── channel_swap.py
│ ├── exhaustive_search.py
│ └── permutation_utilities.py
├── permutation_tests
│ ├── README.md
│ ├── ablation_studies.sh
│ ├── permutation_test.py
│ ├── runtime_table.sh
│ └── unstructured_study.sh
├── sparse_masklib.py
└── test
├── checkpointing_test_part1.py
├── checkpointing_test_part2.py
├── checkpointing_test_reference.py
├── test_permutation_application.py
└── toy_problem.py
共包含三个主要文件:
- asp.py
- permutation_lib. py
- sparse_masklib.py
以及三个主要目录
- permutation_search_kernels
- permutation_tests
- test
其中目录test用于展示一些具体的实例,目录permutation_tests是一个单独的模块,用于复现论文中的实验,这两个目录不用关注。如果不需要使用通道置换算法的话,目录permutation_search_kernels和文件permutation_lib.py也不需要关注。
因此,ASP源代码中最主要的还是asp.py文件和sparse_masklib.py文件,如果需要使用通道置换算法的话,可以在此基础上探询一下permutation_search相关的算法和代码实现。
asp.py文件
ASP类
asp.py中主定义了ASP类,其成员函数定义了init_model_for_pruning、init_optimizer_for_pruning、compute_sparse_masks、already_init_asp_model、restore_pruned_weights、is_sparsity_enabled、prune_trained_model、set_permutation_saving_params八个静态方法,分别用于对模型、优化器进行稀疏初始化、计算稀疏mask、检查模型是否已经进行稀疏初始化,检查模型是否进行了稀疏化,恢复模型的权重以及为通道设置算法设置参数。其中最主要的是prune_trained_model及其调用的init_model_for_pruning、init_optimizer_for_pruning、compute_sparse_masks三个方法。
成员变量
__model = None # 待处理的模型
__verbosity = 0 # 输出信息的详细程度
__optimizer = None # 待处理的优化器
__sparse_parameters = [] # 用于保存稀疏参数信息
__calculate_mask = None # 一个函数指针,能够通过传入的tensor的shape为tensor生成相应的mask
__allow_permutation = True # 是否需要开启通道置换算法
__all_parameters = [] # 用于保存模型中所有参数的信息
__save_permutation_graph = False # 是否保存通道置换的graph
__permutation_output_dir = '' # 通道置换信息的输出目录
成员函数
- prune_trained_model
prune_trained_model是用法介绍中需要在模型训练文件中需要添加的两行代码之一,也是ASP模块的使用入口:
@classmethod
def prune_trained_model(cls, model, optimizer):
# add mask buffers to model (init_model_for_pruning), augment optimizer (init_optimizer_for_pruning) and compute masks (compute_sparse_masks)
cls.init_model_for_pruning(model, mask_calculator="m4n2_1d", verbosity=2, whitelist=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.MultiheadAttention], allow_recompute_mask=False)
cls.init_optimizer_for_pruning(optimizer)
cls.compute_sparse_masks()
prune_trained_model方法接受两个参数,分别是需要训练后的模型和优化器。
该方法中又分别调用了三个方法:首先使用init_model_for_pruning,init_optimizer_for_pruning方法分别对模型和优化器中的权重进行分析和初始化准备工作(为模型添加mask buffer),并通过compute_sparse_masks方法为每个权重计算生成对应的稀疏mask。
- init_model_for_pruning
def init_model_for_pruning(cls, model, mask_calculator="m4n2_1d",
verbosity=3,
whitelist=[torch.nn.Linear, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.MultiheadAttention],
allowed_layer_names=None,
disallowed_layer_names=[],
allow_recompute_mask=False,
custom_layer_dict={
},
allow_permutation=True):
assert (cls.__model is None), "ASP has been initialized already."
cls.__model = model
cls.__verbosity = verbosity
cls.__allow_permutation = allow_permutation
if isinstance(mask_calculator, str):
def create_mask_from_pattern(param):
return create_mask(param, mask_calculator).bool()
cls.__calculate_mask = create_mask_from_pattern
else:
cls.__calculate_mask = mask_calculator #user defined function
# function to extract variables that will be sparsified.
# idea is that you will add one of these functions for each module type that can be sparsified.
if torchvision_imported:
print("[ASP] torchvision is imported, can work with the MaskRCNN/KeypointRCNN from torchvision.")
torchvision_version = str(torchvision.__version__)
torchvision_version_major = int(torchvision_version.split('.')[0])
torchvision_version_minor = int(torchvision_version.split('.')[1])
if torchvision_version_major == 0 and torchvision_version_minor < 12:
sparse_parameter_list = {
torch.nn.Linear: ['weight'], torch.nn.Conv1d: ['weight'], torch.nn.Conv2d: ['weight'], torch.nn.Conv3d: ['weight'], torch.nn.modules.linear.NonDynamicallyQuantizableLinear: ['weight'], torch.nn.MultiheadAttention: ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'], torchvision.ops.misc.Conv2d: ['weight']}

本文对英伟达结构化剪枝工具Nvidia Apex Automatic Sparsity [ASP]进行代码分析。介绍了ASP模块包含的文件和目录,重点分析了asp.py文件中ASP类的成员变量和函数,如prune_trained_model、init_model_for_pruning等方法的功能和实现,涉及模型和优化器的稀疏初始化、mask计算等。
最低0.47元/天 解锁文章
2649

被折叠的 条评论
为什么被折叠?



