错误方法:
def data_format(sess, image_array, label_array, slice_array):
train_image_array = []
train_label_array = []
for b in range(len(image_array)):
for r in range(len(slice_array[b])):
img = sess.run(tf.image.resize_images(
image_array[b:b + 1, slice_array[b][r][0]:slice_array[b][r][2] + slice_array[b][r][0],
slice_array[b][r][1]:slice_array[b][r][3] + slice_array[b][r][1], :], [32, 32], method=1))
# print(img_slice)
train_image_array.append(img[0])
train_label_array.append(label_array[b][r])
return np.asarray(train_image_array), np.asarray(train_label_array)
正确方法:
def data_format(sess, image_array, label_array, slice_array):
train_image_array = []
train_label_array = []
i_x = tf.placeholder(tf.float32, shape=[1, None, None, 3], name="i_x")
resize = tf.image.resize_images(i_x, size=[32, 32], method=1)
for b in range(len(image_array)):
for r in range(len(slice_array[b])):
img = sess.run(resize, feed_dict={i_x:
image_array[b:b + 1, slice_array[b][r][0]:slice_array[b][r][2] + slice_array[b][r][0],
slice_array[b][r][1]:slice_array[b][r][3] + slice_array[b][r][1], :]})
# print(img_slice)
train_image_array.append(img[0])
train_label_array.append(label_array[b][r])
return np.asarray(train_image_array), np.asarray(train_label_array)
错误方法之所以是错误,是因为它定义了非常多的tensor(有多少次循环就定义了多少个tensor),定义出来的tensor是要加载到图里面去的,tensor节点多了,会导致图非常大(容易内存溢出),且图加载tensor节点非常耗时。
正确方法之所以正确,是因为它只定义了一个tensor,其他的只是不断重复执行这个tensor而已。所以图非常简单。
其次:
tensor里面放到的数据只能是多维数组(一维,二维,三维。。。),不能是其他数据结构类型
853

被折叠的 条评论
为什么被折叠?



