TensorFlow:tf.scan与tf.map_fn

该文介绍了TensorFlow中的scan和map_fn函数,它们用于处理张量序列。scan函数沿第一维度对元素应用可调用函数,而map_fn则对输入序列中的每个元素应用函数。两者都支持并行迭代和反向传播,并可以设置是否交换GPU-CPU内存来优化性能。用户需指定函数输出类型、迭代次数等参数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

scan(
    fn,
    elems,
    initializer=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)
map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)
map_fnscan
fn:

可调用的执行.它接受一个参数,它将具有与之相同的(可能嵌套的)结构 elems.其输出必须具有与 dtype 相同的结构(如果提供了),否则它必须具有与elems 相同的结构.

elems:

张量或(可能是嵌套的)张量序列,其中的每一个都将沿着它们的第一维度进行解压.生成的切片的嵌套序列将应用于 fn.

dtype(initializer)

(可选)fn 的输出类型.如果 fn 返回与 elems 结构不同的张量结构,则 dtype 不是可选的,并且必须具有与 fn 的输出相同的结构.

parallel_iterations:

(可选)允许并行运行的迭代次数.

back_prop:

(可选)True 允许支持反向传播.

swap_memory:

(可选)True 可实现 GPU-CPU 内存交换.

infer_shape:

(可选)False 禁用对一致输出形状的测试.

name

(可选)返回的张量的名称前缀.

这是报错信息:Traceback (most recent call last): File "main_SemanticKITTI.py", line 206, in <module> dataset.init_input_pipeline() File "main_SemanticKITTI.py", line 176, in init_input_pipeline self.batch_train_data = self.batch_train_data.map(map_func=map_func) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1861, in map return MapDataset(self, map_func, preserve_cardinality=True) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4985, in __init__ use_legacy_function=use_legacy_function) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4218, in __init__ self._function = fn_factory() File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3151, in get_concrete_function *args, **kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3116, in _get_concrete_function_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function capture_by_value=self._capture_by_value), File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4195, in wrapped_fn ret = wrapper_helper(*args) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4125, in wrapper_helper ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper raise e.ag_error_metadata.to_exception(e) AttributeError: in user code: main_SemanticKITTI.py:145 tf_map * up_i = tf.py_func(DP.knn_search, [sub_points, batch_pc, 1], tf.int32) AttributeError: module 'tensorflow' has no attribute 'py_func' 这是代码:from helper_tool import DataProcessing as DP from helper_tool import ConfigSemanticKITTI as cfg from helper_tool import Plot from os.path import join from RandLANet import Network from tester_SemanticKITTI import ModelTester import tensorflow as tf import numpy as np import os, argparse, pickle class SemanticKITTI: def __init__(self, test_id): self.name = 'SemanticKITTI' self.dataset_path = '/root/autodl-tmp/RandLA-Net-master/data/semantic_kitti/dataset/sequences_0.06' self.label_to_names = {0: 'unlabeled', 1: 'car', 2: 'bicycle', 3: 'motorcycle', 4: 'truck', 5: 'other-vehicle', 6: 'person', 7: 'bicyclist', 8: 'motorcyclist', 9: 'road', 10: 'parking', 11: 'sidewalk', 12: 'other-ground', 13: 'building', 14: 'fence', 15: 'vegetation', 16: 'trunk', 17: 'terrain', 18: 'pole', 19: 'traffic-sign'} self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} self.ignored_labels = np.sort([0]) self.val_split = '08' self.seq_list = np.sort(os.listdir(self.dataset_path)) self.test_scan_number = str(test_id) self.train_list, self.val_list, self.test_list = DP.get_file_list(self.dataset_path, self.test_scan_number) self.train_list = DP.shuffle_list(self.train_list) self.val_list = DP.shuffle_list(self.val_list) self.possibility = [] self.min_possibility = [] # Generate the input data flow def get_batch_gen(self, split): if split == 'training': num_per_epoch = int(len(self.train_list) / cfg.batch_size) * cfg.batch_size path_list = self.train_list elif split == 'validation': num_per_epoch = int(len(self.val_list) / cfg.val_batch_size) * cfg.val_batch_size cfg.val_steps = int(len(self.val_list) / cfg.batch_size) path_list = self.val_list elif split == 'test': num_per_epoch = int(len(self.test_list) / cfg.val_batch_size) * cfg.val_batch_size * 4 path_list = self.test_list for test_file_name in path_list: points = np.load(test_file_name) self.possibility += [np.random.rand(points.shape[0]) * 1e-3] self.min_possibility += [float(np.min(self.possibility[-1]))] def spatially_regular_gen(): # Generator loop for i in range(num_per_epoch): if split != 'test': cloud_ind = i pc_path = path_list[cloud_ind] pc, tree, labels = self.get_data(pc_path) # crop a small point cloud pick_idx = np.random.choice(len(pc), 1) selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx) else: cloud_ind = int(np.argmin(self.min_possibility)) pick_idx = np.argmin(self.possibility[cloud_ind]) pc_path = path_list[cloud_ind] pc, tree, labels = self.get_data(pc_path) selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx) # update the possibility of the selected pc dists = np.sum(np.square((selected_pc - pc[pick_idx]).astype(np.float32)), axis=1) delta = np.square(1 - dists / np.max(dists)) self.possibility[cloud_ind][selected_idx] += delta self.min_possibility[cloud_ind] = np.min(self.possibility[cloud_ind]) if True: yield (selected_pc.astype(np.float32), selected_labels.astype(np.int32), selected_idx.astype(np.int32), np.array([cloud_ind], dtype=np.int32)) gen_func = spatially_regular_gen gen_types = (tf.float32, tf.int32, tf.int32, tf.int32) gen_shapes = ([None, 3], [None], [None], [None]) return gen_func, gen_types, gen_shapes def get_data(self, file_path): seq_id = file_path.split('/')[-3] frame_id = file_path.split('/')[-1][:-4] kd_tree_path = join(self.dataset_path, seq_id, 'KDTree', frame_id + '.pkl') # Read pkl with search tree with open(kd_tree_path, 'rb') as f: search_tree = pickle.load(f) points = np.array(search_tree.data, copy=False) # Load labels if int(seq_id) >= 11: labels = np.zeros(np.shape(points)[0], dtype=np.uint8) else: label_path = join(self.dataset_path, seq_id, 'labels', frame_id + '.npy') labels = np.squeeze(np.load(label_path)) return points, search_tree, labels @staticmethod def crop_pc(points, labels, search_tree, pick_idx): # crop a fixed size point cloud for training center_point = points[pick_idx, :].reshape(1, -1) select_idx = search_tree.query(center_point, k=cfg.num_points)[1][0] select_idx = DP.shuffle_idx(select_idx) select_points = points[select_idx] select_labels = labels[select_idx] return select_points, select_labels, select_idx @staticmethod def get_tf_mapping2(): def tf_map(batch_pc, batch_label, batch_pc_idx, batch_cloud_idx): features = batch_pc input_points = [] input_neighbors = [] input_pools = [] input_up_samples = [] for i in range(cfg.num_layers): neighbour_idx = tf.py_func(DP.knn_search, [batch_pc, batch_pc, cfg.k_n], tf.int32) sub_points = batch_pc[:, :tf.shape(batch_pc)[1] // cfg.sub_sampling_ratio[i], :] pool_i = neighbour_idx[:, :tf.shape(batch_pc)[1] // cfg.sub_sampling_ratio[i], :] up_i = tf.py_func(DP.knn_search, [sub_points, batch_pc, 1], tf.int32) input_points.append(batch_pc) input_neighbors.append(neighbour_idx) input_pools.append(pool_i) input_up_samples.append(up_i) batch_pc = sub_points input_list = input_points + input_neighbors + input_pools + input_up_samples input_list += [features, batch_label, batch_pc_idx, batch_cloud_idx] return input_list return tf_map def init_input_pipeline(self): print('Initiating input pipelines') cfg.ignored_label_inds = [self.label_to_idx[ign_label] for ign_label in self.ignored_labels] gen_function, gen_types, gen_shapes = self.get_batch_gen('training') gen_function_val, _, _ = self.get_batch_gen('validation') gen_function_test, _, _ = self.get_batch_gen('test') self.train_data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes) self.val_data = tf.data.Dataset.from_generator(gen_function_val, gen_types, gen_shapes) self.test_data = tf.data.Dataset.from_generator(gen_function_test, gen_types, gen_shapes) self.batch_train_data = self.train_data.batch(cfg.batch_size) self.batch_val_data = self.val_data.batch(cfg.val_batch_size) self.batch_test_data = self.test_data.batch(cfg.val_batch_size) map_func = self.get_tf_mapping2() self.batch_train_data = self.batch_train_data.map(map_func=map_func) self.batch_val_data = self.batch_val_data.map(map_func=map_func) self.batch_test_data = self.batch_test_data.map(map_func=map_func) self.batch_train_data = self.batch_train_data.prefetch(cfg.batch_size) self.batch_val_data = self.batch_val_data.prefetch(cfg.val_batch_size) self.batch_test_data = self.batch_test_data.prefetch(cfg.val_batch_size) iter = tf.data.Iterator.from_structure(self.batch_train_data.output_types, self.batch_train_data.output_shapes) self.flat_inputs = iter.get_next() self.train_init_op = iter.make_initializer(self.batch_train_data) self.val_init_op = iter.make_initializer(self.batch_val_data) self.test_init_op = iter.make_initializer(self.batch_test_data) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0, help='the number of GPUs to use [default: 0]') parser.add_argument('--mode', type=str, default='train', help='options: train, test, vis') parser.add_argument('--test_area', type=str, default='14', help='options: 08, 11,12,13,14,15,16,17,18,19,20,21') parser.add_argument('--model_path', type=str, default='None', help='pretrained model path') FLAGS = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.gpu) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' Mode = FLAGS.mode test_area = FLAGS.test_area dataset = SemanticKITTI(test_area) dataset.init_input_pipeline() if Mode == 'train': model = Network(dataset, cfg) model.train(dataset) elif Mode == 'test': cfg.saving = False model = Network(dataset, cfg) if FLAGS.model_path is not 'None': chosen_snap = FLAGS.model_path else: chosen_snapshot = -1 logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')]) chosen_folder = logs[-1] snap_path = join(chosen_folder, 'snapshots') snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta'] chosen_step = np.sort(snap_steps)[-1] chosen_snap = os.path.join(snap_path, 'snap-{:d}'.format(chosen_step)) tester = ModelTester(model, dataset, restore_snap=chosen_snap) tester.test(model, dataset) else: ################## # Visualize data # ################## with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(dataset.train_init_op) while True: flat_inputs = sess.run(dataset.flat_inputs) pc_xyz = flat_inputs[0] sub_pc_xyz = flat_inputs[1] labels = flat_inputs[17] Plot.draw_pc_sem_ins(pc_xyz[0, :, :], labels[0, :]) Plot.draw_pc_sem_ins(sub_pc_xyz[0, :, :], labels[0, 0:np.shape(sub_pc_xyz)[1]])
最新发布
07-30
<think>嗯,用户之前遇到了依赖冲突的问题,现在按照我的建议安装了指定版本的包,但出现了新的错误:`load_tensorflow`的跟踪栈,特别是关于`ValueError: Failed to import metagraph`。我需要分析可能的原因,并提供解决方案。 首先,用户的问题出现在加载TensorFlow模型到RKNN Toolkit时。错误信息提到了`tf_optimizer.py`中的失败,这可能TensorFlow版本或模型结构有关。用户已经安装了指定的TensorFlow 2.6.2,但可能还有其他因素导致这个问题。 可能的因素包括: 1. **模型格式问题**:用户可能使用了SavedModel格式,而RKNN Toolkit需要冻结的GraphDef(.pb)模型。 2. **模型中的自定义层或操作**:如果模型包含RKNN不支持的层或操作,转换时会失败。 3. **TensorFlow安装问题**:虽然版本正确,但可能存在安装不完整或依赖项冲突。 4. **Python环境问题**:虚拟环境中可能存在残留文件或其他包的干扰。 接下来,我需要分步骤思考解决方案: **步骤一:检查模型格式** 用户可能尝试加载的是SavedModel,而RKNN可能需要冻结的.pb文件。建议用户使用`freeze_graph`工具将模型转换为冻结格式。同时,确保输入节点名称正确,可能需要使用TensorFlow的工具来查看模型结构。 **步骤二:验证模型兼容性** 用户需要检查模型中是否有自定义层或RKNN不支持的操作。运行模型在纯TensorFlow环境中的推理,确保模型本身没有问题。另外,可以尝试简化模型结构,排除复杂操作的影响。 **步骤三:环境彻底清理** 虽然用户创建了虚拟环境,但可能存在安装时的缓存或残留。建议彻底卸载TensorFlow和相关包,然后重新安装,确保所有依赖项都正确。 **步骤四:使用中间转换格式** 如果直接转换TensorFlow模型失败,可以尝试先将模型转换为ONNX格式,再通过RKNN加载ONNX模型。这可能需要额外的工具,如tf2onnx,但可能绕过一些兼容性问题。 **步骤五:日志和调试** 启用详细日志有助于定位问题。查看是否有更具体的错误信息,或者尝试在转换时设置不同的日志级别。此外,检查是否有路径或文件权限问题,导致模型无法正确读取。 **步骤六:寻求官方支持** 如果以上步骤都失败,建议用户查阅RKNN的官方文档或社区,看看是否有已知的解决方案。可能存在特定版本的RKNN Toolkit需要额外的补丁或配置。 最后,考虑到用户之前已经处理过依赖冲突,这次的问题可能更具体于模型本身或环境配置。需要逐步排除可能的原因,从模型格式到环境设置,确保每一步都符合RKNN Toolkit的要求。</think>━━━━━━━━━━━━━━━━━━ ### 错误根源分析 该错误发生在使用RKNN Toolkit加载TensorFlow模型时,核心原因是 **TensorFlow模型结构解析失败**。常见触发场景: ``` 1. 使用了不支持的模型格式(如SavedModel而非Frozen Graph) 2. 模型中包含RKNN不支持的TensorFlow操作 3. TensorFlow版本模型结构不兼容(即使版本号匹配) ``` ━━━━━━━━━━━━━━━━━━ ### 逐步解决方案 #### 步骤一:确认模型格式要求 **必须使用冻结的TensorFlow GraphDef模型(.pb文件)** ```bash # 检查模型格式是否正确 file your_model.pb # 期望输出:data (protocol buffer binary format) # 查看模型输入输出节点名称 saved_model_cli show --dir your_model/ --all ``` > 📌 如果使用SavedModel格式,需先转换: > ```python > from tensorflow.python.tools import freeze_graph > freeze_graph.freeze_graph(...) # 参考官方冻结模型文档 > ``` #### 步骤二:强制指定输入输出节点 在`load_tensorflow`调用中显式声明节点名称: ```python rknn.load_tensorflow( tf_pb='model.pb', inputs=['input_1'], # 必须精确匹配输入节点名 outputs=['output_1'], # 必须精确匹配输出节点名 input_size_list=[[224,224,3]] ) ``` #### 步骤三:检查模型操作兼容性 运行以下代码验证模型是否包含非法操作: ```python import tensorflow as tf # 加载模型 with tf.io.gfile.GFile('model.pb', 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) # 打印所有操作类型 ops = {node.op for node in graph_def.node} print(f"模型中包含的操作类型:{ops}") # 对比RKNN支持的操作列表 # 参考:https://github.com/rockchip-linux/rknn-toolkit2/blob/master/doc/RKNN_OP_List.md ``` > ⚠️ 若出现`TensorList*`、`While`等控制流操作,需修改模型结构 #### 步骤四:环境完整性验证 在虚拟环境中执行以下检查: ```bash # 确认TensorFlow精确版本 python -c "import tensorflow as tf; print(tf.__version__)" # 必须输出2.6.2 # 清理残留安装 pip uninstall -y tensorflow protobuf numpy pip install --no-cache-dir tensorflow==2.6.2 protobuf==3.12.2 # 检查ABI兼容性 python -c "import tensorflow as tf; print(tf.sysconfig.get_compile_flags())" # 应包含 '-D_GLIBCXX_USE_CXX11_ABI=0' ``` #### 步骤五:使用调试模式获取详细信息 ```python from rknn.api import RKNN rknn = RKNN(verbose=True) # 启用详细日志 rknn.config(target_platform='rk3588') # 根据实际硬件修改 ret = rknn.load_tensorflow(...) # 观察完整错误堆栈 ``` > 📍 关键日志线索: > - 若出现`Unsupported Ops: [...]`需修改模型 > - 若出现`Failed to parse meta graph`需检查模型冻结方式 ━━━━━━━━━━━━━━━━━━ ### 高级调试方案 #### 方案一:强制优化器降级 编辑RKNN源码(谨慎操作): ```python # 定位文件:rknn/base/convertor/tensorflow2onnx/tf_loader.py # 修改第722行附近的优化参数 config = tf.compat.v1.ConfigProto() config.graph_options.rewrite_options.disable_meta_optimizer = True # 新增 ``` #### 方案二:转ONNX中间格式 ```bash # 使用tf2onnx转换模型 python -m tf2onnx.convert --saved-model saved_model_dir --output model.onnx # 在RKNN中加载ONNX rknn.load_onnx(model='model.onnx') ``` #### 方案三:使用官方Docker环境 ```bash # 拉取Rockchip官方镜像 docker pull rockchip/rknn-toolkit2:1.5.0 # 运行容器并挂载模型目录 docker run -it --rm -v /path/to/models:/models rockchip/rknn-toolkit2:1.5.0 ``` ━━━━━━━━━━━━━━━━━━ ### 关键排查流程图 ``` 模型加载失败 ├─→ 检查模型格式 → 非冻结模型 → 使用freeze_graph转换 ├─→ 检查节点名称 → 名称错误 → 使用Netron可视化模型 ├─→ 检查操作兼容性 → 存在非法操作 → 修改模型结构 └─→ 检查TF版本 → 版本异常 → 完全重装指定版本环境 ``` 若仍无法解决,建议: 1. 提供完整的模型结构图和转换日志 2. 测试Rockchip提供的预训练模型是否可转换(验证环境完整性) 3. 在RK3588开发板上直接执行转换(部分操作需硬件支持)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值