简介
最近忽然看到不是基于kaldi的ASR代码,尝试了一下发现效果还不错,搬上来记录一下。
源码地址:
https://pan.baidu.com/s/1tFlZkMJmrMTD05cd_zxmAg
提取码:ndrr
数据集需要自行下载。
1. 数据集
数据集使用的是清华大学的thchs30中文数据,data文件夹中包含(.wav文件和.trn文件;trn文件里存放的是.wav文件的文字描述:第一行为词,第二行为拼音,第三行为音素).
2. 模型预测
先直接解释有了训好的模型后如何使用,代码如下:
# -*- coding: utf-8 -*-
from keras.models import load_model
from keras import backend as K
import numpy as np
import librosa
from python_speech_features import mfcc
import pickle
import glob
wavs = glob.glob('A2_8.wav')
print(wavs)
with open('dictionary.pkl', 'rb') as fr:
[char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)
mfcc_dim = 13
model = load_model('asr.h5')
index = np.random.randint(len(wavs))
print(wavs[index])
## 读取数据,并去除掉没说话的起始结束时间
audio, sr = librosa.load(wavs[index])
energy = librosa.feature.rmse(audio)
frames = np.nonzero(energy >= np.max(energy) / 5)
indices = librosa.core.frames_to_samples(frames)[1]
audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14)
print(X_data.shape)
pred = model.predict(np.expand_dims(X_data, axis=0))
pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
pred_ids = pred_ids.flatten().tolist()
print(''.join([id2char[i] for i in pred_ids]))
3. 模型训练
模型采用了 TDNN 网络结构,并直接通过字符级别来预测,直接根据常见度将字符对应成数字标签。整个流程而言,
- 先将一个个语音样本变成MFCC特征,即一个样本的维度为time*num_MFCC,time维度将被补齐到batch里最长的time。
- 将批量样本送入网络,采用1d卷积,仅在时间轴上卷积,一个样本的输出维度为time*(num_words+1),加的1代表预测了空状态。
- 通过CTC Loss计算损失
# -*- coding: utf-8 -*-
#导入相关的库
from keras.models import Model
from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
from keras.optimizers import Adam, SGD
from keras import backend as K
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import random
import pickle
import glob
from tqdm import tqdm
import os
from python_speech_features import mfcc
import scipy.io.wavfile as wav
import librosa
from IPython.display import Audio
#读取数据集文件
text_paths = glob.glob('data/*.trn')
total = len(text_paths)
print(total)
with open(text_paths[0], 'r', encoding='utf8') as fr:
lines = fr.readlines()
print(lines)
#数据集文件trn内容读取保存到数组中
texts = []
paths = []
for path in text_paths:
with open(path, 'r', encoding='utf8') as fr:
lines = fr.readlines()
li

本文介绍了使用非Kaldi的ASR系统,基于清华大学的THCHS30中文数据集训练和预测的流程。首先,数据预处理为MFCC特征,然后利用TDNN网络结构和CTC损失函数进行模型训练。模型预测时,通过加载预训练模型,将MFCC特征输入模型,解码得到文字转录。最后,文章讨论了CTC在端到端ASR系统中的应用和发展趋势。
最低0.47元/天 解锁文章
9161

被折叠的 条评论
为什么被折叠?



