在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的仓库复现仓库是基于pytorch实现,链接为https://github.com/ming024/FastSpeech2。该仓库是基于https://github.com/xcmyz/FastSpeech中的FastSpeech复现代码完成的,很多代码基本一致。作者前期已对该FastSpeech复现仓库进行注释分析,感兴趣的读者可见此专栏。
本笔记主要是基于LJSpeech数据集对FastSpeech2复现仓库代码进行注释分析,数据处理和模型搭建部分的代码分析可见笔记fastspeech2复现github项目–数据准备和fastspeech2复现github项目–模型构建。本笔记主要对FastSpeech2模型训练相关代码进行注释分析,也附带贴出验证、生成等代码
文章目录
model/loss.py
FastSpeech2在训练时会对duration predictor、pitch predictor和energy predictor同时训练,结合之前自回归模型均会对最后mel经过postnet处理的前后计算损失,故训练过程中会计算五个损失。loss.py文件中就定义了损失类
import torch
import torch.nn as nn
# 自定义的损失,整个模型的损失由五个不同损失组成,分别时是mel_loss,postnet_mel_loss,duration_loss,pitch_loss,energy_loss
class FastSpeech2Loss(nn.Module):
""" FastSpeech2 Loss """
def __init__(self, preprocess_config, model_config):
super(FastSpeech2Loss, self).__init__()
self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
def forward(self, inputs, predictions):
mel_targets, _, _, pitch_targets, energy_targets, duration_targets = inputs[6:] # 目标,相当于label
mel_predictions, postnet_mel_predictions, pitch_predictions, energy_predictions, log_duration_predictions, _, \
src_masks, mel_masks, _, _ = predictions # 模型的输出
src_masks = ~src_masks
mel_masks = ~mel_masks
log_duration_targets = torch.log(duration_targets.float() + 1) # 对目标持续时间取log
mel_targets = mel_targets[:, : mel_masks.shape[1], :]
mel_masks = mel_masks[:, :mel_masks.shape[1]]
log_duration_targets.requires_grad = False
pitch_targets.requires_grad = False
energy_targets.requires_grad = False
mel_targets.requires_grad = False
if self.pitch_feature_level == "phoneme_level":
pitch_predictions = pitch_predictions.masked_select(src_masks)
pitch_targets = pitch_targets.masked_select(src_masks)
elif self.pitch_feature_level == "frame_level":
pitch_predictions = pitch_predictions.masked_select(mel_masks)
pitch_targets = pitch_targets.masked_select(mel_masks)
if self.energy_feature_level == "phoneme_level":
energy_predictions = energy_predictions.masked_select(src_masks)
energy_targets = energy_targets.masked_select(src_masks)
if self.energy_feature_level == "frame_level":
energy_predictions = energy_predictions.masked_select(mel_masks)
energy_targets = energy_targets.masked_select(mel_masks)
log_duration_predictions = log_duration_predictions.masked_select(src_masks)
log_duration_targets = log_duration_targets.masked_select(src_masks)
mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
postnet_mel_predictions = postnet_mel_predictions.masked_select(mel_masks.unsqueeze(-1))
mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))
mel_loss = self.mae_loss(mel_predictions, mel_targets) # 解码器预测的mel谱图的损失
postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets) # 解码器预测的mel谱图经过postnet处理后的损失
pitch_loss = self.mse_loss(pitch_predictions, pitch_targets) # pitch损失
energy_loss = self.mse_loss(energy_predictions, energy_targets) # energy损失
duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets) # duration损失
total_loss = mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss
return (
total_loss,
mel_loss,
postnet_mel_loss,
pitch_loss,
energy_loss,
duration_loss,
)
model/optimizer.py
该文件中封装了一个学习率优化类,其可以实现学习率动态变化,结合了退火处理
import torch
import numpy as np
# 为学习率更新封装的类
class ScheduledOptim:
""" A simple wrapper class for learning rate scheduling """
def __init__(self, model, train_config, model_config, current_step):
self._optimizer = torch.optim.Adam(
model.parameters(),
betas=train_config["optimizer"]["betas"],
eps=train_config["optimizer"]["eps"],
weight_decay=train_config["optimizer"]["weight_decay"],
)
self.n_warmup_steps = train_config["optimizer"]["warm_up_step"] # warmup的步数
self.anneal_steps = train_config["optimizer"]["anneal_steps"] # 退火步数
self.anneal_rate = train_config["optimizer"]["anneal_rate"] # 退火率
self.current_step = current_step # 训练时的当前步数
self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5) # 初始学习率
# 使用设置的学习率方案进行参数更新
def step_and_update_lr(self):
self._update_learning_rate()
self._optimizer.step()
# 清楚梯度
def zero_grad(self):
# print(self.init_lr)
self._optimizer.zero_grad()
# 加载保存的优化器参数
def load_state_dict(self, path):
self._optimizer.load_state_dict(path)
# 学习率变化规则
def _get_lr_scale(self):
lr = np.min([np.power(self.current_step, -0.5),
np.power(self.n_warmup_steps, -1.5) * self.current_step])
for s in self.anneal_steps: # 如果当前训练步数大于设置的回火步数,进一步对学习率进行设置
if self.current_step > s:
lr = lr * self.anneal_rate
return lr
# 该学习方案中每步的学习率
def _update_learning_rate(self):
""" Learning rate scheduling per step """
self.current_step += 1
lr = self.init_lr * self._get_lr_scale() # 计算当前步数的学习率
# 给所有参数设置学习率
for param_group in self._optimizer.param_groups:
param_group["lr"] = lr
dataset.py
该文件主要用于数据加载和数据转换,将预处理好的文本音素、时序时间、mel谱图、pitch序列和energy序列等数据转换、加载为模型可以直接使用的形式。
import json
import math
import os
import numpy as np
from torch.utils.data import Dataset
from text import text_to_sequence
from utils.tools import pad_1D, pad_2D
class Dataset(Dataset):
def __init__(self, filename, preprocess_config, train_config, sort=False, drop_last=False):
self.dataset_name = preprocess_config["dataset"]
self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
self.batch_size = train_config["optimizer"]["batch_size"]
self.basename, self.speaker, self.text, self.raw_text = self.process_meta(filename) # 加载音频对应的文本数据
with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
self.speaker_map = json.load(f)
self.sort = sort
self.drop_last = drop_last
def __len__(self):
return len(self.text)
def __getitem__(self, idx): # 通过下标索引获取数据
basename = self.basename[idx] # 文件的basaname
speaker = self.speaker[idx] # speaker名称,即数据集的名称
speaker_id = self.speaker_map[speaker] # speaker对应的数值序号
raw_text = self.raw_text[idx] # 原始文本
phone = np.array(text_to_sequence(self.text[idx], self.cleaners)) # 文本处理后的音素序列
mel_path = os.path.join(
self.preprocessed_path,
"mel",
"{}-mel-{}.npy".format(speaker, basename),
)
mel = np.load(mel_path) # 加载mel谱图
pitch_path = os.path.join(
self.preprocessed_path,
"pitch",
"{}-pitch-{}.npy".format(speaker, basename),
)
pitch = np.load(pitch_path) # 加载pitch序列
energy_path = os.path.join(
self.preprocessed_path,
"energy",
"{}-energy-{}.npy".format(speaker, basename),
)
energy = np.load(energy_path) # 加载energy序列
duration_path = os.path.join(
self.preprocessed_path,
"duration",
"{}-duration-{}.npy".format(speaker, basename),
)
duration = np.load(duration_path) # 加载持续时间
sample = {
"id": basename,
"speaker": speaker_id,
"text": phone,
"raw_text": raw_text,
"mel": mel,
"pitch": pitch,
"energy": energy,
"duration": duration,
}
return sample # 返回数据
# 加载每个音频对应的文本数据
def process_meta(self, filename):
with open(os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8") as f:
name = []
speaker = []
text = []
raw_text = []
for line in f.readlines():
n, s, t, r = line.strip("\n").split("|")
name.append(n)
speaker.append(s)
text.append(t)
raw_text.append(r)
return name, speaker, text, raw_text
# 对数据进一步转换
def reprocess(self, data, idxs):
ids = [data[idx]["id"] for idx in idxs]
speakers = [data[idx]["speaker"] for idx in idxs]
texts = [data[idx]["text"] for idx in idxs]
raw_texts = [data[idx]["raw_text"] for idx in idxs]
mels = [data[idx]["mel"] for idx in idxs]
pitches = [data[idx]["pitch"] for idx in idxs]
energies = [data[idx]["energy"] for idx in idxs]
durations = [data[idx]["duration"] for idx in idxs]
text_lens = np.array([text.shape[0] for text in texts]) # 文本序列长度列表
mel_lens = np.array([mel.shape[0] for mel in mels]) # mel图谱序列长度列表
speakers = np.array(speakers)
# 对一下的序列进行对应维度的pad
texts = pad_1D(texts)
mels = pad_2D(mels)
pitches = pad_1D(pitches)
energies = pad_1D(energies)
durations = pad_1D(durations)
return (
ids,
raw_texts,
speakers,
texts,
text_lens,
max(text_lens),
mels,
mel_lens,
max(mel_lens),
pitches,
energies,
durations,
)
# 定义数据集时使用的数据转换回调函数
def collate_fn(self, data):
data_size = len(data)
if self.sort: # 如果排序
len_arr = np.array([d["text"].shape[0] for d in data])
idx_arr = np.argsort(-len_arr) # 返回文本序列长度从大到小排序的索引序列
else:
idx_arr = np.arange(data_size)
# 当一个batch传入的数据量不是batch_size的整数倍时,tail就是最后不够一个batch_size的数据
tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size):]
idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)] # 前面batch_size的整数倍数据对应的序列列表
idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
if n

本文详细解析 FastSpeech2 的训练代码,涵盖损失函数、优化器、数据加载及处理等核心模块。针对 PyTorch 实现,适用于语音合成领域的研究者。
最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



