AUTOVC 代码解析 —— data_loader.py
文章目录
简介
本项目一个基于 AUTOVC 模型的语音转换项目,它是使用 PyTorch 实现的(项目地址)。
AUTOVC 遵循自动编码器框架,只对自动编码器损耗进行训练,但它引入了精心调整的降维和时间下采样来约束信息流,这个简单的方案带来了显著的性能提高。(详情请参阅 AUTOVC 的详细介绍)。
由于 AUTOVC 项目较大,代码较多。为了方便学习与整理,将按照工程文件的结构依次介绍。
本文将介绍项目中的 data_loader.py 文件:设计了数据集,以及数据迭代器,使数据的读取变得非常简单,快捷。
类解析
Utterances
该类的意义为:自定义数据集,规范化数据的读取。
Utterances 继承 PyTorch 中数据集的基类: torch.utils.data.Dataset ;
故需要重载两个函数:__getitem__ 与 __len__ 。
下面依次介绍 Utterances 类的成员函数
__ init __
该函数的作用是: 创建 Utterances 自定义数据集需要的元素。
输入参数:
root_dir : 数据集文件目录
len_crop : 截取长度
输出参数: 无
代码详解:
def __init__(self, root_dir, len_crop):
""" 初始化和预处理话语数据集 """
# 创建变量 root_dir 保存数据集文件目录
self.root_dir = root_dir
# 创建变量 len_crop 保存截取长度
self.len_crop = len_crop
# 创建变量 step 保存步长
self.step = 10
# 将训练数据文件 train.pkl 的路径拼接
metaname = os.path.join(self.root_dir, "train.pkl")
# 使用 pickle.load 读取训练数据,保存于变量 meta
meta = pickle.load(open(metaname, "rb"))
""" 使用多处理加载数据 """
# 由 multiprocessing.Manager 返回的管理器对象控制一个服务器进程,该进程保存 Python 对象并允许其他进程使用代理操作它们
manager = Manager()
# 创建列表 meta 共享对象保存训练数据
meta = manager.list(meta)
# 创建一个与 meta 相同长度的空共享列表对象 dataset
dataset = manager.list(len(meta)*[None])
# 创建进程池列表 processes
processes = []
for i in range(0, len(meta), self.step):
# 通过创建一个 Process 对象创建新的进程
# 进程使用成员函数 load_data ,输入为部分训练数据、数据集地址与编号偏移量
p = Process(target=self.load_data,
args=(meta[i:i+self.step],dataset,i))
# 调用 start 方法启动进程
p.start()
# 将进程加入进程池
processes.append(p)
# 遍历进程池
for p in processes:
# 等待至进程终止
p.join()
# 将得到的数据集保存至成员列表变量 train_dataset
self.train_dataset = list(dataset)
# 计算数据集长度,并保存至成员变量 num_tokens
self.num_tokens = len(self.train_dataset)
# 打印结束语句至终端
print('Finished loading the dataset...')
load_data
该函数的作用是: 将训练元数据加载进数据集
输入参数:
submeta : 部分训练元数据
dataset : 数据集对象
idx_offset : 编号偏移量
输出参数: 无
代码详解:
def load_data(self, submeta, dataset, idx_offset):
# 遍历部分训练元数据 submeta ,k 为编号,sbmt 为当前说话人信息(包括说话人名、说话人编码、说话文件--梅尔频谱)的列表
for k, sbmt in enumerate(submeta):
# 创建与当前说话人信息列表相同长度的列表 uttrs
uttrs = len(sbmt)*[None]
# 遍历当前说话人信息列表 sbmt ,j 为编号,tmp 为当前信息
for j, tmp in enumerate(sbmt):
# 当 j=0 时,tmp 为说话人人名
# 当 j=1 时,tmp 为说话人编码
if j < 2:
# 保存说话人信息
uttrs[j] = tmp
# 当 j>2 时,tmp 为说话文件(梅尔频谱图)
else:
# 加载说话文件并保存
uttrs[j] = np.load(os.path.join(self.root_dir, tmp))
# 将组合的说话人信息保存至偏移后的正确位置
dataset[idx_offset+k] = uttrs
__ getitem __
该函数的作用是: 指定随机说话人,获取该说话人的说话人编码,以及长度为截取长度的说话数据
输入参数:
index : 说话人编号
输出参数:
uttr : 说话数据
emb_org : 说话人编码
代码详解:
def __getitem__(self, index):
# 取数据集
dataset = self.train_dataset
# 取出编号对应的说话人信息列表 list_uttrs
list_uttrs = dataset[index]
# 取出说话人编码 emb_org
emb_org = list_uttrs[1]
# 随机生成说话数据编号 a
a = np.random.randint(2, len(list_uttrs))
# 取出编号对应的说话数据 tmp
tmp = list_uttrs[a]
# 当说话数据 tmp 的长度小于截取长度 len_crop 时
if tmp.shape[0] < self.len_crop:
# 计算说话数据与截取长度的差距
len_pad = self.len_crop - tmp.shape[0]
# 在说话数据末尾填补说话数据,使说话数据长度与截取长度相等
uttr = np.pad(tmp, ((0,len_pad),(0,0)), 'constant')
# 当说话数据 tmp 的长度大于截取长度 len_crop 时
elif tmp.shape[0] > self.len_crop:
# 随机取说话数据左边界 left
left = np.random.randint(tmp.shape[0]-self.len_crop)
# 沿说话数据左边界截取长度为截取长度的说话数据片段
uttr = tmp[left:left+self.len_crop, :]
# 当说话数据 tmp 的长度等于截取长度 len_crop 时
else:
# 不做截取或填补处理,直接赋值
uttr = tmp
# 返回长度为截取长度的说话数据,以及说话人编码
return uttr, emb_org
__ len __
该函数的作用是: 获取数据集长度
输入参数: 无
输出参数:
self.num_tokens : 说话人数量
代码详解:
def __len__(self):
""" 返回说话人的数量 """
return self.num_tokens
函数解析
get_loader
该函数的作用是: 构建并返回数据迭代器
输入参数:
root_dir : 训练文件目录
batch_size : 批大小,默认为 16
len_crop : 截取长度,默认为 128
num_workers : 线程数量,默认为 0
输出参数:
data_loader : 返回数据迭代器
代码详解:
def get_loader(root_dir, batch_size=16, len_crop=128, num_workers=0):
# 构建自定义数据集,输入训练文件目录与截取长度
dataset = Utterances(root_dir, len_crop)
# worker_init_fn 将在每个 worker 子进程上以 worker id (int in [0, num_workers - 1]) 作为输入
worker_init_fn = lambda x: np.random.seed((torch.initial_seed()) % (2**32))
# 创建数据迭代器 data_loader
# 数据集为自定义数据集 dataset ,批大小为 batch_size ,设置打乱顺序加载
# 线程数量为 num_workers ,设置删除最后一个不完整的批处理,设置随机种子为 worker_init_fn
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
worker_init_fn=worker_init_fn)
# 返回数据迭代器 data_loader
return data_loader