【项目实战】WaveNet 代码解析 —— train.py 【更新中】

本文详细解析了一个基于WaveNet的语音合成项目中的train.py文件,该项目使用TensorFlow实现。主要介绍了全局变量,如训练参数、数据路径等;主函数main()的流程,包括数据预处理、模型创建、训练、模型保存和恢复;以及get_arguments()函数,用于获取命令行参数。此外,还涉及了验证目录、模型保存和恢复的辅助函数。

WaveNet 代码解析 —— train.py

  简介

       本项目是一个基于 WaveNet 生成神经网络体系结构的语音合成项目,它是使用 TensorFlow 实现的(项目地址)。
       
       WaveNet神经网络体系结构能直接生成原始音频波形,在文本到语音和一般音频生成方面显示了出色的结果(详情请参阅 WaveNet 的详细介绍)。
       
       由于 WaveNet 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
       
       本文将介绍项目中的 train.py 文件:基于VCTK语料库的小波网络训练脚本。
       
       本脚本使用来自VCTK语料库的数据,用WaveNet训练网络(下载地址
       

  代码解析

    全局变量解析

       以下变量主要作为各功能参数的默认值,辅助开发人员对训练过程进行配置。

		BATCH_SIZE = 1								# 一批训练集中,样本音频的数量
		DATA_DIRECTORY = './VCTK-Corpus'			# 下载的VCTK数据集的路径
		LOGDIR_ROOT = './logdir'					# 训练日志的路径
		CHECKPOINT_EVERY = 50						# 保存训练模型的检查点数量
		NUM_STEPS = int(1e5)						# 训练的总次数
		LEARNING_RATE = 1e-3						# 学习率
		WAVENET_PARAMS = './wavenet_params.json'	# WaveNet 模型的相关参数路径
		STARTED_DATESTRING = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())				# 当前日期格式化
		SAMPLE_SIZE = 100000						# 样本数量大小
		L2_REGULARIZATION_STRENGTH = 0				# L2正则化中的系数
		SILENCE_THRESHOLD = 0.3						# 音量阈值大小
		EPSILON = 0.001								# 精度设置
		MOMENTUM = 0.9								# 优化器动量
		MAX_TO_KEEP = 5								# 保存的最大检查点数量
		METADATA = False							# 高级调试信息存储标志

    函数解析

      main()

        下面这段代码是 train.py 的主函数,主要作用是提取样本进行预处理、创建网络、训练模型、存取模型以及记录日志。

	def main():
	    # 解析命令行功能参数
	    args = get_arguments()
	
	    try:
	        # 验证并整理与目录有关的参数
	        directories = validate_directories(args)
	    except ValueError as e:
	        print("Some arguments are wrong:")
	        print(str(e))
	        return
	
	    # 将整理好的文件路径赋给相应变量
	    logdir = directories['logdir']
	    restore_from = directories['restore_from']
	
	    # 即使我们恢复了模型,如果训练的模型被写入到任意位置,我们也会把它当作新的训练
	    is_overwritten_training = logdir != restore_from
	
	    # 使用 josn 库的 load 函数读取 WaveNet 模型相关参数,将 json 格式的字符转换为 dict
	    with open(args.wavenet_params, 'r') as f:
	        wavenet_params = json.load(f)
	
	    # 创建线程协调器,多线程协调器相关知识可参考文章地址如下:
	    # https://blog.youkuaiyun.com/weixin_42721167/article/details/112795491
	    coord = tf.train.Coordinator()
	
	    # 从VCTK数据集中加载原始波形
	    with tf.name_scope('create_inputs'):
	        # 允许通过指定接近零的阈值跳过静默修剪
	        silence_threshold = args.silence_threshold if args.silence_threshold > \
	                                                      EPSILON else None
	        gc_enabled = args.gc_channels is not None
	        # 通用的后台音频读取器,对音频文件进行预处理并将它们排队到TensorFlow队列中
	        reader = AudioReader(
	            args.data_dir,
	            coord,
	            sample_rate=wavenet_params['sample_rate'],
	            gc_enabled=gc_enabled,
	            receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
	                                                                   wavenet_params["dilations"],
	                                                                   wavenet_params["scalar_input"],
	                                                                   wavenet_params["initial_filter_width"]),
	            sample_size=args.sample_size,
	            silence_threshold=silence_threshold)
	        # 准备好的音频出队列
	        audio_batch = reader.dequeue(args.batch_size)
	        if gc_enabled:
	            gc_id_batch = reader.dequeue_gc(args.batch_size)
	        else:
	            gc_id_batch = None
	
	    # 创建 WaveNet 网络
	    net = WaveNetModel(
	        batch_size=args.batch_size,
	        dilations=wavenet_params["dilations"],
	        filter_width=wavenet_params["filter_width"],
	        residual_channels=wavenet_params["residual_channels"],
	        dilation_channels=wavenet_params["dilation_channels"],
	        skip_channels=wavenet_params["skip_channels"],
	        quantization_channels=wavenet_params["quantization_channels"],
	        use_biases=wavenet_params["use_biases"],
	        scalar_input=wavenet_params[
评论 14
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值