a = np.identity(5, dtype=np.int32)
a[0,0] = 0
a[0,1] = 1
#print(a)
tf.reset_default_graph()
ids = tf.placeholder(tf.int32, shape=None)
embedding = tf.Variable(a)
embedding_look = tf.nn.embedding_lookup(embedding, ids)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(embedding))
print("\n")
print(sess.run(embedding_look, feed_dict={ids:[1, 2, 3]}))
结果为:
[[0 1 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
[[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]]
当输入的ids为,
[[0, 1],
[1, 2]]
结果为:
[[0 1 0 0 0]
[0 1 0 0 0]
[0 0 1 0 0]
[0 0 0 1 0]
[0 0 0 0 1]]
[[[0 1 0 0 0]
[0 1 0 0 0]]
[[0 1 0 0 0]
[0 0 1 0 0]]]
此函数用ids中的元素值当做tensor的下标,去查找embedding数组(tensor)中的对应行,然后组成一个新的数组返回。
返回的shape为:[ids.shape, params.shape[1]] #params 即为上述代码中所使用的embedding 查找表
embedding_matrix is a tensor of shape [vocabulary_size, embedding size]