背景
对于keras,没有tfrecord作为输入数据,通常使用yield (inputs, outputs)作为输入,以我的这次训练为例,crnn训练从txt中读入图片和label如下所示:
out/1358695.jpg from
out/650690.jpg If
out/2353853.jpg sections
以空格作为分割符,前面时图片路径,后面是label。此时可以
1.将图片和label的列表保存在内存中
2.将图片和label读入内存,这样训练时候省去了读图片的时间,加快训练速度。
而将300万张图片读入内存非常耗时,因此写了下面代码多线程读入。
实现
引入包,并且包括parameter,该类包含了路径等一些配置信息。
import cv2
import random
import numpy as np
import os
import time
from multiprocessing import Pool
from parameter import *
定义多线程类
class MyThread(object):
def __init__(self, func):
self.func = func
#耗时的读入操作和try操作,如果实现其它功能,可以修改下面内容
def long_time_task(self,i):
# i 内容:out/451073.jpg WARRANTIES
img_name,label = i.strip().split(' ')
try:
img = cv2.imread(img_root_dir+img_name,cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (img_w, img_h))
img = img.astype(np.float32)
img = (img / 255.0) * 2.0 - 1.0
except Exception as e:
print(img_name)
print(e)
return None
else:
return (img,label)
# 传入的img_list,包含图片地址和label
def parse_thread(self,img_list):
p = Pool()
results = []
for