以此为核心,加上各种改进技巧,造就强大的seq2seq!Action!
import tensorflow as tf
class Seq2seq(object):
def __init__(self, config, w2i_target):
self.seq_inputs = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_inputs')
self.seq_inputs_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_inputs_length')
self.seq_targets = tf.placeholder(shape=(config.batch_size, None), dtype=tf.int32, name='seq_targets')
self.seq_targets_length = tf.placeholder(shape=(config.batch_size,), dtype=tf.int32, name='seq_targets_length')
with tf.variable_scope("encoder"):
encoder_embedding = tf.Variable(tf.random_uniform([config.source_vocab_size, config.embedding_dim]),
dtype=tf.float32, name='encoder_embedding')
#tf.nn.embedding_lookup的用法主要是选取一个张量里面索引对应的元素
encoder_inputs_embedded = tf.nn.embedding_lookup(encoder_embedding, self.seq_inputs)
encoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim) #定义GRU单元个数,num_units表示rnn cell中神经元个数
encoder_outputs, encoder_state = tf.nn.dynamic_rnn(cell=encoder_cell, inputs=encoder_inputs_embedded,
sequence_length=self.seq_inputs_length, dtype=tf.float32,
time_major=False)
tokens_go = tf.ones([config.batch_size], dtype=tf.int32) * w2i_target["_GO"]
# 把期望输出往后挪一下,然后前面加一个“_GO”的标记,比如期望输出序列是:“你瞅啥”,改完后为“_GO你瞅啥”
decoder_inputs = tf.concat([tf.reshape(tokens_go, [-1, 1]), self.seq_targets[:, :-1]], 1)
with tf.variable_scope("decoder"):
decoder_embedding = tf.Variable(tf.random_uniform([config.target_vocab_size, config.embedding_dim]),
dtype=tf.float32, name='decoder_embedding')
# tf.nn.embedding_lookup的用法主要是选取一个张量里面索引对应的元素
decoder_inputs_embedded = tf.nn.embedding_lookup(decoder_embedding, decoder_inputs)
decoder_cell = tf.nn.rnn_cell.GRUCell(config.hidden_dim)
decoder_outputs, decoder_state = tf.nn.dynamic_rnn(cell=decoder_cell, inputs=decoder_inputs_embedded,
initial_state=encoder_state,
sequence_length=self.seq_targets_length,dtype=tf.float32,
time_major=False)
# 全连接层 相当于添加一个层
decoder_logits = tf.layers.dense(decoder_outputs.rnn_output, config.target_vocab_size)
self.out = tf.argmax(decoder_logits, 2) #返回某个维度最大值的位置