WaveNet 代码解析 —— train.py
文章目录
简介
本项目是一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址)
代码解析
全局变量解析
以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。
BATCH_SIZE = 1 # 一批训练集中,样本音频的数量
DATA_DIRECTORY = './VCTK-Corpus' # 下载的VCTK数据集的路径
LOGDIR_ROOT = './logdir' # 训练日志的路径
CHECKPOINT_EVERY = 50 # 保存训练模型的检查点数量
NUM_STEPS = int(1e5) # 训练的总次数
LEARNING_RATE = 1e-3 # 学习率
WAVENET_PARAMS = './wavenet_params.json' # WaveNet 模型的相关参数路径
STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now()) # 当前日期格式化
SAMPLE_SIZE = 100000 # 样本数量大小
L2_REGULARIZATION_STRENGTH = 0 # L2正则化中的系数
SILENCE_THRESHOLD = 0.3 # 音量阈值大小
EPSILON = 0.001 # 精度设置
MOMENTUM = 0.9 # 优化器动量
MAX_TO_KEEP = 5 # 保存的最大检查点数量
METADATA = False # 高级调试信息存储标志
函数解析
main()
下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。
def main():
# 解析命令行功能参数
args = get_arguments()
try:
# 验证并整理与目录有关的参数
directories = validate_directories(args)
except ValueError as e:
print("Some arguments are wrong:")
print(str(e))
return
# 将整理好的文件路径赋给相应变量
logdir = directories['logdir']
restore_from = directories['restore_from']
# 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练
is_overwritten_training = logdir != restore_from
# 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dict
with open(args.wavenet_params, 'r') as f:
wavenet_params = json.load(f)
# 创建线程协调器,多线程协调器相关知识可参考文章地址如下:
# https://blog.youkuaiyun.com/weixin_42721167/article/details/112795491
coord = tf.train.Coordinator()
# 从VCTK数据集中加载原始波形
with tf.name_scope('create_inputs'):
# 允许通过指定接近零的阈值跳过静默修剪
silence_threshold = args.silence_threshold if args.silence_threshold > \
EPSILON else None
gc_enabled = args.gc_channels is not None
# 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中
reader = AudioReader(
args.data_dir,
coord,
sample_rate=wavenet_params['sample_rate'],
gc_enabled=gc_enabled,
receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
wavenet_params["dilations"],
wavenet_params["scalar_input"],
wavenet_params["initial_filter_width"]),
sample_size=args.sample_size,
silence_threshold=silence_threshold)
# 准备好的音频出队列
audio_batch = reader.dequeue(args.batch_size)
if gc_enabled:
gc_id_batch = reader.dequeue_gc(args.batch_size)
else:
gc_id_batch = None
# 创建 WaveNet 网络
net = WaveNetModel(
batch_size=args.batch_size,
dilations=wavenet_params["dilations"],
filter_width=wavenet_params["filter_width"],
residual_channels=wavenet_params["residual_channels"],
dilation_channels=wavenet_params["dilation_channels"],
skip_channels=wavenet_params["skip_channels"],
quantization_channels=wavenet_params["quantization_channels"],
use_biases=wavenet_params["use_biases"],
scalar_input=wavenet_params[

本文详细解析了一个基于WaveNet的语音合成项目中的train.py文件,该项目使用TensorFlow实现。主要介绍了全局变量,如训练参数、数据路径等;主函数main()的流程,包括数据预处理、模型创建、训练、模型保存和恢复;以及get_arguments()函数,用于获取命令行参数。此外,还涉及了验证目录、模型保存和恢复的辅助函数。
最低0.47元/天 解锁文章
6001





