一、数据集HWDB
由于HWDB数据集解压之后是gnt格式的,需要将其转换为png格式。
https://blog.youkuaiyun.com/yql_617540298/article/details/82740382博客介绍了如何转换。
https://pan.baidu.com/s/1o84jIrg#list/path=%2F百度云中可以直接下载转换后的结果。
二、源码
参考github:https://github.com/burness/tensorflow-101/tree/master/chinese_hand_write_rec/src
import tensorflow as tf
import os
import random
import tensorflow.contrib.slim as slim
import time
import logging
import numpy as np
import pickle
from PIL import Image
logger = logging.getLogger('Training a chinese write char recognition')
logger.setLevel(logging.INFO)
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
logger.addHandler(ch)
tf.app.flags.DEFINE_boolean('random_flip_up_down', False, "Whether to random flip up down")
tf.app.flags.DEFINE_boolean('random_brightness', True, "whether to adjust brightness")
tf.app.flags.DEFINE_boolean('random_contrast', True, "whether to random constrast")
tf.app.flags.DEFINE_integer('charset_size', 3755, "Choose the first `charset_size` character to conduct our experiment.")
tf.app.flags.DEFINE_integer('image_size', 64, "Needs to provide same value as in training.")
tf.app.flags.DEFINE_boolean('gray', True, "whether to change the rbg to gray")
tf.app.flags.DEFINE_integer('max_steps', 12002, 'the max training steps ')
tf.app.flags.DEFINE_integer('eval_steps', 50, "the step num to eval")
tf.app.flags.DEFINE_integer('save_steps', 2000, "the steps to save")
tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoint/', 'the checkpoint dir')
tf.app.flags.DEFINE_string('train_data_dir', './data/train/', 'the train dataset dir')
tf.app.flags.DEFINE_string('test_data_dir', './data/test/', 'the test dataset dir')
tf.app.flags.DEFINE_string('log_dir', './log', 'the logging dir')
tf.app.flags.DEFINE_boolean('restore', False, 'whether to restore from checkpoint')
tf.app.flags.DEFINE_boolean('epoch', 1, 'Number of epoches')
tf.app.flags.DEFINE_boolean('batch_size', 128, 'Validation batch size')
tf.app.flags.DEFINE_string('mode', 'train', 'Running mode. One of {"train", "valid", "test"}')
FLAGS = tf.app.flags.FLAGS
class DataIterator:
def __init__(self, data_dir):
# Set FLAGS.charset_size to a small value if available computation power is limited.
truncate_path = data_dir + ('%05d' % FLAGS.charset_size)
print(truncate_path)
self.image_names = []
for root, sub_folder, file_list in os.walk(data_dir):
if root < truncate_path:
self.image_names += [os.path.join(root, file_path) for file_path in file_list]
random.shuffle(self.image_names)
self.labels = [int(file_name[len(data_dir):].split(os.sep)[0]) for file_name in self.image_names]
@property
def size(self):
retur