# Assume input range is [0, 1]
class ResNet101FeatureExtractor(nn.Module):
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
super(ResNet101FeatureExtractor, self).__init__()
model = torchvision.models.resnet101(pretrained=True)
self.use_input_norm = use_input_norm
if self.use_input_norm:
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
self.register_buffer('mean', mean)
self.register_buffer('std', std)
self.features = nn.Sequential(*list(model.children())[:8])
# No need to BP to variable
for k, v in self.features.named_parameters():
v.requires_grad = False
def forward(self, x):
if self.use_input_norm:
x = (x - self.mean) / self.std
output = self.features(x)
return output
加载resNet预训练模型
最新推荐文章于 2025-11-03 11:32:02 发布
本文介绍了一个基于预训练ResNet101模型的特征提取器类实现,详细解析了输入归一化处理及如何冻结模型参数,避免反向传播。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
675

被折叠的 条评论
为什么被折叠?



