【新手向】基于生成对抗网络算法(GAN)的手写稿生成专属字体文件教学

目录

项目背景与前言

解决方案

1.项目背景

1.1.字体家AI神笔:

1.2.手迹造字:

1.3.calligraphr

1.4.Handwriting Synthesis

1. 5.zi2zi(zi2zi-chain)

GAN简介

2.项目实现思路

3.准备工作

3.1电脑配置要求

3.1.1环境说明

3.2处理数据

3.2.1图片处理

3.2.2图片重命名(高版本Python)

3.2.3生成数据集(Python3.7)

3.2.3.1克隆项目

3.2.3.2图片数据规范命名

3.2.3.3数据集(图片生成)

3.2.3.4数据集(打包)

3.3上云(Python3.7+CUDA10.1)

3.3.1关于飞桨AI studio

3.3.2云端训练

3.3.2.1环境配置

3.3.2.2运行模型

 3.3.3可能的报错

3.3.3.1unrecognized arguments

3.3.3.2AssertionError from torch.cuda.is_available()

3.3.3.3FileNotFoundError: [Errno 2] No such file or directory: 'experiment/data/val.obj'

3.3.3.4KeyError

3.3.3.5__getitem__:KeyError

3.3.3.6PickledImageProvider相关

3.3.3.7数据集问题

结语


项目背景与前言

语文老师在我们的假期布置了少量的语文练字作业,突然想到想着能不能用科技改变生活,高打印质量完成作业(配合写字机器人食用效果更佳)

本项目核心算法参考GitHub上的zi2zi-chain项目,遵守Apache 2.0开源协议

叠个甲:作者的字帖其实已经写完了,本项目仅用于学习研究用途,请合理范围内利用本项目,不得将其应用到不合规的地方,并且后果自负

另外,这是我第一次写优快云,有任何写的有问题的地方还望大家指正

解决方案

1.项目背景

现目前网上有许多类似的案例,比如字体家AI神笔(付费),手迹造字,calligraphr,Handwriting Synthesis,zi2zi,相比传统的需要3500个常用字以上的方法,以上几款都可以使用较少的字符进行训练,对比如下。

1.1.字体家AI神笔:
  • 有免费单字生成和字体融合图功能,每日免费单字生成50字,字体融合图4次机会(约等于仅可体验)

  • 手写8个汉字即可生成包含6000多个汉字的字体库。

  • 能较好模拟手写笔触,生成的字体风格多样,但生成结果还原不算很高(会自动美化笔迹)。

1.2.手迹造字:
  • 基本功能免费使用。
  • 最少手写100字可生成字体,
  • 手机端APP即可使用,操作简单,生成效率也较好,但风格种类受限于手写样本,相对单一,不太适合用于本场景。
1.3.calligraphr
  • 免费版可设置最多75个字符。
  • 需下载模板手写后上传,生成流程相对复杂一些,且字符数量有限,效果不佳。
1.4.Handwriting Synthesis
  • 开源免费。
  • 需有一定技术基础来训练模型,操作复杂,生成效率取决于用户对技术的掌握和训练过程,便捷性较差。
  • 基于深度学习技术,可生成较为接近手写真实感的字体,若训练得当,能较好模拟特定手写风格。
1. 5.zi2zi(zi2zi-chain)
  • 开源免费。
  • 通过条件生成对抗网络,能够学习并生成多种字体样式,但是部署比较复杂。
  • 可实现很好的风格模拟效果(这个算法可以用来推算书法字体,效果很好),能处理多样、复杂的字体风格。

所以,本项目核心算法采用not-bald-owl大佬的生成对抗网络(GAN)算法(zi2zi-chain项目)进行训练。

新楷体——楷书-方正多宝塔碑 模型推理效果如下

新楷体——楷书-魏碑 模型推理效果如下

 

新楷体——楷书-赵孟頫三门记 模型推理效果如下

这个都不是像了,还原度太高了,这就是GAN算法在图片风格迁移上的应用

GAN简介

那么还是来简单介绍一下

生成对抗网络(Generative Adversarial Networks,GAN)是深度学习领域创新性无监督学习框架,由Ian Goodfellow等2014年提出。其灵感源于博弈论“零和博弈”,通过构建生成器与判别器两个对抗的神经网络,实现从随机噪声生成逼真数据的目标,改变了传统生成模型设计范式,是近年AI研究最具影响力方向之一。

GAN训练本质是生成器与判别器的动态博弈。生成器将随机噪声(如高斯或均匀分布向量)映射为与真实数据分布相似的合成数据(如图像、文本、音频);判别器是二分类模型,判断输入数据来自真实集还是生成器。两者目标对立:生成器“欺骗”判别器,判别器提升鉴别能力。这种对抗关系用极小极大博弈形式化表达。

训练时两网络交替优化:先固定生成器参数,训练判别器以最大化正确分类真实与生成数据的能力;再固定判别器参数,训练生成器以最小化判别器对生成数据的正确判断概率。理想状态下,训练达纳什均衡时,生成器完美拟合真实数据分布,判别器无法区分数据来源,生成样本与真实数据统计特性高度一致。

2.项目实现思路

由于我的电脑显卡为GeForce MX250,2GB显存,不满足本地推力计算要求中的显存≥8GB,但是支持CUDA10.2,并且有Python3.7和3.13双版本,所以本项目演示实现过程如下

  • 本地(无配置与环境要求):扫描软件+Photoshop切片,生成简单处理好的数据
  • 本地(Python3.13):使用OCR(这里是Python的easyOCR库)对图片按照{汉字}_[变体ID].png的格式重命名,处理好切片的图片

easyOCR:部署非常简单的Python库,识别准确率比较一般,有很多需要自己修改的,但也能大大加快我们处理的速度,所以选择了这个(我都做完了,后来才知道用百度飞桨会更好,悲)

  • 本地(Python3.7):这一步用于处理数据集,但不需要GPU,600张图片只算生成的时间可能一分钟左右,无性能要求
  • 云端(百度飞桨AI studio):这一步是在云端使用云上GPU计算,推理生成字体
    • 特别注意
    • 云端为Linux环境,部分终端中的代码与Windows的不一样
    • 百度飞桨每天使用GPU有4小时限制,不调用的时候请使用CPU,具体方法见3.3.1关于飞桨AI studio

为什么选择百度飞桨

  1. 每天上线免费送的算力可以在GPU环境中运行4小时,并且CPU环境不限制时长
  2. 提供的框架中有Python3.7+CUDA10.1的环境(就找这个环境就花了我两天)

除了百度飞桨还有哪些算力平台?

  1. 阿里云:阿里云提供的交互式建模(DSW)可以免费试用(貌似是750算力时分为250*3月发放),提供的官方镜像中有cu101(CUDA10.1)的,但是是Python3.6的环境,不知道能不能用conda创建需要的环境,或者能不能降级Python
  2. Google Colab:很著名的算力平台,但是配置很麻烦(我花了3后来放弃了),没有现成的符合条件的环境,需要降级Python(这个比较简单),但是配置低版本的CUDA和cuDNN就比较麻烦了,使用conda老是出问题,不知道有没有懂的朋友帮忙解释一下,每天有12h的GPU使用额度,并且需要科学上网,网速一般都比较恶心
  3. 智星云:这个有对应的CUDA版本,但不完全免费
  4. BML:也是百度的,支持自定义环境(为数不多框架镜像里面有pytorch1.5.1+Python3.7+CUDA10.1),但是付费

3.准备工作

3.1电脑配置要求
  • 以下两点二选一
    • 有独立显卡,支持CUDA10.1/10.2cuDNN7.6.5(硬性要求,若有高版本必须降级,或使用conda创建虚拟环境,各版本互不兼容)(支持的架构有Volta架构,Turing架构,Pascal架构等),详见这篇帖子,显存≥4GB(可能报错,建议≥8GB),进行本地推理计算
    • 若没有(英伟达)独显/显存不足/架构不支持/内存不足,使用云服务在云端计算(免费/试用的有:阿里云,腾讯云,百度飞桨,Google Colab(需要科学上网,且配置比较复杂))
  • Python环境:
    • 若使用本地推力计算:Python3.7(必须),Python3.13(高版本,使用easyOCR进行重命名时使用,可以不装,但需要你能根据扫描并切片后的几百张图片逐个重命名),多个版本对于每个不同的项目需要使用虚拟环境进行隔离
    • 若使用云服务在云端计算:Python3.7(建议,3.2.3生成数据集(Python3.7)部分需要,在云端比较麻烦),Python3.13(同上)
  • Photoshop(建议):用于将处理自己的手写稿样本(比如我的是已经写字帖上,需要抠掉稿纸的线,并且将文字单独切片),否则制作切片比较麻烦
  • 平板+触控笔(这个有没有都无所谓,可以提高效率):如果你已经写了一些在字帖上,或者想用之前已经写好的任何笔记内容作为样本,需要干净的纯白背景,可以用这个快速擦除不需要的部分
3.1.1环境说明

训练时需要的库:

  • pytorch 1.5.1
  • pillow 7.1.2
  • numpy 1.18.1
  • scipy 1.4.1
  • imageio 2.8.0

另外建议

  • Git(建议,安装在本地):用于克隆GitHub/Gitee项目,下载链接,网上有教程
  • VsCode(强烈建议,本地):全称Visual Studio Code,是著名的集成开发环境(IDE),在里面装上插件codebuddy(强烈推荐:一款对于小型项目能直接帮你写代码,运行代码,修改报错的AI
3.2处理数据
3.2.1图片处理

样本可以来源于任何地方,但必须扫描之后保证背景为纯白色,不含任何无关线条(允许少量污点),样本数量推荐300(最低,重复的字算作一次)以上,500以上为佳,建议选用风格统一的,不然需要更多样本。所以,建议准备约4页字帖,大概处理成这样

另外,保留所有变体!例如“的”字有很多,全部保留,另外,标点符号可以全部擦除(这里不擦后面图像处理的时候也会擦,毕竟标点容易露馅)

处理流程大概是:(网上有详细教程)

  1. 扫描手写稿
  2. 导入后取消背景层锁定-Ctrl+shift+U去色-上方菜单栏:图像-调整-阈值:128
  3. 在这一层下方新建图层并填充为白色,对原图层不需要的部分框选-删除或使用橡皮擦擦除
  4. 顶部菜单栏:视图-标尺→打勾,从上和左两侧拉出标尺(类似图中的)
  5. 左侧切片工具-按参考线切片
  6. Alt+shift+Ctrl+S导出为旧版所用格式-鼠标拉动全选左侧预览的切片-预设:png-24-储存
  7. 记住刚才保存的路径(比如C:\zi2zi\pictures),进入,找到里面的image文件夹(C:\zi2zi\pictures\image),其中zi2zi文件夹为图片处理部分的根目录
3.2.2图片重命名(高版本Python

这里以Python3.13做演示,报错后建议直接复制信息问AI

打开VScode,文件-打开文件夹,选择项目根目录,确定后右下角环境-创建虚拟环境-venv-python3.13

新建rename.py,粘贴以下代码,告诉codebuddy在虚拟环境中安装依赖并运行,以及自己的图片保存位置(或者自己安装,并把最后几行换成实际的路径)

import os
import cv2
import easyocr
from pathlib import Path
import numpy as np
import logging

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def save_debug_image(img, filename, output_dir="debug_images"):
    """保存调试图像"""
    debug_dir = Path(output_dir)
    debug_dir.mkdir(exist_ok=True)
    output_path = debug_dir / filename
    cv2.imwrite(str(output_path), img)
    logger.info(f"已保存调试图像: {output_path}")

def show_debug_info(img, results):
    """显示调试信息"""
    logger.info(f"识别到 {len(results)} 个文本区域")
    for i, (bbox, text, prob) in enumerate(results):
        logger.info(f"区域 {i+1}: 文本='{text}' 置信度={prob:.2f}")
        # 绘制识别结果但不显示
        debug_img = img.copy()
        cv2.rectangle(debug_img, 
                     tuple(map(int, bbox[0])), 
                     tuple(map(int, bbox[2])), 
                     (0, 255, 0), 2)
        cv2.putText(debug_img, f"{text} {prob:.2f}", 
                   tuple(map(int, bbox[0])),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
        save_debug_image(debug_img, f"debug_{i}.png")

def preprocess_image(img_path):
    """改进的图像预处理"""
    # 读取图像
    img_array = np.fromfile(img_path, dtype=np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"无法读取图像: {img_path}")
    
    # 保存原始图像
    save_debug_image(img, "original.png")
    
    # 转换为灰度图
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    save_debug_image(gray, "gray.png")
    
    # 直方图均衡化
    gray = cv2.equalizeHist(gray)
    save_debug_image(gray, "equalized.png")
    
    # 自适应阈值
    thresh = cv2.adaptiveThreshold(gray, 255, 
                                 cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                 cv2.THRESH_BINARY_INV, 21, 5)
    save_debug_image(thresh, "threshold.png")
    
    # 形态学操作
    kernel = np.ones((3,3), np.uint8)
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
    save_debug_image(thresh, "morphology.png")
    
    return img, thresh

def rename_outputs(input_dir, output_dir):
    # 初始化EasyOCR阅读器
    reader = easyocr.Reader(['ch_sim'], 
                          gpu=False,
                          download_enabled=False)
    
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)
    
    for filename in os.listdir(input_dir):
        if not filename.lower().endswith('.png'):
            continue
            
        try:
            variant_num = filename.split('_')[1].split('~')[0]
        except IndexError:
            variant_num = "000"
            
        input_path = Path(input_dir) / filename
        
        if not input_path.exists():
            logger.warning(f"文件不存在: {input_path}")
            continue
            
        try:
            logger.info(f"\n处理文件: {filename}")
            
            # 读取并预处理图像
            img, processed_img = preprocess_image(str(input_path))
            
            # 识别文本
            results = reader.readtext(processed_img,
                                   width_ths=0.7,
                                   height_ths=0.7,
                                   text_threshold=0.4,
                                   low_text=0.3)
            
            # 显示调试信息
            show_debug_info(img.copy(), results)
            
            # 获取识别结果
            recognized_text = results[0][1] if results else "未识别"
            
            new_filename = f"{recognized_text}_{variant_num}~用户手写体.png"
            output_path = Path(output_dir) / new_filename
            cv2.imwrite(str(output_path), img)
            logger.info(f"重命名结果: {filename} -> {new_filename}")
            
        except Exception as e:
            logger.error(f"处理文件 {filename} 时出错: {str(e)}")
            continue

if __name__ == "__main__":
    input_dir = "output_images"    #按照实际情况更改输入路径(相对路径)
    output_dir = "debug_outputs"    #输出路径
    rename_outputs(input_dir, output_dir)

接下来处理好所有的图片,并把识别不准确的名称改正确,若变体ID出错不用管

如果大面积无法识别,请将预处理-反色部分注释掉,或者保存调试图像以检查问题

若没有安装Python3.13

  • 可将所有图片打包到一个文件夹,上传至云端处理(上云方法见后文)
  • 手动重命名,命名规则:{汉字}_[变体ID].png
    • 例如:假设“的”有5个,则分别为:的_01~用户手写体.png……的_05~用户手写体.png
    • 假设“地”只有一个,则为:地_01~用户手写体.png

处理好的大概是这样

 

3.2.3生成数据集(Python3.7)
  • 如果你的显卡支持CUDA10.1/10.2,并且已经安装,那么建议在本地生成数据集,此过程不需要GPU性能,但pytorch库需要CUDA环境,(不需要cuDNN)
  • 如果不支持,或者电脑没有Python3.7,则将图片数据打包,上云(详见后文3.3上云(Python3.7+CUDA10.1)),但是,大量图片上传在云端可能比较慢(尤其是Google Colab,并且在云端是Linux环境,部分代码与实例的不同,例如where Python和which Python)

 在新的根目录中,新建虚拟环境(python3.7),也可以用之前的在VScode中创建并选择虚拟环境

python3.7 -m venv myenv
myenv\Scripts\Activate.ps1
3.2.3.1克隆项目

在项目根目录里打开终端,运行

git clone -b master https://github.com/not-bald-owl/zi2zi-chain.git

一定要加 -b master,这个项目的默认分支里面只有一个许可证声明(雾)

如果访问卡顿可以科学上网,但更推荐转到Gitee上,这样国内的云平台访问也更快

运行后如果终端不是在zi2zi-chain中,那么使用

cd zi2zi-chain

再安装依赖库pytorch 1.5.1 pillow 7.1.2 numpy 1.18.1 scipy 1.4.1 imageio 2.8.0

 注意,安装时一定要先注意版本,根据我踩坑的经验

先运行一次

pip list

然后安装pytorch 1.5.1

pip install pytorch==1.5.1

再运行一次

pip list

如果此时对比上次好多了 pillow 8.x,那么直接 pip uninstall

最后安装其他的库就好了

3.2.3.2图片数据规范命名

新建reset_naming.py,安装依赖并运行

import os
import re
from pathlib import Path
import logging
from collections import defaultdict

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    filename='rename_log.txt'
)
logger = logging.getLogger(__name__)

def extract_pure_text_id(filename):
    """提取纯文字ID(完全忽略现有序号)"""
    # 匹配所有可能的格式变体
    patterns = [
        r"^(.+?)_\d+~用户手写体",  # 含数字序号
        r"^(.+?)~用户手写体(?:\s\(\d+\))?",  # 无序号或带重复标记
        r"^(.+?)\.png"  # 极简格式
    ]
    
    for pattern in patterns:
        match = re.match(pattern, filename)
        if match:
            return match.group(1).strip()  # 去除前后空格
    return None

def reset_all_filenames(input_dir):
    """完全重置所有文件名编号"""
    text_files = defaultdict(list)
    
    # 1. 扫描并分类所有文件(按纯文字ID)
    for filename in os.listdir(input_dir):
        if not filename.lower().endswith('.png'):
            continue
            
        text_id = extract_pure_text_id(filename)
        if not text_id:
            logger.error(f"无法解析文件名: {filename}")
            continue
            
        text_files[text_id].append(filename)
    
    # 2. 重命名所有文件(包括当前"规范"的文件)
    total_renamed = 0
    rename_plan = []
    
    for text_id, files in text_files.items():
        # 按文件名排序(可根据需要修改排序逻辑)
        sorted_files = sorted(files)
        
        for idx, filename in enumerate(sorted_files, start=1):
            new_name = f"{text_id}_{idx:02d}~用户手写体.png"
            old_path = Path(input_dir) / filename
            new_path = Path(input_dir) / new_name
            
            rename_plan.append((old_path, new_path, filename, new_name))
    
    # 3. 执行重命名(先记录后执行)
    logger.info("="*50)
    logger.info(f"即将重命名 {len(rename_plan)} 个文件")
    
    for old_path, new_path, old_name, new_name in rename_plan:
        try:
            # 跳过同名文件(避免重复处理)
            if old_path.name == new_path.name:
                continue
                
            # 处理冲突(先临时后缀)
            temp_path = new_path.with_name(f"temp_{new_path.name}")
            if new_path.exists():
                os.rename(new_path, temp_path)
                
            os.rename(old_path, new_path)
            logger.info(f"重置命名: {old_name} → {new_name}")
            total_renamed += 1
            
            # 清理临时文件
            if temp_path.exists():
                os.remove(temp_path)
                
        except Exception as e:
            logger.error(f"重命名失败 {old_name}: {str(e)}")
    
    return total_renamed

if __name__ == "__main__":
    input_dir = r"D:\img\renamed_outputs"
    print(f"开始完全重置目录: {input_dir}")
    print("详细日志请查看: rename_log.txt")
    
    total = reset_all_filenames(input_dir)
    print(f"\n操作完成,共重置 {total} 个文件名")
    print("注意:所有文件已重新编号,原序号信息已丢弃")

 效果大概这样

3.2.3.3数据集(图片生成)

在新的文件夹中创建training_images.py,运行要求同上

import os
from PIL import Image, ImageDraw, ImageFont
import numpy as np

# 配置参数
SRC_FONT_PATH = 'zi2zi-chain/charset/cjk/simkai.ttf'
HANDWRITING_DIR = 'img_outputs'
OUTPUT_DIR = 'zi2zi-chain/data/train'
WRITER_ID = 10  # 用户确认的writer_dict值
FONT_SIZE = 185  # 缩小20%(256 * 0.8 ≈ 200)
IMAGE_SIZE = (256, 256)
BG_COLOR = (255, 255, 255)  # 白色背景
TEXT_COLOR = (0, 0, 0)      # 黑色文字

# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 加载字体
font = ImageFont.truetype(SRC_FONT_PATH, FONT_SIZE)

def process_character(char, handwriting_path):
    # 生成标准字体图片
    std_img = Image.new('RGB', IMAGE_SIZE, BG_COLOR)
    draw = ImageDraw.Draw(std_img)
    text_width, text_height = draw.textsize(char, font=font)
    position = ((IMAGE_SIZE[0] - text_width) // 2, (IMAGE_SIZE[1] - text_height) // 2)
    draw.text(position, char, font=font, fill=TEXT_COLOR)
    
    # 加载手写体图片并正确处理透明度
    print(f"Loading handwriting image from: {handwriting_path}")
    try:
        hw_img = Image.open(handwriting_path)
        print(f"Original image size: {hw_img.size}, mode: {hw_img.mode}")
        
        # 处理RGBA图片:将透明背景转为白色
        if hw_img.mode == 'RGBA':
            background = Image.new('RGB', hw_img.size, (255, 255, 255))
            background.paste(hw_img, mask=hw_img.split()[3])  # 使用alpha通道作为mask
            hw_img = background
        
        # 确保手写体图片尺寸正确
        if hw_img.size != IMAGE_SIZE:
            print(f"Resizing image to {IMAGE_SIZE}")
            hw_img = hw_img.resize(IMAGE_SIZE, Image.LANCZOS)
        
        # 验证图片数据
        print(f"Processed image size: {hw_img.size}, mode: {hw_img.mode}")
        print(f"Sample pixel values: {np.array(hw_img)[0,0]}")
    except Exception as e:
        print(f"Error loading image: {e}")
        raise
    
    # 水平拼接图片
    combined = Image.new('RGB', (IMAGE_SIZE[0]*2, IMAGE_SIZE[1]))
    combined.paste(std_img, (0, 0))
    combined.paste(hw_img, (IMAGE_SIZE[0], 0))
    
    # 保存结果
    output_path = os.path.join(OUTPUT_DIR, f'{char}~{WRITER_ID}.png')
    combined.save(output_path)
    print(f'Generated: {output_path}')

def main():
    # 遍历手写体图片
    for filename in os.listdir(HANDWRITING_DIR):
        if filename.endswith('.png') and '~' in filename:
            # 提取汉字 (格式: 汉字_变体ID~用户手写体.png)
            char = filename.split('_')[0]
            handwriting_path = os.path.join(HANDWRITING_DIR, filename)
            
            # 处理所有变体
            process_character(char, handwriting_path)

if __name__ == '__main__':
    main()

其中,请适当调整FONT_SIZE的值,让左侧的标准字体与右侧的手写字体大小相近

处理好的大概这样

 

3.2.3.4数据集(打包)

创建package.py,安装依赖并运行

# -*- coding: utf-8 -*-
import argparse
import glob
import json
import os
import pickle
import random
from tqdm import tqdm
import re


def pickle_examples_with_split_ratio(paths, train_path, val_path, train_val_split=0.1):
    """
    Compile a list of examples into pickled format, so during
    the training, all io will happen in memory
    """
    with open(train_path, 'wb') as ft:
        with open(val_path, 'wb') as fv:
            for p, label in tqdm(paths):
                label = int(label)
                with open(p, 'rb') as f:
                    img_bytes = f.read()
                    r = random.random()
                    example = (label, img_bytes)
                    if r < train_val_split:
                        pickle.dump(example, fv)
                    else:
                        pickle.dump(example, ft)


def pickle_examples_with_file_name(paths, obj_path):
    with open(obj_path, 'wb') as fa:
        for p, label in tqdm(paths):
            label = int(label)
            with open(p, 'rb') as f:
                img_bytes = f.read()
                example = (label, img_bytes)
                pickle.dump(example, fa)


parser = argparse.ArgumentParser(description='Compile list of images into a pickled object for training')
parser.add_argument('--dir', required=True, help='path of examples')
parser.add_argument('--save_dir', required=True, help='path to save pickled files')
parser.add_argument('--split_ratio', type=float, default=0.1, dest='split_ratio',
                    help='split ratio between train and val')

parser.add_argument('--dst_json', type=str, default=None)
parser.add_argument('--type_file', type=str, default='type/宋黑类字符集.txt')

parser.add_argument('--save_obj_dir', type=str, default=None)

args = parser.parse_args()


def get_special_type():

    with open(args.type_file, 'r', encoding='utf-8') as fp:
        font_list = [line.strip() for line in fp]
    # print("Font list:", font_list)

    '''
    font_list = os.listdir(args.type_dir)
    font_list = [f[:f.find('.test.jpg')] for f in font_list]
    '''
    font_set = set(font_list)
    font_dict = {v: k for v, k in enumerate(font_list)}
    inv_font_dict = {k: v for v, k in font_dict.items()}
    return font_set, font_dict, inv_font_dict


if __name__ == "__main__":
    if not os.path.isdir(args.save_dir):
        os.mkdir(args.save_dir)

    font_set, font_dict, inv_font_dict = get_special_type()

    # from total label to type label
    label_map = dict()
    ok_fonts = None

    dst_json = args.dst_json
    if not dst_json is None:
        with open(dst_json, 'r', encoding='utf-8') as fp:
            dst_fonts = json.load(fp)

        for idx, dst_font in enumerate(dst_fonts):
            font_name = dst_font['font_name']
            font_name = os.path.splitext(font_name)[0]
            if font_name in font_set:
                label_map[idx] = inv_font_dict[font_name]
            else:
                continue
        ok_fonts = set(label_map.keys())
    

    train_path = os.path.join(args.save_dir, "train.obj")
    val_path = os.path.join(args.save_dir, "val.obj")

    total_file_list = sorted(
        glob.glob(os.path.join(args.dir, "*.jpg")) +
        glob.glob(os.path.join(args.dir, "*.png")) +
        glob.glob(os.path.join(args.dir, "*.tif"))
    )
    # '%d_%05d.png'
    cur_file_list = []
    for file_name in tqdm(total_file_list):
        label = os.path.basename(file_name).split('_')[0]
        label = int(label)
        if ok_fonts is None:
            cur_file_list.append((file_name, label))
        else:
            if label in ok_fonts:
                cur_file_list.append((file_name, label_map[label]))

    if args.split_ratio == 0 and args.save_obj_dir is not None:
        pickle_examples_with_file_name(cur_file_list, args.save_obj_dir)
    else:
        pickle_examples_with_split_ratio(
            cur_file_list,
            train_path=train_path,
            val_path=val_path,
            train_val_split=args.split_ratio
        )

pkl和obj是可以互相转换的

在根目录下,也可以通过终端运行

python package.py --dir=C:\zi2zi\zi2zi-chain\data\raw_images --save_dir=C:\zi2zi\zi2zi-chain\data\bin --split_ratio=0.1

其中:

  • --dir:包含.jpg/.png/.tif的源图片目录,我们处理的是png
  • --save_dir:输出目录(自动生成train.obj和val.obj)
  • --save_obj_dir:直接指定输出的obj文件路径(待会要用
  • --split_ratio:验证集比例(0表示不分割,即只生成训练集,不生成验证集

注意:训练样本过少可能会报错!

3.3上云(Python3.7+CUDA10.1)
3.3.1关于飞桨AI studio

每天上线启动项目时会赠送8算力时,可以运行V100 16GB (完全够用)4h/day,但是每天用完之后就只能使用CPU,且好像不能充值加时

里面的环境是Linux,预装了Python3.7,CUDA10.1,cudnn 7.6.5还有conda

但是在这个封装好的环境中不能使用conda新建虚拟环境,不能使用sudo命令,所以我们隔离环境就只能使用venv,而venv不能隔离CUDA,所以必须选用预装CUDA10.1的飞桨2.2.2框架

conda和venv的区别与联系

  • Conda:既是包管理器(支持Python和非Python依赖),也是环境管理器。可安装Python解释器本身,适合科学计算和跨语言项目,可以隔离CUDA
  • Venv:仅用于管理Python环境,依赖系统已安装的Python解释器,无法管理非Python包,依赖pip安装包,不能隔离CUDA

详细的上云过程我会放在另一篇文章中来讲(也是百度飞桨AI studio)

3.3.2云端训练

当我们新建好之后,点击启动环境,然后会有如下界面

这里我们第一次建议选择CPU,第一次我们要进行环境的配置,详见3.1.1环境说明

3.3.2.1环境配置

如果IDE选择的是VScode,那么在左侧右键work文件夹(work文件夹中内容的每次变更都会被保存,所以虚拟环境放在这里比较方便,下次开机可以直接使用),选择在集成终端中打开,然后下面会出现终端aistudio@jupyter-17752757-9391612:~/work$在终端中直接右键可以粘贴文本

先来检查环境

python3.7 --version

然后和之前一样,克隆GitHub项目

git clone -b master https://github.com/not-bald-owl/zi2zi-chain.git

然后配置并激活虚拟环境

python3.7 -m venv myenv
source myenv/bin/activate

如果遇到意外需要退出虚拟环境,则使用

deactivate 

每次退出重进后需要使用source myenv/bin/activate重新进入虚拟环境,无需重新创建

然后先安装pytorch 1.5.1

pip install torch==1.5.1

然后检查库的内容

pip list

这里会有几个库的版本比要求的高,要求见,其他的库不要卸载

需要先卸载,比如

pip uninstall numpy

然后再安装指定版本

pip install pillow==7.1.2 numpy==1.18.1 imageio==2.8.0 scipy==1.4.1 matplotlib

在检验一下

pip list

 预期为

(myenv)aistudio@jupyter-17752757-9391612:~/work/zi2zi-chain$ pip list
Package           Version
----------------- -----------
cycler            0.11.0
fonttools         4.38.0
future            1.0.0
imageio           2.8.0
kiwisolver        1.4.5
matplotlib        3.5.3
numpy             1.18.1
packaging         24.0
Pillow            7.1.2
pip               24.0
pyparsing         3.1.4
python-dateutil   2.9.0.post0
scipy             1.4.1
setuptools        40.8.0
six               1.17.0
torch             1.5.1
torchvision       0.6.1
typing_extensions 4.7.1

如果确认无误可以下一步了

3.3.2.2运行模型

GPU环境中运行

然后要加入我们的数据集,添加到项目根目录(/home/aistudio/work/zi2zi-chain)里面

具体来说,在项目根目录里面新建experiment文件夹,布局为

experiment/
└── data
    ├── train.obj
    └── val.obj

obj和pkl是等效的,但是在train.py里面调用文件作者只写了.obj,所以需要进行转换

python -c "import pickle, os; [pickle.dump(pickle.load(open(f, 'rb')), open(f.replace('.pkl', '.obj'), 'wb')) for f in os.listdir() if f.endswith('.pkl')]"

启动虚拟环境之后进入项目根目录

cd zi2zi-chain

最后就是运行了

python train.py --experiment_dir=experiment --gpu_ids=cuda:0 --batch_size=32 --epoch=100 --sample_steps=200 --checkpoint_steps=500
 3.3.3可能的报错

(修改仅供参考,不一定能全部解决)

3.3.3.1unrecognized arguments

如果报错unrecognized arguments,请参考确认参数之间用--链接

3.3.3.2AssertionError from torch.cuda.is_available()

如果报错AssertionError from torch.cuda.is_available()检查CUDA安装,是否启用GPU

3.3.3.3FileNotFoundError: [Errno 2] No such file or directory: 'experiment/data/val.obj'

如果报错FileNotFoundError: [Errno 2] No such file or directory: 'experiment/data/val.obj'

# 找到这行代码
val_dataset = DatasetFromObj(os.path.join(data_dir, 'val.obj'), input_nc=args.input_nc)

# 修改为
val_dataset = DatasetFromObj(os.path.join(data_dir, 'val.pkl'), input_nc=args.input_nc)  # 

也可以直接把路径改为完整路径,不通过参数调用

3.3.3.4KeyError

如果报错KeyError,可以在dataset.py中的__init__中添加

self.image_provider = [(0, x) if not isinstance(x, (list, tuple)) else x 
                      for x in PickledImageProvider(obj_path)]
3.3.3.5__getitem__:KeyError

如果dataset.py中的__getitem__报错:KeyError

可以检查数据文件内容

import pickle
with open('/home/aistudio/work/zi2zi-data/data/bin/train.pkl', 'rb') as f:
    data = pickle.load(f)
print(f"数据键: {data.keys()}")
print(f"样本数量: {len(data['train'])}" if 'train' in data else "无'train'键")
3.3.3.6PickledImageProvider相关

如果PickledImageProvider类有问题,可以尝试修改 bytesIO.py 

class PickledImageProvider:
    def __init__(self, obj_path):
        self.obj_path = obj_path
        self.examples = self.load_pickled_examples()
        
    def load_pickled_examples(self):
        with open(self.obj_path, "rb") as f:
            examples = pickle.load(f)
        
        # 标准化数据格式
        if isinstance(examples, dict):
            if 'train' in examples:
                return examples['train']
            elif 'data' in examples:
                return examples['data']
        return examples

    def __getitem__(self, index):
        return self.examples[index]

    def __len__(self):
        return len(self.examples)

dataset.py中

class DatasetFromObj(data.Dataset):
    def __init__(self, obj_path, input_nc=3):
        super().__init__()
        self.image_provider = PickledImageProvider(obj_path.replace('.obj', '.pkl'))
        self.input_nc = input_nc

    def __getitem__(self, index):
        try:
            item = self.image_provider[index]
            # 处理不同的数据格式
            if isinstance(item, (list, tuple)) and len(item) >= 2:
                return self.process(item[1])
            elif hasattr(item, 'keys') and 'image' in item:
                return self.process(item['image'])
            else:
                return self.process(item)
        except Exception as e:
            print(f"Error processing item {index}: {str(e)}")
            # 返回空数据或重新抛出异常
            raise

    def __len__(self):
        return len(self.image_provider)

    def process(self, img):
        # 确保图像处理逻辑正确
        if isinstance(img, (list, tuple)):
            img = img[0]  # 取第一个元素
        # 添加您的图像处理代码
        return img, img  # 示例返回
3.3.3.7数据集问题

如果数据集有问题可以添加调试信息,如

def __init__(self, obj_path, input_nc=3):
    print(f"Loading data from: {obj_path}")
    self.image_provider = PickledImageProvider(obj_path)
    print(f"Total examples: {len(self.image_provider)}")

如果数据集还是有问题,可以使用这个生成

import os
import shutil
import random
from pathlib import Path
import logging

# 设置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("prepare_dataset.log", mode='w'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def create_dataset_structure(output_dir="dataset"):
    """创建zi2zi所需的数据集目录结构"""
    dirs = [
        os.path.join(output_dir, "train", "A"),
        os.path.join(output_dir, "train", "B"),
        os.path.join(output_dir, "train", "test_A"),
        os.path.join(output_dir, "train", "test_B"),
        os.path.join(output_dir, "val", "A"),
        os.path.join(output_dir, "val", "B")
    ]
    
    for dir_path in dirs:
        os.makedirs(dir_path, exist_ok=True)
        logger.info(f"创建目录: {dir_path}")
    
    return dirs

def get_character_variants(img_dir):
    """获取每个汉字及其变体"""
    char_variants = {}
    
    for filename in os.listdir(img_dir):
        if filename.endswith(".png"):
            # 提取汉字部分(文件名格式:汉字_序号~用户手写体.png)
            char = filename.split("_")[0]
            if char not in char_variants:
                char_variants[char] = []
            char_variants[char].append(filename)
    
    logger.info(f"找到 {len(char_variants)} 个不同的汉字")
    return char_variants

def prepare_dataset(source_dir, output_dir="dataset", train_ratio=0.8, test_ratio=0.1):
    """准备zi2zi数据集"""
    # 创建目录结构
    create_dataset_structure(output_dir)
    
    # 获取所有汉字及其变体
    char_variants = get_character_variants(source_dir)
    
    # 统计信息
    total_chars = len(char_variants)
    train_count = int(total_chars * train_ratio)
    test_count = int(total_chars * test_ratio)
    val_count = total_chars - train_count - test_count
    
    logger.info(f"数据集划分: 训练 {train_count} 个汉字, 测试 {test_count} 个, 验证 {val_count} 个")
    
    # 随机打乱汉字顺序
    chars = list(char_variants.keys())
    random.shuffle(chars)
    
    # 分配训练、测试和验证集
    train_chars = chars[:train_count]
    test_chars = chars[train_count:train_count+test_count]
    val_chars = chars[train_count+test_count:]
    
    # 复制文件到相应目录
    for char in train_chars:
        for variant in char_variants[char]:
            # 训练集B(手写字体)
            src = os.path.join(source_dir, variant)
            dst = os.path.join(output_dir, "train", "B", variant)
            shutil.copy2(src, dst)
            
            # 训练集A(标准字体) - 这里需要您提供标准字体图片
            # 暂时复制手写字体作为占位符
            dst = os.path.join(output_dir, "train", "A", variant)
            shutil.copy2(src, dst)
    
    for char in test_chars:
        for variant in char_variants[char]:
            # 测试集B
            src = os.path.join(source_dir, variant)
            dst = os.path.join(output_dir, "train", "test_B", variant)
            shutil.copy2(src, dst)
            
            # 测试集A
            dst = os.path.join(output_dir, "train", "test_A", variant)
            shutil.copy2(src, dst)
    
    for char in val_chars:
        for variant in char_variants[char]:
            # 验证集B
            src = os.path.join(source_dir, variant)
            dst = os.path.join(output_dir, "val", "B", variant)
            shutil.copy2(src, dst)
            
            # 验证集A
            dst = os.path.join(output_dir, "val", "A", variant)
            shutil.copy2(src, dst)
    
    logger.info("数据集准备完成")

if __name__ == "__main__":
    # 配置参数
    source_dir = "img_outputs"  # 手写字体图片目录
    output_dir = "dataset"      # 输出目录
    
    prepare_dataset(source_dir, output_dir)

结语

因为我已经把字帖写完了,所以最后就没有将训练好的字体文件排版后打印,本项目通过GAN算法,在本地+云端环境下生成一套属于自己的手写体文件,祝大家暑假快乐

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值