所用代码地址:crnn.caffe
1. 数据准备
1.1 图像数据准备
这里需要将图像数据统一转换到12832(宽度高度)上来,当然也可以更改为其它类型的长度,这里只是使用了该尺寸。
1.2 Label数据准备
Label数据是与图像数据对应的数据,其中包含了图像中的具体字符数据。在制作label数据的时候需要将不同的字符转换到不同的数字标号上来,这里需要将字符映射表的最后一位设置为“_blank_”。这里还需要注意的是Label的长度应该和最大label的长度对应否则会超过label的表示范围。
2. 修改crnn.prototxt
这里首先假设需要的分类数目为N,再加上背景那么就是N+1类。所以就要在CRNN中首先就要修改的便是:
layer {
name: "fc1"
type: "InnerProduct"
bottom: "drop1"
top: "fc1"
param {
lr_mult: 1
decay_mult: 1
}
param {
lr_mult: 2
decay_mult: 0
}
inner_product_param {
num_output: N+1
axis: 2
weight_filler {
type: "xavier"
}
bias_filler {
type: "constant"
value: 0
}
}
}
还有
layer {
name: "ctc_loss"
type: "CtcLoss"
bottom: "fc1"
bottom: "label"
top: "ctc_loss"
loss_weight: 1.0
ctc_loss_param {
blank_label: N
alphabet_size: N+1
time_step: 32
}
}
还有精度层
layer {
name: "accuracy"
type: "LabelsequenceAccuracy"
bottom: "premuted_fc"
bottom: "label"
top: "accuracy"
labelsequence_accuracy_param {
blank_label: N
}
}
其中还需要修改的便是time_step的值,也就是CRNN中的帧率。将time_step都设置为32,因为我们的输入图像宽度为128。
修改reshape层
layer {
name: "reshape"
type: "Reshape"
bottom: "conv6"
top: "reshape"
reshape_param {
shape {
#n*c*(w*h)
dim: 64
dim: 512
dim: 32
}
}
}
这里的batch_size为64,所以也需要修改前面的data层。
3. 开始训练网络
训练用的代码
# -*- coding=utf-8 -*-
import numpy as np
import sys
sys.path.append('~/Desktop/crnn.caffe/python')
import caffe
# training
caffe.set_device(2)
caffe.set_mode_gpu()
solver = caffe.SGDSolver('solver.prototxt')
net = solver.net
print net.blobs['label'].data[0].shape
iter_nums = 100000
for _ in range(iter_nums):
solver.step(1)
在这里需要注意solver参数的选择,否则会出现不收敛的情况-_-||,这也算是一个坑吧…
4. 预测结果
deploy文件就自己生成了哈,也记得修改其中的batch_size…
# -*- coding=utf-8 -*-
import sys
sys.path.append('~/Desktop/crnn.caffe/python')
import caffe
from PIL import Image
import numpy as np
model_file = './snapshot/_iter_60000.caffemodel'
deploy_file = 'crnn_deploy.prototxt'
test_img = '2.jpg'
# set device
caffe.set_device(2)
caffe.set_mode_gpu()
# load model
net = caffe.Net(deploy_file, model_file, caffe.TEST)
# load test img
img = Image.open(test_img)
img = img.resize((128, 32), Image.ANTIALIAS)
in_ = np.array(img, dtype=np.float32)
in_ = in_[:,:,::-1]
in_ = in_.transpose((2,0,1))
# 执行上面设置的图片预处理操作,并将图片载入到blob中
# shape for input (data blob is N x C x H x W), set data
net.blobs['data'].reshape(1, *in_.shape)
net.blobs['data'].data[...] = in_
# run net
net.forward()
# get result
res = net.blobs['probs'].data
print('result shape is:', res.shape)
# 取出标签文档
char_set = []
with open('label.txt', 'r') as f:
line = f.readline()
while line:
line = line.strip('\n\r')
# print(line)
char_set.append(str(line))
line = f.readline()
# 取出最多可能的label标签
for i in range(32):
data = res[i, :, :]
index = np.argmax(data)
#print(index, data[0, index])
print(char_set[index])