目录
1.简介
数据集来源为Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到数据集。
该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。
1.1 思路
将不定长字符转换为定长字符的识别问题,并使用CNN完成训练和验证,具体包括以下几个步骤:
-
数据读取(封装为Pytorch的Dataset和DataLoder)
-
构建CNN模型(使用Pytorch搭建)
-
模型训练与验证
-
模型结果预测
1.2 运行环境及安装示例
- 运行环境要求:Python2/3,Pytorch1.x,内存4G,有无GPU都可以。
下面给出python3.7+ torch1.3.1gpu版本的环境安装示例:
- 首先在Anaconda中创建一个专用虚拟环境。
$conda create -n py37_torch131 python=3.7
- 激活环境,并安装pytorch1.3.1
$source activate py37_torch131
$conda install pytorch=1.3.1 torchvision cudatoolkit=10.0
- 通过下面的命令一键安装所需其它依赖库
$pip install jupyter tqdm opencv-python matplotlib pandas
- 启动notebook,即可开始baseline代码的学习
$jupyter-notebook
2.代码实例
为了方便使用,代码运行在google colab 上运行,数据存储在google driver上。
2.1. 导入常用的包:
import os, sys, glob, shutil, json
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import cv2
from PIL import Image
import numpy as np
from tqdm import tqdm, tqdm_notebook
import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
2.2. 定义好读取图像的Dataset
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 设置最长的字符长度为5个
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)
2.3. 定义好训练数据和验证数据的Dataset
train_path = glob.glob('/content/drive/My Drive/mchar_train.zip (Unzipped Files)/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('/content/drive/My Drive/SVHN/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.RandomCrop((60, 120)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size