tensorflow API:tf.map_fn

本文深入解析了TensorFlow中的tf.map_fn函数,详细介绍了其参数设置与功能特性,包括并行迭代次数、反向传播支持及GPU-CPU内存交换等选项。并通过实例展示了如何使用tf.map_fn对张量列表进行操作,适用于需要对张量元素进行逐个处理的场景。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.map_fn(
    fn,
    elems,
    dtype=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    infer_shape=True,
    name=None
)

作用:map on the list of tensors unpacked from elems on dimension 0.

参数
fn: The callable to be performed. It accepts one argument, which will have the same (possibly nested) structure as elems. Its output must have the same structure as dtype if one is provided, otherwise it must have the same structure as elems.

elems: A tensor or (possibly nested) sequence of tensors, each of which will be unpacked along their first dimension. The nested sequence of the resulting slices will be applied to fn.

dtype: (optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn.

parallel_iterations: (optional) The number of iterations allowed to run in parallel.

back_prop: (optional) True enables support for back propagation.

swap_memory: (optional) True enables GPU-CPU memory swapping.

infer_shape: (optional) False disables tests for consistent output shapes.

name: (optional) Name prefix for the returned tensors.

官网例子:
1.

elems = np.array([1, 2, 3, 4, 5, 6])
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]

elems = (np.array([1, 2, 3]), np.array([-1, 1, -1]))
alternate = map_fn(lambda x: x[0] * x[1], elems, dtype=tf.int64)
# alternate == [-1, 2, -3]

elems = np.array([1, 2, 3])
alternates = map_fn(lambda x: (x, -x), elems, dtype=(tf.int64, tf.int64))
# alternates[0] == [1, 2, 3]
# alternates[1] == [-1, -2, -3]
这是报错信息:Traceback (most recent call last): File "main_SemanticKITTI.py", line 206, in <module> dataset.init_input_pipeline() File "main_SemanticKITTI.py", line 176, in init_input_pipeline self.batch_train_data = self.batch_train_data.map(map_func=map_func) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 1861, in map return MapDataset(self, map_func, preserve_cardinality=True) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4985, in __init__ use_legacy_function=use_legacy_function) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4218, in __init__ self._function = fn_factory() File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3151, in get_concrete_function *args, **kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3116, in _get_concrete_function_garbage_collected graph_function, _ = self._maybe_define_function(args, kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3463, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 3308, in _create_graph_function capture_by_value=self._capture_by_value), File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/framework/func_graph.py", line 1007, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4195, in wrapped_fn ret = wrapper_helper(*args) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 4125, in wrapper_helper ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args) File "/root/miniconda3/envs/randlanet/lib/python3.6/site-packages/tensorflow/python/autograph/impl/api.py", line 695, in wrapper raise e.ag_error_metadata.to_exception(e) AttributeError: in user code: main_SemanticKITTI.py:145 tf_map * up_i = tf.py_func(DP.knn_search, [sub_points, batch_pc, 1], tf.int32) AttributeError: module 'tensorflow' has no attribute 'py_func' 这是代码:from helper_tool import DataProcessing as DP from helper_tool import ConfigSemanticKITTI as cfg from helper_tool import Plot from os.path import join from RandLANet import Network from tester_SemanticKITTI import ModelTester import tensorflow as tf import numpy as np import os, argparse, pickle class SemanticKITTI: def __init__(self, test_id): self.name = 'SemanticKITTI' self.dataset_path = '/root/autodl-tmp/RandLA-Net-master/data/semantic_kitti/dataset/sequences_0.06' self.label_to_names = {0: 'unlabeled', 1: 'car', 2: 'bicycle', 3: 'motorcycle', 4: 'truck', 5: 'other-vehicle', 6: 'person', 7: 'bicyclist', 8: 'motorcyclist', 9: 'road', 10: 'parking', 11: 'sidewalk', 12: 'other-ground', 13: 'building', 14: 'fence', 15: 'vegetation', 16: 'trunk', 17: 'terrain', 18: 'pole', 19: 'traffic-sign'} self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} self.ignored_labels = np.sort([0]) self.val_split = '08' self.seq_list = np.sort(os.listdir(self.dataset_path)) self.test_scan_number = str(test_id) self.train_list, self.val_list, self.test_list = DP.get_file_list(self.dataset_path, self.test_scan_number) self.train_list = DP.shuffle_list(self.train_list) self.val_list = DP.shuffle_list(self.val_list) self.possibility = [] self.min_possibility = [] # Generate the input data flow def get_batch_gen(self, split): if split == 'training': num_per_epoch = int(len(self.train_list) / cfg.batch_size) * cfg.batch_size path_list = self.train_list elif split == 'validation': num_per_epoch = int(len(self.val_list) / cfg.val_batch_size) * cfg.val_batch_size cfg.val_steps = int(len(self.val_list) / cfg.batch_size) path_list = self.val_list elif split == 'test': num_per_epoch = int(len(self.test_list) / cfg.val_batch_size) * cfg.val_batch_size * 4 path_list = self.test_list for test_file_name in path_list: points = np.load(test_file_name) self.possibility += [np.random.rand(points.shape[0]) * 1e-3] self.min_possibility += [float(np.min(self.possibility[-1]))] def spatially_regular_gen(): # Generator loop for i in range(num_per_epoch): if split != 'test': cloud_ind = i pc_path = path_list[cloud_ind] pc, tree, labels = self.get_data(pc_path) # crop a small point cloud pick_idx = np.random.choice(len(pc), 1) selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx) else: cloud_ind = int(np.argmin(self.min_possibility)) pick_idx = np.argmin(self.possibility[cloud_ind]) pc_path = path_list[cloud_ind] pc, tree, labels = self.get_data(pc_path) selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx) # update the possibility of the selected pc dists = np.sum(np.square((selected_pc - pc[pick_idx]).astype(np.float32)), axis=1) delta = np.square(1 - dists / np.max(dists)) self.possibility[cloud_ind][selected_idx] += delta self.min_possibility[cloud_ind] = np.min(self.possibility[cloud_ind]) if True: yield (selected_pc.astype(np.float32), selected_labels.astype(np.int32), selected_idx.astype(np.int32), np.array([cloud_ind], dtype=np.int32)) gen_func = spatially_regular_gen gen_types = (tf.float32, tf.int32, tf.int32, tf.int32) gen_shapes = ([None, 3], [None], [None], [None]) return gen_func, gen_types, gen_shapes def get_data(self, file_path): seq_id = file_path.split('/')[-3] frame_id = file_path.split('/')[-1][:-4] kd_tree_path = join(self.dataset_path, seq_id, 'KDTree', frame_id + '.pkl') # Read pkl with search tree with open(kd_tree_path, 'rb') as f: search_tree = pickle.load(f) points = np.array(search_tree.data, copy=False) # Load labels if int(seq_id) >= 11: labels = np.zeros(np.shape(points)[0], dtype=np.uint8) else: label_path = join(self.dataset_path, seq_id, 'labels', frame_id + '.npy') labels = np.squeeze(np.load(label_path)) return points, search_tree, labels @staticmethod def crop_pc(points, labels, search_tree, pick_idx): # crop a fixed size point cloud for training center_point = points[pick_idx, :].reshape(1, -1) select_idx = search_tree.query(center_point, k=cfg.num_points)[1][0] select_idx = DP.shuffle_idx(select_idx) select_points = points[select_idx] select_labels = labels[select_idx] return select_points, select_labels, select_idx @staticmethod def get_tf_mapping2(): def tf_map(batch_pc, batch_label, batch_pc_idx, batch_cloud_idx): features = batch_pc input_points = [] input_neighbors = [] input_pools = [] input_up_samples = [] for i in range(cfg.num_layers): neighbour_idx = tf.py_func(DP.knn_search, [batch_pc, batch_pc, cfg.k_n], tf.int32) sub_points = batch_pc[:, :tf.shape(batch_pc)[1] // cfg.sub_sampling_ratio[i], :] pool_i = neighbour_idx[:, :tf.shape(batch_pc)[1] // cfg.sub_sampling_ratio[i], :] up_i = tf.py_func(DP.knn_search, [sub_points, batch_pc, 1], tf.int32) input_points.append(batch_pc) input_neighbors.append(neighbour_idx) input_pools.append(pool_i) input_up_samples.append(up_i) batch_pc = sub_points input_list = input_points + input_neighbors + input_pools + input_up_samples input_list += [features, batch_label, batch_pc_idx, batch_cloud_idx] return input_list return tf_map def init_input_pipeline(self): print('Initiating input pipelines') cfg.ignored_label_inds = [self.label_to_idx[ign_label] for ign_label in self.ignored_labels] gen_function, gen_types, gen_shapes = self.get_batch_gen('training') gen_function_val, _, _ = self.get_batch_gen('validation') gen_function_test, _, _ = self.get_batch_gen('test') self.train_data = tf.data.Dataset.from_generator(gen_function, gen_types, gen_shapes) self.val_data = tf.data.Dataset.from_generator(gen_function_val, gen_types, gen_shapes) self.test_data = tf.data.Dataset.from_generator(gen_function_test, gen_types, gen_shapes) self.batch_train_data = self.train_data.batch(cfg.batch_size) self.batch_val_data = self.val_data.batch(cfg.val_batch_size) self.batch_test_data = self.test_data.batch(cfg.val_batch_size) map_func = self.get_tf_mapping2() self.batch_train_data = self.batch_train_data.map(map_func=map_func) self.batch_val_data = self.batch_val_data.map(map_func=map_func) self.batch_test_data = self.batch_test_data.map(map_func=map_func) self.batch_train_data = self.batch_train_data.prefetch(cfg.batch_size) self.batch_val_data = self.batch_val_data.prefetch(cfg.val_batch_size) self.batch_test_data = self.batch_test_data.prefetch(cfg.val_batch_size) iter = tf.data.Iterator.from_structure(self.batch_train_data.output_types, self.batch_train_data.output_shapes) self.flat_inputs = iter.get_next() self.train_init_op = iter.make_initializer(self.batch_train_data) self.val_init_op = iter.make_initializer(self.batch_val_data) self.test_init_op = iter.make_initializer(self.batch_test_data) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0, help='the number of GPUs to use [default: 0]') parser.add_argument('--mode', type=str, default='train', help='options: train, test, vis') parser.add_argument('--test_area', type=str, default='14', help='options: 08, 11,12,13,14,15,16,17,18,19,20,21') parser.add_argument('--model_path', type=str, default='None', help='pretrained model path') FLAGS = parser.parse_args() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ['CUDA_VISIBLE_DEVICES'] = str(FLAGS.gpu) os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' Mode = FLAGS.mode test_area = FLAGS.test_area dataset = SemanticKITTI(test_area) dataset.init_input_pipeline() if Mode == 'train': model = Network(dataset, cfg) model.train(dataset) elif Mode == 'test': cfg.saving = False model = Network(dataset, cfg) if FLAGS.model_path is not 'None': chosen_snap = FLAGS.model_path else: chosen_snapshot = -1 logs = np.sort([os.path.join('results', f) for f in os.listdir('results') if f.startswith('Log')]) chosen_folder = logs[-1] snap_path = join(chosen_folder, 'snapshots') snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta'] chosen_step = np.sort(snap_steps)[-1] chosen_snap = os.path.join(snap_path, 'snap-{:d}'.format(chosen_step)) tester = ModelTester(model, dataset, restore_snap=chosen_snap) tester.test(model, dataset) else: ################## # Visualize data # ################## with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(dataset.train_init_op) while True: flat_inputs = sess.run(dataset.flat_inputs) pc_xyz = flat_inputs[0] sub_pc_xyz = flat_inputs[1] labels = flat_inputs[17] Plot.draw_pc_sem_ins(pc_xyz[0, :, :], labels[0, :]) Plot.draw_pc_sem_ins(sub_pc_xyz[0, :, :], labels[0, 0:np.shape(sub_pc_xyz)[1]])
最新发布
07-30
将下列代码修改成tensorflow2.3.0兼容的模式:class BSplineLayer(Layer): """自定义B样条基函数层""" def __init__(self, num_basis, degree=3, **kwargs): super(BSplineLayer, self).__init__(**kwargs) self.num_basis = num_basis self.degree = degree def build(self, input_shape): # 初始化B样条节点和系数 self.knots = self.add_weight( name='knots', shape=(self.num_basis + self.degree + 1,), initializer='glorot_uniform', trainable=True) self.coeffs = self.add_weight( name='coeffs', shape=(self.num_basis,), initializer='glorot_uniform', trainable=True) super(BSplineLayer, self).build(input_shape) def call(self, inputs): # B样条基函数计算 t = tf.linspace(-1.0, 1.0, self.num_basis + self.degree + 1) basis = tf.map_fn( lambda x: tf.math.bessel_j0( # 使用Bessel函数近似样条基 tf.reduce_sum(self.coeffs * tf.math.exp(-(x - self.knots)**2)) ), inputs) return basis class KANBlock(Layer): """KAN模块实现""" def __init__(self, num_basis, **kwargs): super(KANBlock, self).__init__(**kwargs) self.bspline_layer = BSplineLayer(num_basis=num_basis) def build(self, input_shape): self.bias = self.add_weight( name='bias', shape=(input_shape[-1],), initializer='zeros', trainable=True) super(KANBlock, self).build(input_shape) def call(self, inputs): # 分路径处理每个输入特征 spline_outputs = [] for i in range(inputs.shape[-1]): feature = inputs[..., i:i+1] spline_outputs.append(self.bspline_layer(feature)) # 合并并添加偏置 combined = tf.concat(spline_outputs, axis=-1) return combined + self.bias
04-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值