最近一直在处理lmdb格式的数据集,因此对于此方向上几个数据集的lmdb格式资源进行分享
对于这些数据集的出处和源数据下载,可以参考这篇
对于整合的数据集,可以下载,2020年UnrealText这篇论文整理了修改标签后的统一版本:
https://github.com/Jyouhou/Case-Sensitive-Scene-Text-Recognition-Datasets
由于Syn90和ST80过大,他们将于稍后时间上传
SVT:优快云下载 度娘下载(3sgh)
CUTE80:优快云下载
Syn90:
ST80:
COCO-Text:https://blog.youkuaiyun.com/zhaominyiz/article/details/106045449
1、创建 你可以修改init_args()里的配置信息
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import re
from PIL import Image
import numpy as np
import imghdr
import argparse
def init_args():
args = argparse.ArgumentParser()
args.add_argument('-i',
'--image_dir',
type=str,
help='The directory of the dataset , which contains the images',
default='D:\mnt/ramdisk/max/90kDICT32px/')
args.add_argument('-l',
'--label_file',
type=str,
help='The file which contains the paths and the labels of the data set',
default='D:\mnt/ramdisk/max/90kDICT32px/all.txt')
args.add_argument('-s',
'--save_dir',
type=str
, help='The generated mdb file save dir',
default='D:/syn90k')
args.add_argument('-m',
'--map_size',
help='map size of lmdb',
type=int,
default=274877906944/4)
# 256/4GB
return args.parse_args()
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
txn = env.begin(write=True)
for k, v in cache.items():
txn.put(k.encode(), v)
txn.commit()
def _is_difficult(word):
assert isinstance(word, str)
return not re.match('^[\w]+$', word)
def createDataset(outputPath, imagePathList, labelList,SIZ, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
ARGS:
outputPath : LMDB output path
imagePathList : list of image path
labelList : list of corresponding groundtruth texts
lexiconList : (optional) list of lexicon lists
checkValid : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=SIZ)
cache = {}
for i in range(nSamples):
imagePath = imagePathList[i]
label = labelList[i]
if len(label) == 0:
continue
if not os.path.exists(imagePath):
print('%s does not exist' % imagePath)
continue
with open(imagePath, 'rb') as f:
imageBin = f.read()
if checkValid:
if not checkImageIsValid(imageBin):
print('%s is not a valid image' % imagePath)
continue
#数据库中都是二进制数据
imageKey = 'img_'+str(i+1)
labelKey = 'lab_'+str(i+1)
cache[imageKey] = imageBin
cache[labelKey] = label.encode()
if i % 1000 == 0:
writeCache(env, cache)
cache = {}
print("SOLVE",i)
cache['num-samples'] = str(nSamples).encode()
writeCache(env, cache)
env.close()
print('Created dataset with %d samples' % nSamples)
if __name__ == '__main__':
args = init_args()
imgdata = open(args.label_file, mode='r', encoding='utf-8')
lines = list(imgdata)
imgDir = args.image_dir
imgPathList = []
labelList = []
SIZ =0
for line in lines:
# print("LINE=",line)
imgPath = os.path.join(imgDir, line.split()[0]).replace('\\','/')
imgPathList.append(imgPath)
tmp = line.split()[0]
word = line.split('_')[1]
word = str.lower(word)
labelList.append(word)
SIZ +=os.path.getsize(imgPath)
# print(imgPath,word)
print("SIZ=",SIZ,"ALL=",len(labelList))
createDataset(args.save_dir, imgPathList, labelList, SIZ*2.1)
2、读取 以pytorch为例
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import lmdb
import six
import cv2
class LmdbData(Dataset):
def __init__(self, txt_path,maxlen,labeltool,transform = None, target_transform = None):
super(LmdbData, self).__init__()
self.labeltool=labeltool
self.to_Tensor = transforms.ToTensor()
self.maxlen=maxlen
self.env = lmdb.open(txt_path,readonly=True)
txn = self.env.begin()
self.len = int(txn.get('num-samples'.encode()))
# print("Read",self.len)
self.transform = transform
self.resize = transforms.Resize(32,120)
self.target_transform = target_transform
def __getitem__(self, index):
assert index <= len(self), 'index range error'
index += 1
with self.env.begin(write=False) as txn:
img_key = 'img_'+str(index)
imgbuf = txn.get(img_key.encode('utf-8'))
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB')
# img.save('gao.jpg')
# print(img)
except IOError:
print('Corrupted image for %d' % index)
return self[index + 1]
# print("PPP")
label_key = 'lab_'+str(index)
label = txn.get(label_key.encode()).decode()
return img, label
def __len__(self):
return self.len