计算机视觉(一)-----街景字符识别简介及代码实例

1.简介

数据集来源为Google街景图像中的门牌号数据集(The Street View House Numbers Dataset, SVHN),并根据一定方式采样得到数据集。
该数据来自真实场景的门牌号。训练集数据包括3W张照片,验证集数据包括1W张照片,每张照片包括颜色图像和对应的编码类别和具体位置;为了保证比赛的公平性,测试集A包括4W张照片,测试集B包括4W张照片。

1.1 思路

将不定长字符转换为定长字符的识别问题,并使用CNN完成训练和验证,具体包括以下几个步骤:

  • 数据读取(封装为Pytorch的Dataset和DataLoder)

  • 构建CNN模型(使用Pytorch搭建)

  • 模型训练与验证

  • 模型结果预测

1.2 运行环境及安装示例

  • 运行环境要求:Python2/3,Pytorch1.x,内存4G,有无GPU都可以。

下面给出python3.7+ torch1.3.1gpu版本的环境安装示例:

  • 首先在Anaconda中创建一个专用虚拟环境。

$conda create -n py37_torch131 python=3.7

  • 激活环境,并安装pytorch1.3.1

$source activate py37_torch131
$conda install pytorch=1.3.1 torchvision cudatoolkit=10.0

  • 通过下面的命令一键安装所需其它依赖库

$pip install jupyter tqdm opencv-python matplotlib pandas

  • 启动notebook,即可开始baseline代码的学习

$jupyter-notebook

2.代码实例

为了方便使用,代码运行在google colab 上运行,数据存储在google driver上。

2.1. 导入常用的包:

import os, sys, glob, shutil, json
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import cv2

from PIL import Image
import numpy as np

from tqdm import tqdm, tqdm_notebook

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

2.2. 定义好读取图像的Dataset

class SVHNDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label 
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):
        img = Image.open(self.img_path[index]).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        
        # 设置最长的字符长度为5个
        lbl = np.array(self.img_label[index], dtype=np.int)
        lbl = list(lbl)  + (5 - len(lbl)) * [10]
        return img, torch.from_numpy(np.array(lbl[:5]))

    def __len__(self):
        return len(self.img_path)

2.3. 定义好训练数据和验证数据的Dataset

train_path = glob.glob('/content/drive/My Drive/mchar_train.zip (Unzipped Files)/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('/content/drive/My Drive/SVHN/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
print(len(train_path), len(train_label))

train_loader = torch.utils.data.DataLoader(
    SVHNDataset(train_path, train_label,
                transforms.Compose([
                    transforms.Resize((64, 128)),
                    transforms.RandomCrop((60, 120)),
                    transforms.ColorJitter(0.3, 0.3, 0.2),
                    transforms.RandomRotation(10),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])), 
    batch_size
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值