tf.nn.static_bidirectional_rnn ; tf.contrib.rnn.static_bidirectional_rnn 讲解

本文深入解析了tf.nn.static_bidirectional_rnn函数,详细介绍了其参数设置和返回值含义,包括前向和反向RNNCell实例的使用,以及如何处理序列长度和初始状态。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.nn.static_bidirectional_rnn h 和tf.contrib.rnn.static_bidirectional_rnn是一样的

tf.nn.static_bidirectional_rnn

别名:
tf.contrib.rnn.static_bidirectional_rnn
tf.nn.static_bidirectional_rnn

tf.nn.static_bidirectional_rnn(
    cell_fw,
    cell_bw,
    inputs,
    initial_state_fw=None,
    initial_state_bw=None,
    dtype=None,
    sequence_length=None,
    scope=None
)
'''
Args:
	cell_fw: An instance of RNNCell, to be used for forward direction.
	cell_bw: An instance of RNNCell, to be used for backward direction.
	inputs: A length T list of inputs, each a tensor of shape [batch_size, input_size], or a nested tuple of such elements.
	initial_state_fw: (optional) An initial state for the forward RNN. This must be a tensor of appropriate type and shape [batch_size, cell_fw.state_size]. If cell_fw.state_size is a tuple, this should be a tuple of tensors having shapes [batch_size, s] for s in cell_fw.state_size.
	initial_state_bw: (optional) Same as for initial_state_fw, but using the corresponding properties of cell_bw.
	dtype: (optional) The data type for the initial state. Required if either of the initial states are not provided.
	sequence_length: (optional) An int32/int64 vector, size [batch_size], containing the actual lengths for each of the sequences.
	scope: VariableScope for the created subgraph; defaults to "bidirectional_rnn"
Returns:
	A tuple (outputs, output_state_fw, output_state_bw) where: outputs is a length T list of outputs (one for each input), which are depth-concatenated forward and backward outputs. output_state_fw is the final state of the forward rnn. output_state_bw is the final state of the backward rnn.
	'''
参数说明

cell_fw: An instance of RNNCell, 前向RNNCell 实例
cell_bw: An instance of RNNCell, 反向RNNCell实例
inputs: 输入,是一个长度为 T列表, 每个元素是一个 a tensor(shape [batch_size, input_size])
initial_state_fw: (可选的)前向RNN的初始状态。这必须是类型和形状都满足的tensor([batch_size,cell_fw.state_size])。如果cell_fw.state_size是一个元组,那么它应该也是一个张量元组,元组的每个元素的shape为 [batch_size,s] for s in cell_fw.state_size
**initial_state_bw:**和initial_state_fw要求一样
sequence_length: 可选)int32/int64类型的向量大小为 [batch_size],包含每个序列的实际长度。

返回
一个元组:(outputs,output_state_fw,output_state_bw)
**outputs:**是长度T的list(每个时间步骤对应一个),这些大小为[batch_size,2*hidden_units],这里不理解的请看下面这个图
**output_state_fw:**是正向rnn的最后state(c,h)
**output_state_bw:**是向后rnn的最后state(c,h)

双向RNN

import os import difflib import numpy as np import tensorflow as tf import scipy.io.wavfile as wav from tqdm import tqdm from scipy.fftpack import fft from python_speech_features import mfcc from random import shuffle from keras import backend as K def data_hparams(): params = tf.contrib.training.HParams( # vocab data_type='train', data_path='data/', thchs30=True, aishell=True, prime=True, stcmd=True, batch_size=1, data_length=10, shuffle=True) return params class get_data(): def __init__(self, args): self.data_type = args.data_type self.data_path = args.data_path self.thchs30 = args.thchs30 self.aishell = args.aishell self.prime = args.prime self.stcmd = args.stcmd self.data_length = args.data_length self.batch_size = args.batch_size self.shuffle = args.shuffle self.source_init() def source_init(self): print('get source list...') read_files = [] if self.data_type == 'train': if self.thchs30 == True: read_files.append('thchs_train.txt') if self.aishell == True: read_files.append('aishell_train.txt') if self.prime == True: read_files.append('prime.txt') if self.stcmd == True: read_files.append('stcmd.txt') elif self.data_type == 'dev': if self.thchs30 == True: read_files.append('thchs_dev.txt') if self.aishell == True: read_files.append('aishell_dev.txt') elif self.data_type == 'test': if self.thchs30 == True: read_files.append('thchs_test.txt') if self.aishell == True: read_files.append('aishell_test.txt') self.wav_lst = [] self.pny_lst = [] self.han_lst = [] for file in read_files: print('load ', file, ' data...') sub_file = 'data/' + file with open(sub_file, 'r', encoding='utf8') as f: data = f.readlines() for line in tqdm(data): wav_file, pny, han = line.split('\t') self.wav_lst.append(wav_file) self.pny_lst.append(pny.split(' ')) self.han_lst.append(han.strip('\n')) if self.data_length: self.wav_lst = self.wav_lst[:self.data_length] self.pny_lst = self.pny_lst[:self.data_length] self.han_lst = self.han_lst[:self.data_length] print('make am vocab...') self.am_vocab = self.mk_am_vocab(self.pny_lst) print('make lm pinyin vocab...') self.pny_vocab = self.mk_lm_pny_vocab(self.pny_lst) print('make lm hanzi vocab...') self.han_vocab = self.mk_lm_han_vocab(self.han_lst) def get_am_batch(self): shuffle_list = [i for i in range(len(self.wav_lst))] while 1: if self.shuffle == True: shuffle(shuffle_list) for i in range(len(self.wav_lst) // self.batch_size): wav_data_lst = [] label_data_lst = [] begin = i * self.batch_size end = begin + self.batch_size sub_list = shuffle_list[begin:end] for index in sub_list: fbank = compute_fbank(self.data_path + self.wav_lst[index]) pad_fbank = np.zeros((fbank.shape[0] // 8 * 8 + 8, fbank.shape[1])) pad_fbank[:fbank.shape[0], :] = fbank label = self.pny2id(self.pny_lst[index], self.am_vocab) label_ctc_len = self.ctc_len(label) if pad_fbank.shape[0] // 8 >= label_ctc_len: wav_data_lst.append(pad_fbank) label_data_lst.append(label) pad_wav_data, input_length = self.wav_padding(wav_data_lst) pad_label_data, label_length = self.label_padding(label_data_lst) inputs = {'the_inputs': pad_wav_data, 'the_labels': pad_label_data, 'input_length': input_length, 'label_length': label_length, } outputs = {'ctc': np.zeros(pad_wav_data.shape[0], )} yield inputs, outputs def get_lm_batch(self): batch_num = len(self.pny_lst) // self.batch_size for k in range(batch_num): begin = k * self.batch_size end = begin + self.batch_size input_batch = self.pny_lst[begin:end] label_batch = self.han_lst[begin:end] max_len = max([len(line) for line in input_batch]) input_batch = np.array( [self.pny2id(line, self.pny_vocab) + [0] * (max_len - len(line)) for line in input_batch]) label_batch = np.array( [self.han2id(line, self.han_vocab) + [0] * (max_len - len(line)) for line in label_batch]) yield input_batch, label_batch def pny2id(self, line, vocab): return [vocab.index(pny) for pny in line] def han2id(self, line, vocab): return [vocab.index(han) for han in line] def wav_padding(self, wav_data_lst): wav_lens = [len(data) for data in wav_data_lst] wav_max_len = max(wav_lens) wav_lens = np.array([leng // 8 for leng in wav_lens]) new_wav_data_lst = np.zeros((len(wav_data_lst), wav_max_len, 200, 1)) for i in range(len(wav_data_lst)): new_wav_data_lst[i, :wav_data_lst[i].shape[0], :, 0] = wav_data_lst[i] return new_wav_data_lst, wav_lens def label_padding(self, label_data_lst): label_lens = np.array([len(label) for label in label_data_lst]) max_label_len = max(label_lens) new_label_data_lst = np.zeros((len(label_data_lst), max_label_len)) for i in range(len(label_data_lst)): new_label_data_lst[i][:len(label_data_lst[i])] = label_data_lst[i] return new_label_data_lst, label_lens def mk_am_vocab(self, data): vocab = [] for line in tqdm(data): line = line for pny in line: if pny not in vocab: vocab.append(pny) vocab.append('_') return vocab def mk_lm_pny_vocab(self, data): vocab = ['<PAD>'] for line in tqdm(data): for pny in line: if pny not in vocab: vocab.append(pny) return vocab def mk_lm_han_vocab(self, data): vocab = ['<PAD>'] for line in tqdm(data): line = ''.join(line.split(' ')) for han in line: if han not in vocab: vocab.append(han) return vocab def ctc_len(self, label): add_len = 0 label_len = len(label) for i in range(label_len - 1): if label[i] == label[i + 1]: add_len += 1 return label_len + add_len # 对音频文件提取mfcc特征 def compute_mfcc(file): fs, audio = wav.read(file) mfcc_feat = mfcc(audio, samplerate=fs, numcep=26) mfcc_feat = mfcc_feat[::3] mfcc_feat = np.transpose(mfcc_feat) return mfcc_feat # 获取信号的时频图 def compute_fbank(file): x = np.linspace(0, 400 - 1, 400, dtype=np.int64) w = 0.54 - 0.46 * np.cos(2 * np.pi * (x) / (400 - 1)) # 汉明窗 fs, wavsignal = wav.read(file) # wav波形 加时间窗以及时移10ms time_window = 25 # 单位ms wav_arr = np.array(wavsignal) range0_end = int(len(wavsignal) / fs * 1000 - time_window) // 10 + 1 # 计算循环终止的位置,也就是最终生成的窗数 data_input = np.zeros((range0_end, 200), dtype=np.float) # 用于存放最终的频率特征数据 data_line = np.zeros((1, 400), dtype=np.float) for i in range(0, range0_end): p_start = i * 160 p_end = p_start + 400 data_line = wav_arr[p_start:p_end] data_line = data_line * w # 加窗 data_line = np.abs(fft(data_line)) data_input[i] = data_line[0:200] # 设置为400除以2的值(即200)是取一半数据,因为是对称的 data_input = np.log(data_input + 1) # data_input = data_input[::] return data_input # word error rate------------------------------------ def GetEditDistance(str1, str2): leven_cost = 0 s = difflib.SequenceMatcher(None, str1, str2) for tag, i1, i2, j1, j2 in s.get_opcodes(): if tag == 'replace': leven_cost += max(i2-i1, j2-j1) elif tag == 'insert': leven_cost += (j2-j1) elif tag == 'delete': leven_cost += (i2-i1) return leven_cost # 定义解码器------------------------------------ def decode_ctc(num_result, num2word): result = num_result[:, :, :] in_len = np.zeros((1), dtype = np.int32) in_len[0] = result.shape[1] r = K.ctc_decode(result, in_len, greedy = True, beam_width=10, top_paths=1) r1 = K.get_value(r[0][0]) r1 = r1[0] text = [] for i in r1: text.append(num2word[i]) return r1, text 将这部分代码更新一下
最新发布
05-31
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值