整体介绍:
环境python3.6+TensorFlow1.12 显卡是英伟达GTX1070(后头换个好些的显卡)训练了四天四夜
主要技术点CTC,BRNN,MFCC特征,全连接神经网络
CTC时序分类算法: 适合这种不知道输入输出是否对齐的情况(哪个字对应哪段声音)使用的算法,所以CTC适合语音识别和手写字符识别的任务,而传统的语音识别是基于语音学的方法,通常包含拼写、声学和语音模型等单独组件。训练模型的语料除了标注具体的文字外。还要标注按时间对应的音素,这就需要大量的人工成本;使用神经网络的语音识别就变得简单多了,通过能进行时序分类的连续时间分类目标函数(CTC),计算多个标签序列的概率,而序列是语音样本中所有可能的对应文字的集合。然后把预测结果跟实际比较,计算误差,不断更新网络权重。这样就丢弃音素的概念,就省了大量人工标注的成本,也不需要语言模型,只要有足够的样本,就可以训练各种语言的语音识别了。
BRNN双向循环神经网络:可参考我的这篇博文
MFCC梅尔频率倒谱系数:在TensorFlow中的使用方式通俗的说法就是他讲语音按时序转化(转化过程可能比较复杂)为一帧接一帧的,然后每一帧都是一个维度是20或更多,本文中的代码中设置的是26维度,训练的最终目的就是讲每一帧通过他的味独特征来定义它是哪个字符或者更像哪个字符。
全连接神经网络 这个不用多说了
介绍完毕,说干就干
第一步收集数据:
在实际的项目中这一步我们可能要消耗不少精力去做,而此处我们没有必要,可以借鉴清华大学提供的数据进行模型的训练
下载地址 http://www.openslr.org/18/ 或 http://166.111.134.19:8081/data/thchs30-openslr/README.html
只需这一个6.4g的就够了,下载下来解压后是这样的:
data中是所有的数据集,train是训练集,test是测试集
每个文件夹中都是一个wav(语音文件)和对应的trn(语音对应的文字)文件,由于train和test文件中trn中记录的不是语音对应的文字而是一个对应语音文字所在data中的那个trn的文件地址,所以在在代码中我们我们训练集的语音使用train中的,寻找对应的trn文字时在data中找
第二部编写代码训练模型
# coding: UTF-8
# 训练数据下载地址 http://www.openslr.org/18/ 或 http://166.111.134.19:8081/data/thchs30-openslr/README.html
# 原博客中的程序将空格算入label中,导致训练程序中无法准确定位空格字符的特征(因为任何两个发音之间的空格的mfcc特征都不一样),
# 导致程序收敛极慢甚至无法收敛,手动将tran_texts中的空格都去除掉后程序开始收敛了
# 还有一点程序中输出的末尾的"龚"字是字符集中出现频率最小的那个,我们可以手动将字符集中的末尾再加一个空格words+=[""]
import numpy as np
from python_speech_features import mfcc
import scipy.io.wavfile as wav
import os
import time
import tensorflow as tf
from tensorflow.python.ops import ctc_ops
from collections import Counter
# 获取文件夹下所有的WAV文件
def get_wav_files(wav_path):
wav_files = []
for (dirpath, dirnames, filenames) in os.walk(wav_path):
for filename in filenames:
if filename.endswith('.wav') or filename.endswith('.WAV'):
# print(filename)
filename_path = os.path.join(dirpath, filename)
# print(filename_path)
wav_files.append(filename_path)
return wav_files
# 获取wav文件对应的翻译文字
def get_tran_texts(wavfiles, tran_path):
tran_texts = []
wav_files = []
for wav_file in wavfiles:
(wav_path, wav_filename) = os.path.split(wav_file)
tran_file = os.path.join(tran_path, wav_filename + '.trn')
# print(tran_file)
if os.path.exists(tran_file) is False:
return None
fd = open(tran_file, 'r', encoding='UTF-8')
text = fd.readline()
wav_files.append(wav_file)
# 不知为何原数据中的文字中有很多空格,去除空格干扰因子(通俗的解释就是空格的MFCC特性种类太多,因为任何两个文字语音间的空格特征都不一样,模型无法定位空格到底长啥样,导致模型收敛极慢甚至不收敛)
tran_texts.append(text.split('\n')[0].replace(' ',''))
fd.close()
return wav_files,tran_texts
# 获取wav和对应的翻译文字
def get_wav_files_and_tran_texts(wav_path, tran_path):
wavfiles = get_wav_files(wav_path)
wav_files,tran_texts = get_tran_texts(wavfiles, tran_path)
return wav_files, tran_texts
# 旧的训练集使用该方法获取音频文件名和译文
def get_wavs_lables(wav_path, label_file):
wav_files = []
for (dirpath, dirnames, filenames) in os.walk(wav_path):
for filename in filenames:
if filename.endswith('.wav') or filename.endswith('.WAV'):
filename_path = os.sep.join([dirpath, filename])
if os.stat(filename_path).st_size < 240000: # 剔除掉一些小文件
continue
wav_files.append(filename_path)
labels_dict = {}
with open(label_file, 'rb') as f:
for label in f:
label = label.strip(b'\n')
label_id = label.split(b' ', 1)[0]
label_text = label.split(b' ', 1)[1]
labels_dict[label_id.decode('ascii')] = label_text.decode('utf-8')
labels = []
new_wav_files = []
for wav_file in wav_files:
wav_id = os.path.basename(wav_file).split('.')[0]
if wav_id in labels_dict:
labels.append(labels_dict[wav_id])
new_wav_files.append(wav_file)
return new_wav_files, labels
# 将稀疏矩阵的字向量转成文字
# tuple是sparse_tuple_from函数的返回值
def sparse_tuple_to_texts_ch(tuple, words):
# 索引
indices = tuple[0]
values = tuple[1]
results = [''] * tuple[2][0]
for i in range(l