pytorch调用resnet、alexnet、vgg、squeezenet、densenet、inception预训练模型
from __future__ import print_function
from __future__ import division
import torch.nn as nn
from torchvision import datasets, models
# Top level data directory. Here we assume the format of the directory conforms
# to the ImageFolder structure
# Number of classes in the dataset
num_classes = 2
# Batch size for training (change depending on how much memory you have)
BATCH_SIZE = 128 #批处理尺寸(batch_size)
# Flag for feature extracting. When False, we finetune the whole model,
# when True we only update the reshaped layer params
feature_extract = True
是否改变卷及层的参数
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
初始化模型
def initialize_model(mo