本文基于AI研习社发布的“图像场景分类挑战赛”完成,主要是从一个小白如何去将一个图像分类的任务用代码跑起来的角度写的,也是自己的一个学习过程。
图像场景分类挑战赛:https://god.yanxishe.com/97?from=god_home_list
本文代码链接:基于预训练模型的图像分类代码模板
拿到比赛题目时,首先做以下几点观察:
- 首先分析一下这个任务:就是一个将一些风景图片正确分类的简单任务。
- 其次看一下官方给的数据集和标签:
1. 数据都是世界各地的风景图片,共有6类,buildings、street、forest、sea、mountain、glacier。训练集有13627张图片,测试集有3407张图片,图片为RGB图片,格式为jpg。
2. 标签文件为csv文件,内容为filename( ‘0.jpg’ )和label( ‘forest’ ),两列都是String类型的。 - 然后看一下提交结果文件的格式:结果文件为csv文件,内容不需要Title,第一列为图片序号(0),Int类型,第二列为类别名称(‘street’),String类型。
分析完基本比赛内容和条件之后,就可以开始用代码实现了,本文使用colab平台实现。 实现步骤如下:
- 成功使用Google云盘和Colab(Colab是一个 Jupyter 笔记本环境,已经默认安装好 pytorch,不需要进行任何设置就可以使用,并且完全在云端运行。使用方法可以参考 :https://www.cnblogs.com/lfri/p/10471852.html ,国内目前无法访问 colab,可以安装一些软件实现访问,比如Ghelper: http://googlehelper.net/ )
- 将官方提供的数据集压缩包上传到Google云盘中(注意是上传压缩包,不要解压以后上传,解压之后上传很慢)
- 代码实现(前序工作):
1.挂载Google Drive (在Colab中将Google云盘载入进来)
2.解压文件(解压数据集压缩包文件到当前运行环境)
3.创建一个文件夹存放训练好的模型 - 代码实现(正式工作)
1.导包(导入所有要用的包,在写代码过程中需要一个补充一个即可)
2.查看是否使用GPU
3.读取标签文件(读取训练集的带标签文件,此处为CSV格式文件)
4.定义读取数据集的类(包括训练集和测试集)
5.预处理(对数据集进行预处理)
6.调用读取数据集的类(包括训练集和测试集)
7.初始化预训练模型
8.定义训练方法
9.训练(调用预训练模型和训练方法进行训练)
10.测试(使用训练好的模型进行测试,得到csv格式的结果文件)
挂载Google Drive (在Colab中将Google云盘载入进来)
from google.colab import drive
drive.mount('/content/drive')
解压文件(解压数据集压缩包文件到当前运行环境)
!cp -r /content/drive/My\ Drive/Scene/Image_Classification.zip ./ #将google云盘中的数据集压缩文件拷贝到当前运行环境
!unzip Image_Classification.zip #将数据集压缩文件解压,在当前运行环境得到'train'文件夹、'test'文件夹和'train.csv'文件
创建一个文件夹存放训练好的模型
! mkdir /content/drive/My\ Drive/Scene/checkpoint
导包(导入所有要用的包,在写代码过程中需要一个补充一个即可)
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import random_split, DataLoader
import os
import torch.nn as nn
import time
import torch.optim as optim
查看是否正在使用GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'CPU')
print(device)
读取标签文件(读取训练集的带标签文件,此处为CSV格式文件)
def readLabelFile():
label_file = pd.read_csv('train.csv')
return label_file['filename'],label_file['label']
filename,filelabel = readLabelFile()
map = ['buildings', 'street', 'forest', 'sea', 'mountain', 'glacier']
num_class = len(map)
#将label中的字符串转换为数字
for i in range(len(map)):
filelabel[filelabel==map[i]] = i
#将对象转换为列表
filename = filename.values
filelabel = filelabel.values
定义读取数据集的类(包括训练集和测试集)
class TrainDataset(torch.utils.data.Dataset):
def __init__(self, root, img_list, label_list, transform = None):
self.root = root
self.img_list = img_list
self.label_list = label_list
self.transform = transform
def __getitem__(self, index):
img = Image.open(self.root + self.img_list[index]).convert('RGB')
label = self.label_list[index]
if self.transform:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.img_list)
class TestDataset(torch.utils.data.Dataset):
def __init__(self, img_path, transform = None):
self.img_path = img_path
self.transform = transform
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform:
img = self.transform(img)
return img,index
def __len__(self):
return len(self.img_path)
预处理(对数据集进行预处理)
transform = {
'train': transforms.Compose([
transforms.Resize((224, 224),interpolation=2),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]),
'val': transforms.Compose([
])
}
调用读取数据集的类(包括训练集和测试集)
train_dataset = TrainDataset('./train/', filename, filelabel, transform['train'])
tra_dataset, val_dataset = random_split(train_dataset, [10000, 3627])
test_dataset = TestDataset([x.path for x in os.scandir('./test/')], transform['train'])
tra_loader = DataLoader(tra_dataset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)
tra_dataset_num = tra_dataset.__len__()
初始化预训练模型
def initializeModel(model_name, num_class, finetuning=False, pretrained=True):
if model_name == 'alexnet':
model = models.alexnet(pretrained=pretrained)
elif model_name == 'vgg11':
model = models.vgg11(pretrained=pretrained)
elif model_name == 'vgg11_bn':
model = models.vgg11_bn(pretrained=pretrained)
elif model_name == 'vgg13':
model = models.vgg13(pretrained=pretrained)
elif model_name == 'vgg13_bn':
model = models.vgg13_bn(pretrained=pretrained)
elif model_name == 'vgg16':
model = models.vgg16(pretrained=pretrained)
elif model_name == 'vgg16_bn':
model = models.vgg11(pretrained=pretrained)
elif model_name == 'vgg19':
model = models.vgg11(pretrained=pretrained)
elif model_name == 'vgg19_bn':
model = models.vgg11(pretrained=pretrained)
elif model_name == 'resnet18':
model = models.resnet18(pretrained=pretrained)
elif model_name == 'resnet34':
model = models.resnet34(pretrained=pretrained)
elif model_name == 'resnet50':
model = models.resnet50(pretrained=pretrained)
elif model_name == 'resnet101':
model = models.resnet101(pretrained=pretrained)
elif model_name == 'resnet152':
model = models.resnet152(pretrained=pretrained)
elif model_name == 'squeezenet1_0':
model = models.squeezenet1_0(pretrained=pretrained)
elif model_name == 'squeezenet1_1':
model = models.squeezenet1_1(pretrained=pretrained)
elif model_name == 'densenet121':
model = models.densenet121(pretrained=pretrained)
elif model_name == 'densenet169':
model = models.densenet169(pretrained=pretrained)
elif model_name == 'densenet161':
model = models.densenet161(pretrained=pretrained)
elif model_name == 'densenet201':
model = models.densenet201(pretrained=pretrained)
elif model_name == 'inception_v3':
model = models.inception_v3(pretrained=pretrained)
elif modle_name == 'googlenet':
model = models.googlenet(pretrained=pretrained)
elif model_name == 'shufflenet_v2_x0_5':
model = models.shufflenet_v2_x0_5(pretrained=pretrained)
elif model_name == 'shufflenet_v2_x1_0':
model = models.shufflenet_v2_x1_0(pretrained=pretrained)
elif model_name == 'shufflenet_v2_x1_5':
model = models.shufflenet_v2_x1_5(pretrained=pretrained)
elif model_name == 'shufflenet_v2_x2_0':
model = models.shufflenet_v2_x2_0(pretrained=pretrained)
elif model_name == 'mobilenet_v2':
model = models.mobilenet_v2(pretrained=pretrained)
elif model_name == 'resnext50_32x4d':
model = models.resnext50_32x4d(pretrained=pretrained)
elif model_name == 'resnext101_32x8d':
model = models.resnext101_32x8d(pretrained=pretrained)
elif model_name == 'wide_resnet50_2':
model = models.wide_resnet50_2(pretrained=pretrained)
elif model_name == 'wide_resnet101_2':
model = models.wide_resnet101_2(pretrained=pretrained)
elif model_name == 'mnasnet0_5':
model = models.mnasnet0_5(pretrained=pretrained)
elif model_name == 'mnasnet0_75':
model = models.mnasnet0_75(pretrained=pretrained)
elif model_name == 'mnasnet1_0':
model = models.mnasnet1_0(pretrained=pretrained)
elif model_name == 'mnasnet1_3':
model = models.mnasnet1_3(pretrained=pretrained)
else:
raise ValueError('No such Model %s' % model_name)
if finetuning:
for param in model.parameters():
param.requires_grad = True
else:
for param in model.parameters():
param.requires_grad = False
fc_features = model.fc.in_features #提取预训练网络模型fc层中固定的参数
model.fc = nn.Linear(fc_features, num_class) #将预训练网络模型fc层中最终分类的类别数修改为数据集的类别数
model = model.to(device) #将模型加载到指定设备(GPU)上
return model
定义训练方法
def traWay(model, criterion, optimizer, epochs):
begin_time = time.time()
once_begin_time = begin_time
for epoch in range(epochs):
print('Epoch {}/{}'.format(epoch+1, epochs))
print('-' * 10)
running_loss = 0.0
running_corrects = 0.0
#遍历数据集
for img, labels in tra_loader:
img = img.to(device)
labels = labels.to(device)
optimizer.zero_grad() #将梯度初始化为零
outputs = model(img) #前向传播求出预测的值
preds = torch.argmax(outputs, dim=1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step() #对参数进行更新
running_loss += loss.item() * img.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss/tra_dataset_num
epoch_acc = running_corrects/tra_dataset_num
print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
print('Training Time per Epoch {}'.format(time.time() - once_begin_time))
once_begin_time = time.time()
end_time = time.time() - begin_time
print('Training complete in {:.0f}m {:.0f}s'.format(end_time // 60, end_time % 60))
return model
训练
model = initializeModel('resnet152', num_class, True)
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
pre_check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar'
if '152_state_best.tar' in os.listdir(r'/content/drive/My Drive/Scene/checkpoint'):
print('loading previous state......')
checkpoint = torch.load(pre_check_name)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
model = traWay(model, criterion, optimizer, 1)
check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar'
torch.save({
'model_state_dict':model.state_dict(),
'optimizer_state_dict':optimizer.state_dict()
},check_name)
测试
model = initializeModel('resnet152', num_class, False)
check_name = r'/content/drive/My Drive/Scene/checkpoint/152_state_best.tar'
checkpoint = torch.load(check_name)
model.load_state_dict(checkpoint['model_state_dict'])
with open('./result.txt', mode='w') as result_file:
for img, index in test_loader:
img = img.to(device)
outputs = model(img)
preds = torch.argmax(outputs, dim=1)
for i in range(index.shape[0]):
print(str(np.array(index)[i].item())+','+str(map[preds[i]]))
result_file.write(str(index[i].item())+','+str(map[preds[i]])+'\n')