本文所实现的网络来源于SCAR:Spatial-/Channel-wise Attention Regression Networks for Crowd Counting(Neurocompting 2019)
import torch;from torchvision import models
from torchvision.models import vgg16
import warnings;from torch import nn
warnings.filterwarnings("ignore")
vgg16 = vgg16(pretrained=True)
def initialize_weights(models):
for model in models:
real_init_weights(model)
import warnings
warnings.filterwarnings("ignore")
def real_init_weights(m):
if isinstance(m, list):
for mini_m in m:
real_init_weights(mini_m)
else:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, std=0.01)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m,nn.Module):
for mini_m in m.children():
real_init_weights(mini_m)
else:
print( m )
class SCAR(torch.nn.Module):
def __init__(self,loadwieght=False):
super(SCAR,self).__init__()
self.vgg10=vgg10
if loadwieght==False:
mod = models.vgg16(pretrained=True)
initialize_weights(self.modules())
self.vgg10.load_state_dict(mod.features[0:23].state_dict(