【语音识别】基于keras的简易语音识别

本文介绍了使用非Kaldi的ASR系统,基于清华大学的THCHS30中文数据集训练和预测的流程。首先,数据预处理为MFCC特征,然后利用TDNN网络结构和CTC损失函数进行模型训练。模型预测时,通过加载预训练模型,将MFCC特征输入模型,解码得到文字转录。最后,文章讨论了CTC在端到端ASR系统中的应用和发展趋势。


简介

最近忽然看到不是基于kaldi的ASR代码,尝试了一下发现效果还不错,搬上来记录一下。

源码地址:

https://pan.baidu.com/s/1tFlZkMJmrMTD05cd_zxmAg
提取码:ndrr
数据集需要自行下载。


1. 数据集

数据集使用的是清华大学的thchs30中文数据,data文件夹中包含(.wav文件和.trn文件;trn文件里存放的是.wav文件的文字描述:第一行为词,第二行为拼音,第三行为音素).

2. 模型预测

先直接解释有了训好的模型后如何使用,代码如下:

# -*- coding: utf-8 -*-
from keras.models import load_model
from keras import backend as K
import numpy as np
import librosa
from python_speech_features import mfcc
import pickle
import glob

wavs = glob.glob('A2_8.wav')
print(wavs)
with open('dictionary.pkl', 'rb') as fr:
    [char2id, id2char, mfcc_mean, mfcc_std] = pickle.load(fr)

mfcc_dim = 13
model = load_model('asr.h5')

index = np.random.randint(len(wavs))
print(wavs[index])

## 读取数据,并去除掉没说话的起始结束时间
audio, sr = librosa.load(wavs[index])
energy = librosa.feature.rmse(audio)
frames = np.nonzero(energy >= np.max(energy) / 5)
indices = librosa.core.frames_to_samples(frames)[1]
audio = audio[indices[0]:indices[-1]] if indices.size else audio[0:0]
X_data = mfcc(audio, sr, numcep=mfcc_dim, nfft=551)
X_data = (X_data - mfcc_mean) / (mfcc_std + 1e-14)
print(X_data.shape)

pred = model.predict(np.expand_dims(X_data, axis=0))
pred_ids = K.eval(K.ctc_decode(pred, [X_data.shape[0]], greedy=False, beam_width=10, top_paths=1)[0][0])
pred_ids = pred_ids.flatten().tolist()
print(''.join([id2char[i] for i in pred_ids]))

3. 模型训练

模型采用了 TDNN 网络结构,并直接通过字符级别来预测,直接根据常见度将字符对应成数字标签。整个流程而言,

  • 先将一个个语音样本变成MFCC特征,即一个样本的维度为time*num_MFCC,time维度将被补齐到batch里最长的time。
  • 将批量样本送入网络,采用1d卷积,仅在时间轴上卷积,一个样本的输出维度为time*(num_words+1),加的1代表预测了空状态。
  • 通过CTC Loss计算损失
# -*- coding: utf-8 -*-
#导入相关的库
from keras.models import Model
from keras.layers import Input, Activation, Conv1D, Lambda, Add, Multiply, BatchNormalization
from keras.optimizers import Adam, SGD
from keras import backend as K
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import random
import pickle
import glob
from tqdm import tqdm
import os

from python_speech_features import mfcc
import scipy.io.wavfile as wav
import librosa
from IPython.display import Audio



#读取数据集文件
text_paths = glob.glob('data/*.trn')
total = len(text_paths)
print(total)

with open(text_paths[0], 'r', encoding='utf8') as fr:
    lines = fr.readlines()
    print(lines)



#数据集文件trn内容读取保存到数组中
texts = []
paths = []
for path in text_paths:
    with open(path, 'r', encoding='utf8') as fr:
        lines = fr.readlines()
        li
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值