tensorflow的算术操作:mul/add/sub等op都支持broadcast机制,该机制支持不同维度的计算,但是在对维度进行逆向比较时需要满足以下要求:
1)二者维度相同
2)二者维度有一个为1
3)如果维度大小不一致,需要用1来对维度小的数据进行扩展,在进行上述判断;
如:a:[256,256,3]、b:[3]这样的维度,需要先将b扩展至与a一致,将b扩展至[1,1,3],再对a、b数据进行mul/add/sub等计算,最后输出维度[256,256,3]
如果为了实现broadcast,可以进行以下操作进行模拟:
1)对维度大小不一致的数组进行维度扩展
2)获取输出维度,即broadcast的维度
3)进行数据广播
粗略代码如下(这里以四维数据为例,进行扩展):
import tensorflow as tf
import numpy as np
if __name__ == "__main__":
input0_shape = [1,1,3,1]
input1_shape = [3]
#维度扩展
input_len = len(input0_shape) - len(input1_shape)
for i in range(input_len):
input1_shape.insert(0,1)
print input1_shape
#获取broadcast shape
broadcast_shape = [0] * len(input0_shape)
for i in range(len(input0_shape)):
broadcast_shape[i] = max(input0_shape[i],input1_shape[i])
print broadcast_shape
data_a = np.random.random(input0_shape) #hwcn
data_b = np.random.random(input1_shape) #h,w,c_out,c_in
a = tf.placeholder("float")
b = tf.placeholder("float")
c = tf.add(a,b)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(c, feed_dict={a: data_a,b:data_b})
#print data_a
print data_b
print out.shape
#print out - data_a
res_pre = out - data_b #获取input0的扩展结果,用于验证实际值
out_tf = res_pre.reshape(broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3])
data_b_tmp = data_a.reshape(input0_shape[0]*input0_shape[1]*input0_shape[2]*input0_shape[3])
print "out_tf"
print out_tf
f_dets = open("pre_data.dat", "w")
for k in out_tf:
b = float(k)
a = '{:.10f}'.format(b)
f_dets.write(str(a) + '\n')
f_dets.close()
out_res = [0]*broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]
#进行数据扩展
for i in range(broadcast_shape[0]):
for j in range(broadcast_shape[1]):
for k in range(broadcast_shape[2]):
for m in range(broadcast_shape[3]):
tmp_idx0 = i*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3] \
+ j*broadcast_shape[2]*broadcast_shape[3] + k*broadcast_shape[3] + m
ii = 0
jj = 0
kk = 0
mm = 0
if i >= input0_shape[0]:
ii = input0_shape[0] -1
else:
ii = i
if j >= input0_shape[1]:
jj = input0_shape[1] -1
else:
jj = j
if k >= input0_shape[2]:
kk = input0_shape[2] -1
else:
kk = k
if m >= input0_shape[3]:
mm = input0_shape[3] -1
else:
mm = m
tmp_idx1 = ii*input0_shape[1]*input0_shape[2]*input0_shape[3] \
+ jj*input0_shape[2]*input0_shape[3] + kk*input0_shape[3] + mm
#print mm
out_res[tmp_idx0] = data_b_tmp[tmp_idx1]
f_dets = open("aft_data.dat", "w")
for k in out_res:
b = float(k)
a = '{:.10f}'.format(b)
f_dets.write(str(a) + '\n')
f_dets.close()
#对比
print "compare"
print out_res - out_tf