tf.stack() 详解 —》理解为主

本文深入解析了TensorFlow中的矩阵拼接方法tf.stack(),通过实例演示了如何使用该函数进行不同维度的矩阵堆叠,并对比了tf.stack()与tf.concat()的区别,以及与tf.transpose()的关系。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 tensorflow用于矩阵拼接的方法:tf.stack()

个人参考感觉还不错的一个理解(tf.stack() 和 tf.concat()的区别):https://blog.youkuaiyun.com/Gai_Nothing/article/details/88416782

 

def stack(values, axis=0, name="stack"):
    """Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor.

  Packs the list of tensors in `values` into a tensor with rank one higher than
  each tensor in `values`, by packing them along the `axis` dimension.
  Given a list of length `N` of tensors of shape `(A, B, C)`;

  if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
  if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
  Etc."""

    '''Args:
    values: A list of `Tensor` objects with the same shape and type.
    axis: An `int`. The axis to stack along. Defaults to the first dimension.
      Negative values wrap around, so the valid range is `[-(R+1), R+1)`.
    name: A name for this operation (optional).'''

 个人理解 ~ 测试:

import tensorflow as tf
import numpy as np

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# stack and unstack
stack_data1, stack_data2 = np.arange(1, 31).reshape([2, 3, 5])
print('stack_data1: \n', stack_data1)
print('stack_data1.shape: \n', stack_data1.shape)
print('stack_data2: \n', stack_data2)
print('stack_data2.shape: \n', stack_data2.shape)
# stack_data1:
#  [[ 1  2  3  4  5]
#  [ 6  7  8  9 10]
#  [11 12 13 14 15]]
# stack_data1.shape:
#  (3, 5)
# stack_data2:
#  [[16 17 18 19 20]
#  [21 22 23 24 25]
#  [26 27 28 29 30]]
# stack_data2.shape:
#  (3, 5)

# 理解:
#     举例:当前两个个张量的维度均为:(维1,维2, 维3, 维4), 此时axis的取值范围为:[-5, 5)
#     所以输入 stacks = [stack_data1, stack_data2], st = tf.stack(stacks, axis=?)
#     此时:
#           stacks的维度为:(2,维1,维2, 维3, 维4 )   维度为5,所以输出维度也为5, axis取值就在[-5, 5)
#           当axis=0时, st维度为:(2, 维1, 维2, 维3, 维4)
#           当axis=1时, st维度为:(维1, 2,维2, 维3, 维4)
#           当axis=2时, st维度为:(维1, 维2, 2,维3, 维4)
#           当axis=3时, st维度为:(维1, 维2, 维3,2,维4)
#           当axis=4时, st维度为:(维1, 维2, 维3,维4,2)

#           当axis=-5时, st维度为:(2, 维1, 维2, 维3, 维4)
#           当axis=-4时, st维度为:(维1, 2,维2, 维3, 维4)
#           当axis=-3时, st维度为:(维1, 维2, 2,维3, 维4)
#           当axis=-2时, st维度为:(维1, 维2, 维3,2,维4)
#           当axis=-1时, st维度为:(维1, 维2, 维3,维4,2)

print('======================================')
st_0 = tf.stack([stack_data1, stack_data2], axis=0)     # 2 * (3, 5) ==> (2, 3, 5)
st_0 = sess.run(st_0)
print('st_0: \n', st_0)
print('st_0.shape: \n', st_0.shape)
# st_0:
#  [[[ 1  2  3  4  5]
#   [ 6  7  8  9 10]
#   [11 12 13 14 15]]
#
#  [[16 17 18 19 20]
#   [21 22 23 24 25]
#   [26 27 28 29 30]]]
# st_0.shape:
#  (2, 3, 5)

print('======================================')
st_1 = tf.stack([stack_data1, stack_data2], axis=1)     # 2 * (3, 5) ==> (3, 2, 5)
st_1 = sess.run(st_1)
print('st_1: \n', st_1)
print('st_1.shape: \n', st_1.shape)
# st_1:
#  [[[ 1  2  3  4  5]
#   [16 17 18 19 20]]
#
#  [[ 6  7  8  9 10]
#   [21 22 23 24 25]]
#
#  [[11 12 13 14 15]
#   [26 27 28 29 30]]]
# st_1.shape:
#  (3, 2, 5)

print('======================================')
st_2 = tf.stack([stack_data1, stack_data2], axis=2)     # 2 * (3, 5) ==> (3, 5, 2)
st_2 = sess.run(st_2)
print('st_2: \n', st_2)
print('st_2.shape: \n', st_2.shape)
# st_2:
#  [[[ 1 16]
#   [ 2 17]
#   [ 3 18]
#   [ 4 19]
#   [ 5 20]]
#
#  [[ 6 21]
#   [ 7 22]
#   [ 8 23]
#   [ 9 24]
#   [10 25]]
#
#  [[11 26]
#   [12 27]
#   [13 28]
#   [14 29]
#   [15 30]]]
# st_2.shape:
#  (3, 5, 2)

print('======================================')
st_1_ = tf.stack([stack_data1, stack_data2], axis=-1)     # 2 * (3, 5) ==>  (3, 5, 2)   等同于st_2
st_1_ = sess.run(st_1_)
print('st_1_: \n', st_1_)
print('st_1_.shape: \n', st_1_.shape)
# st_1:
#  [[[ 1 16]
#   [ 2 17]
#   [ 3 18]
#   [ 4 19]
#   [ 5 20]]
#
#  [[ 6 21]
#   [ 7 22]
#   [ 8 23]
#   [ 9 24]
#   [10 25]]
#
#  [[11 26]
#   [12 27]
#   [13 28]
#   [14 29]
#   [15 30]]]
# st_1.shape:
#  (3, 5, 2)

print('=================比较st_1, 和 transpose=====================')
print('st_1: \n', st_1)
transpose_test = sess.run(tf.transpose(st_0, [1, 0, 2]))
print('transpose_test: \n', transpose_test)
print('transpose_test == st_1: \n', transpose_test == st_1)

print('=================比较st_2, 和 transpose=====================')
print('st_2: \n', st_2)
transpose_test = sess.run(tf.transpose(st_0, [1, 2, 0]))
print('transpose_test: \n', transpose_test)
print('transpose_test == st_2: \n', transpose_test == st_2)
# 总结:
#     tf.stack() 中 stacks = (2,维1,维2, 维3, 维4 )
#     当axis=0时, 就相当于tf.transpose(stacks, [0, 1, 2, 3, 4])
#     当axis=1时, 就相当于tf.transpose(stacks, [1, 0, 2, 3, 4])
#     当axis=2时, 就相当于tf.transpose(stacks, [1, 2, 0, 3, 4])
#     当axis=3时, 就相当于tf.transpose(stacks, [1, 2, 3, 0, 4])
#     当axis=0时, 就相当于tf.transpose(stacks, [1, 2, 3, 4, 0])


# 4 维测试:
stack_data1, stack_data2 = np.arange(1, 121).reshape([2, 3, 4, 5]) # (2, 3, 4, 5)
st_ = tf.stack([stack_data1, stack_data2], axis=3)
st_0 = tf.stack([stack_data1, stack_data2], axis=0)
st_ = sess.run(st_)
st_0 = sess.run(st_0)

tr_ = tf.transpose(st_0, [1, 2, 3, 0])
tr_ = sess.run(tr_)

print('st_.shape: ', st_.shape)
print('st_: ', st_)

print('tr_.shape: ', tr_.shape)
print('tr_: ', tr_)

print(st_ == tr_)

 

WARNING:tensorflow:From /root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/compat/v2_compat.py:96: disable_resource_variables (from tensorflow.python.ops.variable_scope) is deprecated and will be removed in a future version. Instructions for updating: non-resource variables are not supported in the long term 2025-07-26 19:47:00.316548: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1 2025-07-26 19:47:00.379323: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1561] Found device 0 with properties: pciBusID: 0000:39:00.0 name: NVIDIA GeForce RTX 4090 computeCapability: 8.9 coreClock: 2.52GHz coreCount: 128 deviceMemorySize: 23.55GiB deviceMemoryBandwidth: 938.86GiB/s 2025-07-26 19:47:00.379583: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory 2025-07-26 19:47:00.379632: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcublas.so.10'; dlerror: libcublas.so.10: cannot open shared object file: No such file or directory 2025-07-26 19:47:00.380958: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10 2025-07-26 19:47:00.381316: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10 2025-07-26 19:47:00.381386: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory 2025-07-26 19:47:00.381440: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcusparse.so.10'; dlerror: libcusparse.so.10: cannot open shared object file: No such file or directory 2025-07-26 19:47:00.381492: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcudnn.so.7'; dlerror: libcudnn.so.7: cannot open shared object file: No such file or directory 2025-07-26 19:47:00.381501: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1598] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... 2025-07-26 19:47:00.381919: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA 2025-07-26 19:47:00.396214: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 2000000000 Hz 2025-07-26 19:47:00.405365: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7f45c0000b60 initialized for platform Host (this does not guarantee that XLA will be used). Devices: 2025-07-26 19:47:00.405415: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version 2025-07-26 19:47:00.409166: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1102] Device interconnect StreamExecutor with strength 1 edge matrix: 2025-07-26 19:47:00.409199: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1108] WARNING:tensorflow:From /root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/layers/layers.py:1089: Layer.apply (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version. Instructions for updating: Please use `layer.__call__` method instead. loaded ./checkpoint/decom_net_train/model.ckpt loaded ./checkpoint/illumination_adjust_net_train/model.ckpt No restoration pre model! (480, 640, 3) (680, 720, 3) (415, 370, 3) Start evalating! 0 Traceback (most recent call last): File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1365, in _do_call return fn(*args) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1350, in _run_fn target_list, run_metadata) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1443, in _call_tf_sessionrun run_metadata) tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value Restoration_net/de_conv6_1/biases [[{{node Restoration_net/de_conv6_1/biases/read}}]] During handling of the above exception, another exception occurred: Traceback (most recent call last): File "evaluate.py", line 92, in <module> restoration_r = sess.run(output_r, feed_dict={input_low_r: decom_r_low, input_low_i: decom_i_low}) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 958, in run run_metadata_ptr) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1181, in _run feed_dict_tensor, options, run_metadata) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1359, in _do_run run_metadata) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1384, in _do_call raise type(e)(node_def, op, message) tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value Restoration_net/de_conv6_1/biases [[node Restoration_net/de_conv6_1/biases/read (defined at /root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/variables.py:256) ]] Original stack trace for 'Restoration_net/de_conv6_1/biases/read': File "evaluate.py", line 28, in <module> output_r = Restoration_net(input_low_r, input_low_i) File "/root/Python/KinD-master/KinD-master/model.py", line 70, in Restoration_net conv6=slim.conv2d(up6, 256,[3,3], rate=1, activation_fn=lrelu,scope='de_conv6_1') File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/arg_scope.py", line 184, in func_with_args return func(*args, **current_args) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/layers/layers.py", line 1191, in convolution2d conv_dims=2) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/arg_scope.py", line 184, in func_with_args return func(*args, **current_args) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/layers/layers.py", line 1089, in convolution outputs = layer.apply(inputs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 324, in new_func return func(*args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 1695, in apply return self.__call__(inputs, *args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 547, in __call__ outputs = super(Layer, self).__call__(inputs, *args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 758, in __call__ self._maybe_build(inputs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 2131, in _maybe_build self.build(input_shapes) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/keras/layers/convolutional.py", line 172, in build dtype=self.dtype) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/layers/base.py", line 460, in add_weight **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer_v1.py", line 447, in add_weight caching_device=caching_device) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 743, in _add_variable_with_custom_getter **kwargs_for_getter) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1573, in get_variable aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 1316, in get_variable aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 551, in get_variable return custom_getter(**custom_getter_kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/layers/layers.py", line 1793, in layer_variable_getter return _model_variable_getter(getter, *args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/layers/layers.py", line 1784, in _model_variable_getter aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/arg_scope.py", line 184, in func_with_args return func(*args, **current_args) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/variables.py", line 328, in model_variable aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/arg_scope.py", line 184, in func_with_args return func(*args, **current_args) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tf_slim/ops/variables.py", line 256, in variable aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 520, in _true_getter aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 939, in _get_single_variable aggregation=aggregation) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 259, in __call__ return cls._variable_v1_call(*args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 220, in _variable_v1_call shape=shape) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 198, in <lambda> previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variable_scope.py", line 2614, in default_variable_creator shape=shape) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 263, in __call__ return super(VariableMetaclass, cls).__call__(*args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1666, in __init__ shape=shape) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/variables.py", line 1854, in _init_from_args self._snapshot = array_ops.identity(self._variable, name="read") File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper return target(*args, **kwargs) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 282, in identity ret = gen_array_ops.identity(input, name=name) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3901, in identity "Identity", input=input, name=name) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 744, in _apply_op_helper attrs=attr_protos, op_def=op_def) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3327, in _create_op_internal op_def=op_def) File "/root/Python/conda_lit/kind/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1791, in __init__ self._traceback = tf_stack.extract_stack()
最新发布
07-27
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值