微调
例如:
修改模型(将1000类改为10类输出)
model = timm.create_model('resnet34',num_classes=10,pretrained=True) x = torch.randn(1,3,224,224) output = model(x) output.shape
输出为:
torch.Size([1, 10])
改变输入通道数(比如我们传入的图片是单通道的,但是模型需要的是三通道图片) 我们可以通过添加in_chans=1
来改变
model = timm.create_model('resnet34',num_classes=10,pretrained=True,in_chans=1) x = torch.randn(1,1,224,224) output = model(x)
PS:关于float16与float32的区别 