tensorflow根据输入更改tensor shape

TF中动态shape实践
本文介绍TensorFlow中如何使用tf.shape实现动态shape处理,适用于构建依赖于输入尺寸的随机数生成器或类RNN网络。文章通过具体实例展示如何根据输入的shape来决定随机数的shape。

涉及随机数以及类RNN的网络构建常常需要根据输入shape,决定中间变量的shape或步长。
tf.shape函数不同于tensor.shape.as_list()函数,后者返回的是常值list,而前者返回的是tensor。使用tf.shape函数可以使得中间变量的tensor形状随输入变化,不需要在构建Graph的时候指定。但对于tf.Variable,因为需要提前分配固定空间,其shape无法通过上诉方法设定。
实例代码如下:

a = tf.placeholder(tf.float32,[None,])
b = tf.random_normal(tf.concat([tf.shape(a),[2,]],axis=0))
<think>我们正在处理一个TensorFlow模型,并希望修改其输入形状。根据引用[2]中的代码示例,我们看到在保存为pb文件之前,模型图是通过`tf.graph_util.convert_variables_to_constants`转换的。但是,修改输入形状通常需要在构建模型时指定,或者在加载模型后重建输入层。在TensorFlow中,修改已保存模型的输入形状有几种方法:1.在模型构建时指定:如果我们有模型的源代码,可以直接修改输入占位符的形状,然后重新保存模型。2.使用图重写:对于已经保存的模型(如pb文件),我们可以通过TensorFlow的图操作来修改输入形状。这通常涉及以下步骤:-加载模型图。-找到输入节点,并修改其形状。-由于形状变化可能影响后续节点的形状,因此需要传播形状信息(使用`tf.graph_util.import_graph_def`和`tf.compat.v1.graph_util.remove_training_nodes`等工具)。-保存修改后的图。然而,直接修改pb文件中的输入形状并不简单,因为后续节点的计算可能依赖于原来的形状。因此,我们需要确保整个图的形状重新推断(shapeinference)正确。下面是一个示例方法,使用TensorFlow的图操作来修改输入形状:步骤:a.加载pb文件中的模型图。b.创建一个新的输入节点,其名称与原输入节点相同,但形状不同。c.将图中所有对原输入节点的引用替换为新的输入节点。d.进行形状推断,确保后续节点的形状正确。e.保存修改后的图。但是,这个过程较为复杂,且容易出错。另一种更简单的方法是在加载模型时使用`tf.keras`(如果模型是Keras模型)或者使用`tf.saved_model` API,并指定新的输入形状。然而,根据引用[1]和[2],我们使用的是GraphDef(pb文件),所以我们可以尝试以下方法:方法一:使用`tf.keras.models.load_model`(仅适用于Keras模型保存的格式,如SavedModel格式,不直接适用于pb文件)方法二:使用`tf.graph_util.import_graph_def`并配合自定义的输入映射。具体步骤(方法二):1.读取pb文件,解析为GraphDef对象。2.创建一个新的TensorFlow图。3.使用`tf.import_graph_def`导入GraphDef,但通过`input_map`参数替换输入节点为一个新的占位符(具有新的形状)。4.然后保存修改后的图。示例代码:假设原输入节点名为"Placeholder:0",新形状为[None,new_height, new_width,channels](对于图像输入)。注意:这种方法要求我们知道输入节点的确切名称,并且新形状必须与模型后续层兼容(例如,卷积层的步长和核大小不能导致非法形状)。代码示例:```pythonimport tensorflow astffrom tensorflow.python.framework importtensor_util#加载现有的pb文件withtf.io.gfile.GFile('model.pb', 'rb')as f:graph_def= tf.compat.v1.GraphDef()graph_def.ParseFromString(f.read())#创建一个新的图withtf.Graph().as_default() asgraph:#定义新的输入占位符,具有新的形状#假设原输入节点名为'Placeholder:0',新形状为[None,256,256,3]new_input =tf.compat.v1.placeholder(tf.float32,shape=[None,256,256,3], name='Placeholder')#导入图,并用新的输入替换原来的输入#注意:原输入节点在graph_def中,我们将其替换为new_inputtf.import_graph_def(graph_def,input_map={'Placeholder:0':new_input},#将原输入节点映射到新的输入name='')#获取输出节点(假设输出节点名为'output:0')output =graph.get_tensor_by_name('output:0')#保存修改后的图with tf.io.gfile.GFile('modified_model.pb','wb') asf:f.write(graph.as_graph_def().SerializeToString()) ```但是,这种方法可能不会总是成功,因为形状改变可能导致后续节点形状不匹配。因此,我们需要确保模型结构允许新的输入形状(例如,全连接层要求固定维度,所以如果改变输入形状,全连接层的输入维度也必须相应改变)。如果模型中有全连接层,那么修改输入形状通常需要同时修改全连接层的权重(因为输入维度变了),这变得非常复杂。因此,修改输入形状通常只适用于卷积网络等对输入大小不敏感的网络(通过全局池化等避免全连接层)。另一种思路:在模型前添加一个预处理层(如resize层),这样我们保持原模型输入形状不变,但在输入数据之前先resize到原模型要求的形状。这种方法更安全,但会引入额外的计算。总结:-如果模型是卷积网络,没有全连接层(或者全连接层之前的特征提取部分可以适应不同大小,最后通过全局池化固定),那么可以直接修改输入形状。-否则,建议在模型外部进行resize,或者修改模型结构(如果有源代码)并重新训练。根据引用[2]中的代码,我们也可以考虑在保存为pb之前修改输入形状(如果有原始会话和图的访问权限)。因此,具体方法取决于模型的结构和我们的需求。注意:以上代码基于TensorFlow1.x,因为引用[2]中的代码是1.x的。在TensorFlow2.x中,我们可以使用tf.compat.v1来保持兼容性。如果用户使用的是TensorFlow2.x,并且模型是SavedModel格式,则可以使用以下方法:```pythonimporttensorflowas tf#加载SavedModelmodel= tf.saved_model.load('path_to_savedmodel') #如果模型有具体的签名,我们可以通过构建一个新的模型来改变输入形状#但SavedModel的输入形状是固定的,除非我们使用tf.keras模型并重新构建#对于Keras模型,我们可以在加载时通过自定义对象和配置来改变输入形状#但前提是模型是用Keras保存的model= tf.keras.models.load_model('keras_model.h5') #然后我们可以通过重新构建模型来改变输入形状#例如,使用函数式APIinput_tensor =tf.keras.layers.Input(shape=(256,256,3)) output_tensor= model(input_tensor)new_model= tf.keras.models.Model(input_tensor,output_tensor)new_model.save('new_keras_model.h5') ```但是,用户的问题是关于pb文件的,所以我们主要关注GraphDef的修改。因此,我们提供两种主要方法:1.通过图重写(使用input_map)修改输入节点(适用于无全连接层或全连接层不受影响的模型)。2.在模型外部添加resize操作(更通用,不修改模型内部)。方法二的示例(外部resize):在输入数据之前,使用tf.image.resize将输入数据调整到模型要求的原始形状。这样,我们就不需要修改模型本身了。根据用户的具体需求,选择合适的方法。相关问题:1.修改TensorFlow模型输入形状后,为什么模型推理会出错?2.如何在不修改模型的情况下处理不同尺寸的输入?3.TensorFlow中如何查看模型的输入节点名称?4.修改模型输入形状后,如何验证模型是否正确?5.TensorFlow2.x中修改模型输入形状的方法与1.x有什么不同?</think>### 修改TensorFlow模型输入形状的方法 在TensorFlow中修改模型输入形状主要有以下三种方法,根据您的模型格式(pb文件或CheckPoint)选择合适方案: #### 方法1:直接修改GraphDef(适用于pb模型) ```python import tensorflow as tf from tensorflow.python.framework import tensor_util # 加载原始pb模型 with tf.io.gfile.GFile('model.pb', 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) # 查找并修改输入节点 for node in graph_def.node: if node.name == 'input_node_name': # 替换为实际输入节点名 # 创建新的TensorShapeProto new_shape = tf.TensorShape([None, 新高度, 新宽度, 通道数]).as_proto() # 清除旧shape属性 if 'shape' in node.attr: del node.attr['shape'] # 设置新shape node.attr['shape'].shape.CopyFrom(new_shape) # 保存修改后的模型 with tf.io.gfile.GFile('modified_model.pb', 'wb') as f: f.write(graph_def.SerializeToString()) ``` **注意**:需将`input_node_name`替换为实际输入节点名(可通过Netron查看),新形状需兼容后续层计算[^2]。 #### 方法2:通过输入映射重构图(推荐) ```python with tf.Graph().as_default() as graph: # 创建新输入张量 new_input = tf.compat.v1.placeholder( tf.float32, shape=[None, 新高度, 新宽度, 通道数], name='new_input' ) # 加载原图并映射输入 tf.graph_util.import_graph_def( graph_def, input_map={'original_input_name:0': new_input}, # 原输入节点名 name='' ) # 获取输出节点 output_tensor = graph.get_tensor_by_name('output_node_name:0') # 保存新模型 converted_graph_def = tf.graph_util.convert_variables_to_constants( tf.compat.v1.Session(), graph.as_graph_def(), [output_tensor.op.name] ) with tf.io.gfile.GFile('remapped_model.pb', 'wb') as f: f.write(converted_graph_def.SerializeToString()) ``` #### 方法3:从CheckPoint重建(适用于ckpt模型) ```python # 重建计算图时修改输入 input_ph = tf.compat.v1.placeholder( tf.float32, [None, 新高度, 新宽度, 通道数], name='new_input' ) # 构建模型网络(使用新输入) logits = your_model_architecture(input_ph) # 从ckpt加载权重 saver = tf.compat.v1.train.Saver() with tf.compat.v1.Session() as sess: saver.restore(sess, 'model.ckpt') # 转换并保存为pb output_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph.as_graph_def(), ['output_node_name'] ) with tf.io.gfile.GFile('new_model.pb', 'wb') as f: f.write(output_graph_def.SerializeToString()) ``` ### 关键注意事项 1. **形状兼容性**:新形状需满足: $$ \text{新宽度} \times \text{新高度} \times \text{通道数} = \text{原输入元素总数} $$ 否则全连接层会报错 2. **节点名确认**:使用Netron可视化pb文件获取准确输入/输出节点名 3. **动态维度**:`None`表示批处理维度可变,其他维度修改需谨慎 4. **后处理调整**:若涉及全连接层,需同步修改后续层权重矩阵维度: $$ W_{\text{new}} \in \mathbb{R}^{d_{\text{new}} \times d_{\text{out}}} $$ > **提示**:对于卷积网络,可通过全局平均池化替代全连接层避免形状约束[^1]。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值