更多模型:https://pytorch.org/docs/stable/torchvision/models.html
环境准备
##设置路径
data_dir = "../data/flowers"
target_dir = "../data/processed_byclass" # 提前创建好这个文件夹以及其下面的train、test、val文件夹
##环境准备
import os
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import shutil
import matplotlib.pyplot as plt
# 读取数据
# ref: https://github.com/pytorch/examples/blob/42e5b996718797e45c46a25c55b031e6768f8440/imagenet/main.py#L89-L101
traindir = os.path.join(target_dir, 'train')
valdir = os.path.join(target_dir, 'val')
testdir = os.path.join(target_dir, 'test')
##把train、test、validation数据分别扔进loader中变成能用的数据,分成一个个batch。这里规范大小为224(因为迁移的网络是这样的)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=32, shuffle=False)
test_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(testdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=32, shuffle=False)
模型准备-框架函数
#规定每10步打印一次结果
print_freq = 10
#train 的步骤
def train(train_loader, model, criterion, optimizer, epoch):
#每一步都要reset,在类AverageMeter中已经规定
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
# target = target.cuda(async=True)
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)