关于pytorch的量化,可以看https://blog.youkuaiyun.com/zlgahu/article/details/104662203/
import numpy as np
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import time
import sys
import torch.quantization
import requests
import warnings
#用于定义需要量化的参数
from torch.quantization import QuantStub,DeQuantStub
#确保通道数能被divisor整除
def _make_divisible(v,divisor,min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value,int(v + divisor / 2) // divisor * divisor)
#防止损失过多通道
if new_v < 0.9 * v:
new_v += divisor
return new_v
#定义几个网络的组合层
class ConvBNReLU(nn.Sequential):
def __init__(self,in_planes,out_planes,kernel_size=3,stride=1,groups=1):
padding = (kernel_size - 1) // 2
super().__init__(
nn.Conv2d(in_planes,out_planes,kernel_size,stride,padding,groups=groups,bias=False),
nn.BatchNorm2d(out_planes,momentum=0.1),
nn.ReLU(inplace=False)
)
#定义残差块
class InvertedResidual(nn.Module):
def __init__(self,inp,oup,stride,expand_ratio):
super().__init__()
self.stride = stride
assert stride in [1,2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(inp,hidden_dim,kernel_size=1))
layers.extend([
ConvBNReLU(hidden_dim,hidden_dim,stride=stride,groups=hidden_dim),
nn.Conv2d(hidden_dim,oup,1,1,0,bias),
nn.BatchNorm2d(oup,momentum=0.1),
])
self.conv = nn.Sequential(*layers)
#将无状态的相加操作转化为FloatFunctional,猜测是因为量化后加法操作不能正常运作
self.skip_add = nn.quantized.FloadFunctional()
def forward(self,x):
if self.use_res_connect:
return self.skip_add.add(x,self.conv(x))
else:
return self.conv(x)
#定义网络结构
class MobileNetV2(nn.Module):
def __init__(self,num_classes=1000,width_mult=1.0,inverted_residual_setting=None,round_nearest=8):
super().__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
#定义网络结构参数
if inverted_residual_setting is None:
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]