在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的仓库复现仓库是基于pytorch实现,链接为https://github.com/ming024/FastSpeech2。该仓库是基于https://github.com/xcmyz/FastSpeech中的FastSpeech复现代码完成的,很多代码基本一致。作者前期已对该FastSpeech复现仓库进行注释分析,感兴趣的读者可见此专栏。
通过论文可知,FastSpeech2模型整体架构与FastSpeech基本一致,只是除了Duration Predicator外,还增加了Pitch Predictor和Energy Predictor两部分,并且此三部分的网络架构是一样的。所以,本仓库中transformer路径下的文件基本与https://github.com/xcmyz/FastSpeech中基本一致,在搭建FastSpeech2模型时,主要使用到其中定义的Encoder, Decoder, PostNet模块,可以进入专栏中详细了解。在本仓库中,FastSpeech2模型搭建主要涉及的两个文件为fastspeech.py和model路径下的modules.py文件。
model/modules.py
本文件主要是定义Variance Adaptor,其中主要包括Duration Predictor、Length Regulator、Pitch Predictor和Energy Predictor,详细代码和注释解析如下所示
import os
import json
import copy
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils.tools import get_mask_from_lengths, pad
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 完整Variance Adaptor
class VarianceAdaptor(nn.Module):
"""Variance Adaptor"""
def __init__(self, preprocess_config, model_config):
super(VarianceAdaptor, self).__init__()
self.duration_predictor = VariancePredictor(model_config)
self.length_regulator = LengthRegulator()
self.pitch_predictor = VariancePredictor(model_config)
self.energy_predictor = VariancePredictor(model_config)
# 设置pitch和energy的级别
self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
assert self.energy_feature_level in ["phoneme_level", "frame_level"]
# 设置pitch和energy的量化方式
pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
energy_quantization = model_config["variance_embedding"]["energy_quantization"]
n_bins = model_config["variance_embedding"]["n_bins"]
assert pitch_quantization in ["linear", "log"]
assert energy_quantization in ["linear", "log"]
# 加载pitch和energy的正则化所需参数
with open(os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f:
stats = json.load(f)
pitch_min, pitch_max = stats["pitch"][:2]
energy_min, energy_max = stats["energy"][:2]
if pitch_quantization == "log":
self.pitch_bins = nn.Parameter(
torch.exp(torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)),
requires_grad=False,)
else:
self.pitch_bins = nn.Parameter(
torch.linspace(pitch_min, pitch_max, n_bins - 1),
requires_grad=False,)
if energy_quantization == "log":
self.energy_bins = nn.Parameter(
torch.exp(torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)),
requires_grad=False,)
else:
self.energy_bins = nn.Parameter(
torch.linspace(energy_min, energy_max, n_bins -

最低0.47元/天 解锁文章
2万+

被折叠的 条评论
为什么被折叠?



