关于embedding的shape
之前读pointer-generator代码的时候一直对tensor的shape概念感到比较混沌,今天再读有了新的收获。
with tf.variable_scope('embedding'):
embedding = tf.get_variable('embedding', [vsize, hps.emb_dim], dtype = tf.float32, initializer = self.trunc_norm.init)
if hps.mode == "train": self._add_emb_vis(embedding) #add to tensorboard
emb_enc_input = tf.nn.embedding_lookup(embedding, self._enc_batch) #tensor with shape (batch_size, max_enc_steps, emb_size)
emb_dec_input = [tf.nn.embedding_lookup(embedding, x) for x in tf.unstack(self._dec_batch, axis=1)]#list length max_dec_steps containing shape (batch_size, emb_size)
其中涉及到的几个函数包括:
- tf.nn.embedding_lookup
tf.nn.embedding_lookup(
params,
ids,
partition_strategy='mod',
name=None,
validate_indices=True,
max_norm=None
)
作用是根据ids
在params
中查找对应的元素。相当于在np.array
中直接采用下标数组获取数据,需要注意的细节是返回的tensor的dtype
和传入的被查询的tensor的dtype
保持一致,和ids
的dtype
无关。
前面第一段代码中_enc_batch
的shape是(batch_size, max_enc_steps),经过此函数处理,得到的tensor的shape为(batch_size, max_enc_steps, emb_dim)。
- tf.unstack
tf.unstack(
value,
num=None,
axis=0,
name='unstack'
)
按照不同的维度进行矩阵分解,视axis=0
或axis=1
的不同而定。
前面第一段代码中将_dec_batch
按照axis=1
分解,则得到的list的长度为max_dec_steps,其中包含元素的shape为(batch_size, emb_dim)。