在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/xcmyz/FastSpeech。该仓库使用的数据集为LJSpeech,数据处理部分的代码见笔记“fastspeech复现github项目–数据准备”、模型构建的代码见笔记“fastspeech复现github项目–模型构建”。本笔记对FastSpeech模型训练相关代码进行详细注释,主要代码是仓库中的dataset.py、loss.py、optimizer.py、train.py、eval.py。
dataset.py
该文件是主要用于数据加载和数据转换,将文本、持续时间和mel谱图序列加载封装至定义的BufferDataset对象中,然后定义回调函数collate_fn_tensor将对数据进行pad等操作,转换为模型训练所需的格式
import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import math
import time
import os
import hparams
import audio
from utils import process_text, pad_1D, pad_2D
from utils import pad_1D_tensor, pad_2D_tensor
from text import text_to_sequence
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_data_to_buffer():
buffer = list()
# 将全部的音频文本读取到一个列表对象中,text是一个列表,每一个元素是一个字符串,即一个音频对应的文本
text = process_text(os.path.join("data", "train.txt"))
start = time.perf_counter()
for i in tqdm(range(len(text))):
mel_gt_name = os.path.join(
hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i+1))
mel_gt_target = np.load(mel_gt_name) # 加载文本对应的音频文件的mel谱图
duration = np.load(os.path.join(
hparams.alignment_path, str(i)+".npy")) # 加载对应的持续时间
character = text[i][0:len(text[i])-1] # 删除最后的换行符
character = np.array(
text_to_sequence(character, hparams.text_cleaners)) # 将英文文本转换为数值序列,相当于分词
print(sum(duration))
# character和duration的长度一致,即duration中的i的值,表示character中i位置的数值出现的次数
character = torch.from_numpy(character)
# dutation中所有数值之和与mel的长度相等,即character经过duration调整后,文本长度将于mel谱图长度对齐
duration = torch.from_numpy(duration)
mel_gt_target = torch.from_numpy(mel_gt_target)
# 将一个音频文件的文本、持续时间和mel谱图数据组合成一个元组对象存在在列表中
buffer.append({
"text": character, "duration": duration,
"mel_target": mel_gt_target})
end = time.perf_counter()
print("cost {:.2f}s to load all data into buffer.".format(end-start))
return buffer
class BufferDataset(Dataset):
def __init__(self, buffer):
self.buffer = buffer # 加载所有数据
self.length_dataset = len(self.buffer) # 数据集总数量
def __len__(self):
return self.length_dataset
def __getitem__(self, idx):
return self.buffer[idx]
def reprocess_tensor(batch, cut_list):
'''
以传入的batch数据和对应的序列索引,给文本序列、mel谱图序列建立位置信息,同时将其封装在一起输出
@param batch:一个大batch的数据
@param cut_list:一个real batch大小的索引列表,其对应的文本长度从达到小降序排列
@return:
'''
texts = [batch[ind]["text"] for ind in cut_list] # batch中的文本
mel_targets = [batch[ind]["mel_target"] for ind in cut_list] # batch中的gt梅尔谱图
durations = [batch[ind]["duration"] for ind in cut_list] # batch中的duration时间
length_text = np.array([]) # 存储所有文本序列的长度大小
for text in texts:
length_text = np.append(length_text, text.size(0))
src_pos = list()
max_len = int(max(length_text)) # 最大文本长度
for length_src_row in length_text:
# 给每个文本生成src_pos,从1到文本的长度,如果长度小于max_len,对应部分用0填充
src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
(0, max_len-int(length_src_row)), 'constant'))
src_pos = torch.from_numpy(np.array(src_pos))
length_mel = np.array(list()) # 存储所有mel谱图序列的长度大小
for mel in mel_targets:
length_mel = np.append(length_mel, mel.size(0))
mel_pos = list()
max_mel_len = int(max(length_mel)) # 最大mel谱图序列长度
for length_mel_row in length_mel:
# 给每个mel谱图序列生成mel_pos,从1到序列的长度,如果长度小于max_mel_len,对应部分用0填充
mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
(0, max_mel_len-int(length_mel_row)), 'constant'))
mel_pos = torch.from_numpy(np.array(mel_pos))
texts = pad_1D_tensor(texts) # 将所有的文本都pad到文本的最大长度
durations = pad_1D_tensor(durations) # 将所有的duration持续时间pad到最大长度
mel_targets = pad_2D_tensor(mel_targets) # 将所有mel谱图序列pad到最大长度
out = {
"text": texts,
"mel_target": mel_targets,

本文详细介绍FastSpeech模型的复现过程,包括数据处理、模型构建及训练代码解析。涵盖PyTorch实现细节,如数据集加载、损失函数定义、优化器配置等。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



