python broadcast机制的模拟实现

本文详细解析了TensorFlow中算术操作的广播机制,包括mul、add、sub等OP如何支持不同维度的计算,以及如何通过维度扩展实现广播。通过具体代码示例,展示了如何对维度大小不一致的数组进行扩展,获取输出维度,并进行数据广播。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tensorflow的算术操作:mul/add/sub等op都支持broadcast机制,该机制支持不同维度的计算,但是在对维度进行逆向比较时需要满足以下要求:

1)二者维度相同

2)二者维度有一个为1

3)如果维度大小不一致,需要用1来对维度小的数据进行扩展,在进行上述判断;

如:a:[256,256,3]、b:[3]这样的维度,需要先将b扩展至与a一致,将b扩展至[1,1,3],再对a、b数据进行mul/add/sub等计算,最后输出维度[256,256,3]

如果为了实现broadcast,可以进行以下操作进行模拟:

1)对维度大小不一致的数组进行维度扩展

2)获取输出维度,即broadcast的维度

3)进行数据广播

粗略代码如下(这里以四维数据为例,进行扩展):

import tensorflow as tf
import numpy as np

if __name__ == "__main__":
	input0_shape = [1,1,3,1]
	input1_shape = [3]

	#维度扩展
	input_len = len(input0_shape) - len(input1_shape)
	for i in range(input_len):
		input1_shape.insert(0,1)
	print input1_shape

	#获取broadcast shape
	broadcast_shape = [0] * len(input0_shape)
	for i in range(len(input0_shape)):
		broadcast_shape[i] = max(input0_shape[i],input1_shape[i])
	print broadcast_shape

	data_a = np.random.random(input0_shape)	#hwcn
	data_b = np.random.random(input1_shape) #h,w,c_out,c_in
	a = tf.placeholder("float")
	b = tf.placeholder("float")
	c = tf.add(a,b)
	with tf.Session() as sess:
		sess.run(tf.global_variables_initializer())
		out = sess.run(c, feed_dict={a: data_a,b:data_b})
		#print data_a
		print data_b
		print out.shape
		#print out - data_a

	res_pre = out - data_b    	#获取input0的扩展结果,用于验证实际值
	out_tf = res_pre.reshape(broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3])
	data_b_tmp = data_a.reshape(input0_shape[0]*input0_shape[1]*input0_shape[2]*input0_shape[3])
	print "out_tf"
	print out_tf
	f_dets = open("pre_data.dat", "w")
	for k in out_tf:
		b = float(k)
		a = '{:.10f}'.format(b)
		f_dets.write(str(a) + '\n')
	f_dets.close()

	out_res = [0]*broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]
	
	#进行数据扩展
	for i in range(broadcast_shape[0]):
		for j in range(broadcast_shape[1]):
			for k in range(broadcast_shape[2]):
				for m in range(broadcast_shape[3]):
					tmp_idx0 = i*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]  \
					          + j*broadcast_shape[2]*broadcast_shape[3] + k*broadcast_shape[3] + m
					ii = 0
					jj = 0
					kk = 0
					mm = 0
					if i >= input0_shape[0]:
						ii = input0_shape[0] -1
					else:
						ii = i

					if j >= input0_shape[1]:
						jj = input0_shape[1] -1
					else:
						jj = j

					if k >= input0_shape[2]:
						kk = input0_shape[2] -1
					else:
						kk = k

					if m >= input0_shape[3]:
						mm = input0_shape[3] -1
					else:
						mm = m

					tmp_idx1 = ii*input0_shape[1]*input0_shape[2]*input0_shape[3] \
								+ jj*input0_shape[2]*input0_shape[3] + kk*input0_shape[3] + mm
					#print mm
					out_res[tmp_idx0] = data_b_tmp[tmp_idx1]
	f_dets = open("aft_data.dat", "w")
	for k in out_res:
		b = float(k)
		a = '{:.10f}'.format(b)
		f_dets.write(str(a) + '\n')
	f_dets.close()
	#对比
	print "compare"
	print out_res - out_tf

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值