fastspeech2复现github项目--模型训练

本文详细解析 FastSpeech2 的训练代码,涵盖损失函数、优化器、数据加载及处理等核心模块。针对 PyTorch 实现,适用于语音合成领域的研究者。

在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的仓库复现仓库是基于pytorch实现,链接为https://github.com/ming024/FastSpeech2。该仓库是基于https://github.com/xcmyz/FastSpeech中的FastSpeech复现代码完成的,很多代码基本一致。作者前期已对该FastSpeech复现仓库进行注释分析,感兴趣的读者可见此专栏

本笔记主要是基于LJSpeech数据集对FastSpeech2复现仓库代码进行注释分析,数据处理和模型搭建部分的代码分析可见笔记fastspeech2复现github项目–数据准备fastspeech2复现github项目–模型构建。本笔记主要对FastSpeech2模型训练相关代码进行注释分析,也附带贴出验证、生成等代码

model/loss.py

FastSpeech2在训练时会对duration predictor、pitch predictor和energy predictor同时训练,结合之前自回归模型均会对最后mel经过postnet处理的前后计算损失,故训练过程中会计算五个损失。loss.py文件中就定义了损失类

import torch
import torch.nn as nn


# 自定义的损失,整个模型的损失由五个不同损失组成,分别时是mel_loss,postnet_mel_loss,duration_loss,pitch_loss,energy_loss
class FastSpeech2Loss(nn.Module):
    """ FastSpeech2 Loss """

    def __init__(self, preprocess_config, model_config):
        super(FastSpeech2Loss, self).__init__()
        self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"]["feature"]
        self.energy_feature_level = preprocess_config["preprocessing"]["energy"]["feature"]
        self.mse_loss = nn.MSELoss()
        self.mae_loss = nn.L1Loss()

    def forward(self, inputs, predictions):
        mel_targets, _, _, pitch_targets, energy_targets, duration_targets = inputs[6:]  # 目标,相当于label
        mel_predictions, postnet_mel_predictions, pitch_predictions, energy_predictions, log_duration_predictions, _, \
        src_masks, mel_masks, _, _ = predictions  # 模型的输出
        src_masks = ~src_masks
        mel_masks = ~mel_masks
        log_duration_targets = torch.log(duration_targets.float() + 1)  # 对目标持续时间取log
        mel_targets = mel_targets[:, : mel_masks.shape[1], :]
        mel_masks = mel_masks[:, :mel_masks.shape[1]]

        log_duration_targets.requires_grad = False
        pitch_targets.requires_grad = False
        energy_targets.requires_grad = False
        mel_targets.requires_grad = False

        if self.pitch_feature_level == "phoneme_level":
            pitch_predictions = pitch_predictions.masked_select(src_masks)
            pitch_targets = pitch_targets.masked_select(src_masks)
        elif self.pitch_feature_level == "frame_level":
            pitch_predictions = pitch_predictions.masked_select(mel_masks)
            pitch_targets = pitch_targets.masked_select(mel_masks)

        if self.energy_feature_level == "phoneme_level":
            energy_predictions = energy_predictions.masked_select(src_masks)
            energy_targets = energy_targets.masked_select(src_masks)
        if self.energy_feature_level == "frame_level":
            energy_predictions = energy_predictions.masked_select(mel_masks)
            energy_targets = energy_targets.masked_select(mel_masks)

        log_duration_predictions = log_duration_predictions.masked_select(src_masks)
        log_duration_targets = log_duration_targets.masked_select(src_masks)

        mel_predictions = mel_predictions.masked_select(mel_masks.unsqueeze(-1))
        postnet_mel_predictions = postnet_mel_predictions.masked_select(mel_masks.unsqueeze(-1))
        mel_targets = mel_targets.masked_select(mel_masks.unsqueeze(-1))

        mel_loss = self.mae_loss(mel_predictions, mel_targets)  # 解码器预测的mel谱图的损失
        postnet_mel_loss = self.mae_loss(postnet_mel_predictions, mel_targets)  # 解码器预测的mel谱图经过postnet处理后的损失

        pitch_loss = self.mse_loss(pitch_predictions, pitch_targets)  # pitch损失
        energy_loss = self.mse_loss(energy_predictions, energy_targets)  # energy损失
        duration_loss = self.mse_loss(log_duration_predictions, log_duration_targets)  # duration损失

        total_loss = mel_loss + postnet_mel_loss + duration_loss + pitch_loss + energy_loss

        return (
            total_loss,
            mel_loss,
            postnet_mel_loss,
            pitch_loss,
            energy_loss,
            duration_loss,
        )

model/optimizer.py

该文件中封装了一个学习率优化类,其可以实现学习率动态变化,结合了退火处理

import torch
import numpy as np


# 为学习率更新封装的类
class ScheduledOptim:
    """ A simple wrapper class for learning rate scheduling """

    def __init__(self, model, train_config, model_config, current_step):

        self._optimizer = torch.optim.Adam(
            model.parameters(),
            betas=train_config["optimizer"]["betas"],
            eps=train_config["optimizer"]["eps"],
            weight_decay=train_config["optimizer"]["weight_decay"],
        )
        self.n_warmup_steps = train_config["optimizer"]["warm_up_step"]  # warmup的步数
        self.anneal_steps = train_config["optimizer"]["anneal_steps"]  # 退火步数
        self.anneal_rate = train_config["optimizer"]["anneal_rate"]  # 退火率
        self.current_step = current_step  # 训练时的当前步数
        self.init_lr = np.power(model_config["transformer"]["encoder_hidden"], -0.5)  # 初始学习率

    # 使用设置的学习率方案进行参数更新
    def step_and_update_lr(self):
        self._update_learning_rate()
        self._optimizer.step()

    # 清楚梯度
    def zero_grad(self):
        # print(self.init_lr)
        self._optimizer.zero_grad()

    # 加载保存的优化器参数
    def load_state_dict(self, path):
        self._optimizer.load_state_dict(path)

    # 学习率变化规则
    def _get_lr_scale(self):
        lr = np.min([np.power(self.current_step, -0.5),
                     np.power(self.n_warmup_steps, -1.5) * self.current_step])
        for s in self.anneal_steps:  # 如果当前训练步数大于设置的回火步数,进一步对学习率进行设置
            if self.current_step > s:
                lr = lr * self.anneal_rate
        return lr

    # 该学习方案中每步的学习率
    def _update_learning_rate(self):
        """ Learning rate scheduling per step """
        self.current_step += 1
        lr = self.init_lr * self._get_lr_scale()  # 计算当前步数的学习率
        # 给所有参数设置学习率
        for param_group in self._optimizer.param_groups:
            param_group["lr"] = lr

dataset.py

该文件主要用于数据加载和数据转换,将预处理好的文本音素、时序时间、mel谱图、pitch序列和energy序列等数据转换、加载为模型可以直接使用的形式。

import json
import math
import os

import numpy as np
from torch.utils.data import Dataset

from text import text_to_sequence
from utils.tools import pad_1D, pad_2D


class Dataset(Dataset):
    def __init__(self, filename, preprocess_config, train_config, sort=False, drop_last=False):
        self.dataset_name = preprocess_config["dataset"]
        self.preprocessed_path = preprocess_config["path"]["preprocessed_path"]
        self.cleaners = preprocess_config["preprocessing"]["text"]["text_cleaners"]
        self.batch_size = train_config["optimizer"]["batch_size"]

        self.basename, self.speaker, self.text, self.raw_text = self.process_meta(filename)  # 加载音频对应的文本数据
        with open(os.path.join(self.preprocessed_path, "speakers.json")) as f:
            self.speaker_map = json.load(f)
        self.sort = sort
        self.drop_last = drop_last

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):  # 通过下标索引获取数据
        basename = self.basename[idx]  # 文件的basaname
        speaker = self.speaker[idx]  # speaker名称,即数据集的名称
        speaker_id = self.speaker_map[speaker]  # speaker对应的数值序号
        raw_text = self.raw_text[idx]  # 原始文本
        phone = np.array(text_to_sequence(self.text[idx], self.cleaners))  # 文本处理后的音素序列
        mel_path = os.path.join(
            self.preprocessed_path,
            "mel",
            "{}-mel-{}.npy".format(speaker, basename),
        )
        mel = np.load(mel_path)  # 加载mel谱图
        pitch_path = os.path.join(
            self.preprocessed_path,
            "pitch",
            "{}-pitch-{}.npy".format(speaker, basename),
        )
        pitch = np.load(pitch_path)  # 加载pitch序列
        energy_path = os.path.join(
            self.preprocessed_path,
            "energy",
            "{}-energy-{}.npy".format(speaker, basename),
        )
        energy = np.load(energy_path)  # 加载energy序列
        duration_path = os.path.join(
            self.preprocessed_path,
            "duration",
            "{}-duration-{}.npy".format(speaker, basename),
        )
        duration = np.load(duration_path)  # 加载持续时间

        sample = {
   
   
            "id": basename,
            "speaker": speaker_id,
            "text": phone,
            "raw_text": raw_text,
            "mel": mel,
            "pitch": pitch,
            "energy": energy,
            "duration": duration,
        }

        return sample  # 返回数据

    # 加载每个音频对应的文本数据
    def process_meta(self, filename):
        with open(os.path.join(self.preprocessed_path, filename), "r", encoding="utf-8") as f:
            name = []
            speaker = []
            text = []
            raw_text = []
            for line in f.readlines():
                n, s, t, r = line.strip("\n").split("|")
                name.append(n)
                speaker.append(s)
                text.append(t)
                raw_text.append(r)
            return name, speaker, text, raw_text

    # 对数据进一步转换
    def reprocess(self, data, idxs):
        ids = [data[idx]["id"] for idx in idxs]
        speakers = [data[idx]["speaker"] for idx in idxs]
        texts = [data[idx]["text"] for idx in idxs]
        raw_texts = [data[idx]["raw_text"] for idx in idxs]
        mels = [data[idx]["mel"] for idx in idxs]
        pitches = [data[idx]["pitch"] for idx in idxs]
        energies = [data[idx]["energy"] for idx in idxs]
        durations = [data[idx]["duration"] for idx in idxs]

        text_lens = np.array([text.shape[0] for text in texts])  # 文本序列长度列表
        mel_lens = np.array([mel.shape[0] for mel in mels])  # mel图谱序列长度列表

        speakers = np.array(speakers)
        # 对一下的序列进行对应维度的pad
        texts = pad_1D(texts)
        mels = pad_2D(mels)
        pitches = pad_1D(pitches)
        energies = pad_1D(energies)
        durations = pad_1D(durations)

        return (
            ids,
            raw_texts,
            speakers,
            texts,
            text_lens,
            max(text_lens),
            mels,
            mel_lens,
            max(mel_lens),
            pitches,
            energies,
            durations,
        )

    # 定义数据集时使用的数据转换回调函数
    def collate_fn(self, data):
        data_size = len(data)

        if self.sort:  # 如果排序
            len_arr = np.array([d["text"].shape[0] for d in data])
            idx_arr = np.argsort(-len_arr)  # 返回文本序列长度从大到小排序的索引序列
        else:
            idx_arr = np.arange(data_size)

        # 当一个batch传入的数据量不是batch_size的整数倍时,tail就是最后不够一个batch_size的数据
        tail = idx_arr[len(idx_arr) - (len(idx_arr) % self.batch_size):]
        idx_arr = idx_arr[: len(idx_arr) - (len(idx_arr) % self.batch_size)]  # 前面batch_size的整数倍数据对应的序列列表
        idx_arr = idx_arr.reshape((-1, self.batch_size)).tolist()
        if n
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值