api
embedding_lookup_sparse
tf.nn.embedding_lookup_sparse(
params,
sp_ids,
sp_weights,
partition_strategy='mod',
name=None,
combiner=None,
max_norm=None
)
params embedding使用的lookup table.
sp_ids 查找lookup table的SparseTensor.
combiner 通过什么运算把一行的数据结合起来mean, sum等.
闲话少说,直接上代码
import numpy as np
import tensorflow as tf
a = np.arange(8).reshape(2, 4)
b = np.arange(8, 16).reshape(2, 4)
c = np.arange(12, 20).reshape(2, 4)
print ("a :")
print (a)
print ("b :")
print (b)
print ("c :")
print (c)
a = tf.Variable(a, dtype=tf.float32)
b = tf.Variable(b, dtype=tf.float32)
c = tf.Variable(c, dtype=tf.float32)
idx = tf.SparseTensor(indices=[[0,0], [0,2], [1,0], [1, 1]], values=[1,2,2,0], dense_shape=(2,3))
result = tf.nn.embedding_lookup_sparse([a,b,c], idx, None, partition_strategy='div', combiner="sum")
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
res = sess.run(result)
print ("\n# result here")
print(res)
分析
根据embedding_lookup_sparse的params和partition_strategy参数分为4中场景
mod1: params=[a, c, b], partition_strategy='mod'
div1: params=[a, c, b], partition_strategy='div'
mod2: params=[a, b, c], partition_strategy='mod'
div2: params=[a, b, c], partition_strategy='div'
4中场景下,
params对应的索引对应结果如下:
# a : mod1 div1 mod2 div2
# [[0 1 2 3] id 0 0 0 0
# [4 5 6 7]] id 3 1 3 1
# b :
# [[ 8 9 10 11] id 2 4 1 2
# [12 13 14 15]] id 5 5 4 3
# c :
# [[12 13 14 15] id 1 2 2 4
# [16 17 18 19]] id 4 3 5 5
计算方法(下面有详解给出):
ps: mod场景 (0,5)% 3, tensor之间差len(params) = 3
div场景 (0,5) / 3, tensor之间为连续
sp_ids如下:
# [[1], [], [2]],
# [2], [0], []]
#
#结果
mod1:
tf.nn.embedding_lookup_sparse([a,c,b], idx, None, partition_strategy=‘mod’,combiner=“sum”)
查找索引矩阵根据partition_strategy(id % 3),得到如下
[[0, 3],
[1, 4],
[2, 5]]
# result here
[[20. 22. 24. 26.] = mod1[1] + mod1[2] = [8, 9, 10, 11] + [12, 13, 14, 15]
[ 8. 10. 12. 14.]] = mod1[2] + mod1[0] = [8, 9, 10, 11] + [0, 1, 2, 3]
div1
tf.nn.embedding_lookup_sparse([a,c,b], idx, None, partition_strategy=‘div’, combiner=“sum”)
查找索引矩阵根据partition_strategy(id / 3),得到如下
#div1
[[0, 1],
[2, 3],
[4, 5]]
# result here
[[16. 18. 20. 22.] = div1[1] + div1[2] = [4, 5, 6, 7] + [12, 13, 14, 15]
[12. 14. 16. 18.]] = div1[2] +div1[0] = [12, 13, 14, 15] + [0, 1, 2, 3]
mod2
tf.nn.embedding_lookup_sparse([a,b,c], idx, None, partition_strategy=‘mod’,combiner=“sum”)
查找索引矩阵根据partition_strategy(id % 3),得到如下
[[0, 3],
[1, 4],
[2, 5]]
# result here
[[20. 22. 24. 26.] = mod2[1] + mod2[2] = [8, 9, 10, 11] + [12, 13, 14, 15]
[12. 14. 16. 18.]] = mod2[2] + mod2[0] = [12, 13, 14, 15] + [0, 1, 2, 3]
div2
tf.nn.embedding_lookup_sparse([a,b,c], idx, None, partition_strategy=‘mod’,combiner=“sum”)
查找索引矩阵根据partition_strategy(id / 3),得到如下
[[0, 1],
[2, 3],
[4, 5]]
# result here
[[12. 14. 16. 18.] = div2[1] + div2[2] = [4, 5, 6, 8] + [8, 9, 10, 11]
[ 8. 10. 12. 14.]] = div2[2] + div2[0] = [8, 9, 10, 11] + [0, 1, 2, 3]
'''