RNN API

本文深入探讨了RNN和LSTM模型的构建过程,包括输入层的处理方式、权重与偏置的初始化方法、不同RNN结构的实现细节,以及如何利用动态RNN处理非固定长度的输入序列。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1输入:
1.1 固定输入(图像)

x = tf.placeholder(tf.float32, [None, height, width])
y = tf.placeholder(tf.float32)
# rnn模型
y_conv = rnn_graph(x, rnn_size, out_size, width, height)

1.2 非固定输入(诗句)

inputs = tf.placeholder(tf.int32, shape=(self.batch_size, None), name='inputs')
# 输出为预测某个字后续字符 故输出也不一致
targets = tf.placeholder(tf.int32, shape=(self.batch_size, None), name='targets')

lstm_inputs = self.embedding_variable(inputs, self.rnn_size, self.word_len)

@staticmethod
def embedding_variable(inputs, rnn_size, word_len):
    with tf.variable_scope('embedding'):
        # 这里选择使用cpu进行embedding
        with tf.device("/cpu:0"):
            # 默认使用'glorot_uniform_initializer'初始化,来自源码说明:
            # If initializer is `None` (the default), the default initializer passed in
            # the variable scope will be used. If that one is `None` too, a
            # `glorot_uniform_initializer` will be used.
            # 这里实际上是根据字符数量分别生成state_size长度的向量
            embedding = tf.get_variable('embedding', [word_len, rnn_size])
            # 根据inputs序列中每一个字符对应索引 在embedding中寻找对应向量,即字符转为连续向量:[字]==>[1]==>[0,1,0]
            lstm_inputs = tf.nn.embedding_lookup(embedding, inputs)
    return lstm_inputs

2.weight and bias

2.1

   w = weight_variable([rnn_size, out_size])
   b = bias_variable([out_size])

   def weight_variable(shape, w_alpha=0.01):
       '''
       增加噪音,随机生成权重
       :param shape:
       :param w_alpha:
       :return:
       '''
       initial = w_alpha * tf.random_normal(shape)
       return tf.Variable(initial)

   def bias_variable(shape, b_alpha=0.1):
       '''
       增加噪音,随机生成偏置项
       :param shape:
       :param b_alpha:
       :return:
       '''
       initial = b_alpha * tf.random_normal(shape)
       return tf.Variable(initial)

2.2

w, b = self.soft_max_variable(rnn_size, word_len)
logits = tf.matmul(x, w) + b

   def soft_max_variable(rnn_size, word_len):
       # 共享变量
       with tf.variable_scope('soft_max'):
           w = tf.get_variable("w", [rnn_size, word_len])
           b = tf.get_variable("b", [word_len])
       return w, b

   3. graph

3.1 标准RNN

lstm = tf.nn.rnn_cell.BasicLSTMCell(num_units=rnn_size)
drop = tf.nn.rnn_cell.DropoutWrapper(lstm, output_keep_prob=keep_prob)
# 多层cell 前一层cell作为后一层cell的输入
cell = tf.nn.rnn_cell.MultiRNNCell([drop] * 2)

lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, lstm_inputs, initial_state=initial_state)

3.2

# 这里RNN会有与输入层相同数量的输出层,我们只需要最后一个输出
outputs, status = tf.nn.static_rnn(lstm_cell, x, dtype=tf.float32)
y_conv = tf.add(tf.matmul(outputs[-1], w), b)

3.3

def build_encoder(self, encode_emb, length, train=True):
    batch_size = self.batch_size if train else 1
    with tf.variable_scope('encoder'):
        cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=self.num_units)
        initial_state = cell.zero_state(batch_size, tf.float32)
        _, final_state = tf.nn.dynamic_rnn(cell, encode_emb, initial_state=initial_state, sequence_length=length)
    return initial_state, final_state

encode_emb, decode_pre_emb, decode_post_emb = self.build_word_embedding(encode, decode_pre_x, decode_post_x)

initial_state, final_state = self.build_encoder(encode_emb, encode_length)

# 前一句decoder
pre_logits, pre_prediction, pre_state = self.build_decoder(decode_pre_emb, decode_pre_length, final_state, scope='decoder_pre')
pre_loss = self.build_loss(pre_logits, decode_pre_y, scope='decoder_pre_loss')
pre_optimizer = self.build_optimizer(pre_loss, scope='decoder_pre_op')

# 后一句decoder
post_logits, post_prediction, post_state = self.build_decoder(decode_post_emb, decode_post_length, final_state, scope='decoder_post', reuse=True)
post_loss = self.build_loss(post_logits, decode_post_y, scope='decoder_post_loss')
post_optimizer = self.build_optimizer(post_loss, scope='decoder_post_op')

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值