fastspeech复现github项目--模型训练

本文详细介绍FastSpeech模型的复现过程,包括数据处理、模型构建及训练代码解析。涵盖PyTorch实现细节,如数据集加载、损失函数定义、优化器配置等。

在完成fastspeech论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/xcmyz/FastSpeech。该仓库使用的数据集为LJSpeech,数据处理部分的代码见笔记“fastspeech复现github项目–数据准备”、模型构建的代码见笔记“fastspeech复现github项目–模型构建”。本笔记对FastSpeech模型训练相关代码进行详细注释,主要代码是仓库中的dataset.py、loss.py、optimizer.py、train.py、eval.py。

dataset.py

该文件是主要用于数据加载和数据转换,将文本、持续时间和mel谱图序列加载封装至定义的BufferDataset对象中,然后定义回调函数collate_fn_tensor将对数据进行pad等操作,转换为模型训练所需的格式

import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import math
import time
import os

import hparams
import audio

from utils import process_text, pad_1D, pad_2D
from utils import pad_1D_tensor, pad_2D_tensor
from text import text_to_sequence
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def get_data_to_buffer():
    buffer = list()
    # 将全部的音频文本读取到一个列表对象中,text是一个列表,每一个元素是一个字符串,即一个音频对应的文本
    text = process_text(os.path.join("data", "train.txt"))

    start = time.perf_counter()
    for i in tqdm(range(len(text))):

        mel_gt_name = os.path.join(
            hparams.mel_ground_truth, "ljspeech-mel-%05d.npy" % (i+1))
        mel_gt_target = np.load(mel_gt_name)  # 加载文本对应的音频文件的mel谱图
        duration = np.load(os.path.join(
            hparams.alignment_path, str(i)+".npy"))  # 加载对应的持续时间
        character = text[i][0:len(text[i])-1]  # 删除最后的换行符
        character = np.array(
            text_to_sequence(character, hparams.text_cleaners))  # 将英文文本转换为数值序列,相当于分词
        print(sum(duration))

        # character和duration的长度一致,即duration中的i的值,表示character中i位置的数值出现的次数
        character = torch.from_numpy(character)
        # dutation中所有数值之和与mel的长度相等,即character经过duration调整后,文本长度将于mel谱图长度对齐
        duration = torch.from_numpy(duration)
        mel_gt_target = torch.from_numpy(mel_gt_target)
        # 将一个音频文件的文本、持续时间和mel谱图数据组合成一个元组对象存在在列表中
        buffer.append({
   
   "text": character, "duration": duration,
                       "mel_target": mel_gt_target})

    end = time.perf_counter()
    print("cost {:.2f}s to load all data into buffer.".format(end-start))

    return buffer


class BufferDataset(Dataset):
    def __init__(self, buffer):
        self.buffer = buffer  # 加载所有数据
        self.length_dataset = len(self.buffer)  # 数据集总数量

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, idx):
        return self.buffer[idx]


def reprocess_tensor(batch, cut_list):
    '''
    以传入的batch数据和对应的序列索引,给文本序列、mel谱图序列建立位置信息,同时将其封装在一起输出
    @param batch:一个大batch的数据
    @param cut_list:一个real batch大小的索引列表,其对应的文本长度从达到小降序排列
    @return:
    '''
    texts = [batch[ind]["text"] for ind in cut_list]  # batch中的文本
    mel_targets = [batch[ind]["mel_target"] for ind in cut_list]  # batch中的gt梅尔谱图
    durations = [batch[ind]["duration"] for ind in cut_list]  # batch中的duration时间

    length_text = np.array([])  # 存储所有文本序列的长度大小
    for text in texts:
        length_text = np.append(length_text, text.size(0))

    src_pos = list()
    max_len = int(max(length_text))  # 最大文本长度
    for length_src_row in length_text:
        # 给每个文本生成src_pos,从1到文本的长度,如果长度小于max_len,对应部分用0填充
        src_pos.append(np.pad([i+1 for i in range(int(length_src_row))],
                              (0, max_len-int(length_src_row)), 'constant'))
    src_pos = torch.from_numpy(np.array(src_pos))

    length_mel = np.array(list())  # 存储所有mel谱图序列的长度大小
    for mel in mel_targets:
        length_mel = np.append(length_mel, mel.size(0))

    mel_pos = list()
    max_mel_len = int(max(length_mel))  # 最大mel谱图序列长度
    for length_mel_row in length_mel:
        # 给每个mel谱图序列生成mel_pos,从1到序列的长度,如果长度小于max_mel_len,对应部分用0填充
        mel_pos.append(np.pad([i+1 for i in range(int(length_mel_row))],
                              (0, max_mel_len-int(length_mel_row)), 'constant'))
    mel_pos = torch.from_numpy(np.array(mel_pos))

    texts = pad_1D_tensor(texts)  # 将所有的文本都pad到文本的最大长度
    durations = pad_1D_tensor(durations)  # 将所有的duration持续时间pad到最大长度
    mel_targets = pad_2D_tensor(mel_targets)  # 将所有mel谱图序列pad到最大长度

    out = {
   
   "text": texts,
           "mel_target": mel_targets,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值