Dataset object 如何消耗数据三种方式:for/next/reduce

本文介绍了如何使用for循环、next函数和reduce函数遍历和消耗Dataset对象的数据。通过示例展示了for循环逐个元素输出,利用next创建Python迭代器以及reduce一次性使用所有元素的方法。同时,文章提到了next函数在不同情况下的使用注意事项,包括数据是否完全消耗及后续代码的读取可能性。


tensorflow ==2.2.0

for 逐个元素输出

因为Dataset object is a Python iterable.

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([8,3,0,8,2,1])
print(dataset)
for element in dataset:
    print(element)
    print(element.numpy())

shuchu

<TensorSliceDataset shapes: (), types: tf.int32>
tf.Tensor(8, shape=(), dtype=int32)
8
tf.Tensor(3, shape=(), dtype=int32)
3
tf.Tensor(0, shape=(), dtype=int32)
0
tf.Tensor(8, shape=(), dtype=int32)
8
tf.Tensor(2, shape=(), dtype=int32)
2
tf.Tensor(1, shape=(), dtype=int32)
1

从这里可以看出,element .numpy()就只是输出其相应的数字

创造一个Python iterator,用next 逐个消耗 **特别需要注意其用法

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([8,3,0,8,2,1])
it  = iter(dataset)
print(next(it))
print(next(it).numpy())

输出

tf.Tensor(8, shape=(), dtype=int32) 从这里可以卡看出是一个一个地控制输出
3

使用reduce 一次性使用全部的元素

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([8,3,0,8,2,1])
print(dataset.reduce(0,lambda state, value: state+value).numpy())

注意

即使这次消耗完了,在后面的代码中也可以重新再读取 for reduce
但是对于next来说 ,分情况
情况一:next 没有消耗完所有的数据,可以再次另外接着在后文使用 for,reduce

情况二:

import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([8,3,0,8,2,1])
it  = iter(dataset)
print(next(it).numpy()) #8
print(next(it).numpy()) #3
print(next(it).numpy()) #0
print(next(it).numpy()) #8
print(next(it).numpy()) #2
print(next(it).numpy()) #1
# print(next(it).numpy()) 如果next 的个数比element的个数还要多的话,就会程序报错
[WARNING] ME(28:281473808992512,MainProcess):2025-10-29-07:21:47.775.125 [mindspore/train/serialization.py:568] 2 parameters in the net are not loaded. [WARNING] ME(28:281473808992512,MainProcess):2025-10-29-07:21:47.778.686 [mindspore/train/serialization.py:570] end_point.weight is not loaded. [WARNING] ME(28:281473808992512,MainProcess):2025-10-29-07:21:47.779.833 [mindspore/train/serialization.py:570] end_point.bias is not loaded. Delete parameter from checkpoint: end_point.weight Delete parameter from checkpoint: end_point.bias Delete parameter from checkpoint: moments.end_point.weight Delete parameter from checkpoint: moments.end_point.bias [WARNING] SESSION(28,ffff4e7e81f0,python):2025-10-29-07:21:54.740.908 [mindspore/ccsrc/backend/session/ascend_session.cc:1806] SelectKernel] There are 54 node/nodes used reduce precision to selected the kernel! [EXCEPTION] CORE(28,ffffba660900,python):2025-10-29-07:22:38.383.434 [mindspore/core/utils/check_convert_utils.cc:624] _CheckTypeSame] The primitive[Conv2D]'s input type must be same. name:[w]:Ref[Tensor(F32)]. name:[x]:Tensor[UInt8]. --------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_28/623803114.py in <module> 642 643 # 执行训练 --> 644 model.train(num_epochs, train_ds, callbacks=[eval_cb, TimeMonitor()], dataset_sink_mode=True) 645 646 # 可视化最佳模型 ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/model.py in train(self, epoch, train_dataset, callbacks, dataset_sink_mode, sink_size) 724 callbacks=callbacks, 725 dataset_sink_mode=dataset_sink_mode, --> 726 sink_size=sink_size) 727 728 def build(self, train_dataset=None, valid_dataset=None, sink_size=-1, epoch=1): ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/model.py in _train(self, epoch, train_dataset, callbacks, dataset_sink_mode, sink_size) 502 self._train_process(epoch, train_dataset, list_callback, cb_params) 503 else: --> 504 self._train_dataset_sink_process(epoch, train_dataset, list_callback, cb_params, sink_size) 505 506 @staticmethod ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/model.py in _train_dataset_sink_process(self, epoch, train_dataset, list_callback, cb_params, sink_size) 575 576 dataset_helper.continue_send() --> 577 list_callback.epoch_end(run_context) 578 should_stop = should_stop or run_context.get_stop_requested() 579 if should_stop: ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/callback/_callback.py in epoch_end(self, run_context) 212 """Called after each epoch finished.""" 213 for cb in self._callbacks: --> 214 cb.epoch_end(run_context) 215 216 def step_begin(self, run_context): /tmp/ipykernel_28/623803114.py in epoch_end(self, run_context) 559 loss_epoch = cb_params.net_outputs 560 if cur_epoch >= self.eval_start_epoch and (cur_epoch - self.eval_start_epoch) % self.interval == 0: --> 561 res = self.eval_function(self.eval_param_dict) 562 print('Epoch {}/{}'.format(cur_epoch, num_epochs)) 563 print('-' * 10) /tmp/ipykernel_28/623803114.py in apply_eval(eval_param) 527 eval_ds = eval_param['dataset'] 528 metrics_name = eval_param['metrics_name'] --> 529 res = eval_model.eval(eval_ds) 530 return res[metrics_name] 531 ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/model.py in eval(self, valid_dataset, callbacks, dataset_sink_mode) 903 with _CallbackManager(callbacks) as list_callback: 904 if dataset_sink_mode: --> 905 return self._eval_dataset_sink_process(valid_dataset, list_callback, cb_params) 906 return self._eval_process(valid_dataset, list_callback, cb_params) 907 ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/model.py in _eval_dataset_sink_process(self, valid_dataset, list_callback, cb_params) 788 cb_params.cur_step_num += 1 789 list_callback.step_begin(run_context) --> 790 outputs = self._eval_network(*inputs) 791 cb_params.net_outputs = outputs 792 list_callback.step_end(run_context) ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/cell.py in __call__(self, *inputs, **kwargs) 402 if self.enable_hook: 403 raise ValueError("The graph mode does not support hook function.") --> 404 out = self.compile_and_run(*inputs) 405 return out 406 ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/cell.py in compile_and_run(self, *inputs) 680 """ 681 self._auto_parallel_compile_and_run = True --> 682 self.compile(*inputs) 683 684 new_inputs = [] ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/cell.py in compile(self, *inputs) 667 inputs (tuple): Inputs of the Cell object. 668 """ --> 669 _cell_graph_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) 670 671 def compile_and_run(self, *inputs): ~/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/common/api.py in compile(self, obj, phase, do_convert, auto_parallel_mode, *args) 546 enable_ge = context.get_context("enable_ge") 547 use_vm = not enable_ge or (enable_debug_runtime and context.get_context("mode") == context.PYNATIVE_MODE) --> 548 result = self._graph_executor.compile(obj, args_list, phase, use_vm, self.queue_name) 549 self.compile_cache[phase] = phase 550 if not result: TypeError: mindspore/core/utils/check_convert_utils.cc:624 _CheckTypeSame] The primitive[Conv2D]'s input type must be same. name:[w]:Ref[Tensor(F32)]. name:[x]:Tensor[UInt8]. The function call stack (See file '/ddhome/data/rank_0/om/analyze_fail.dat' for more details): # 0 In file /root/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/train/dataset_helper.py(79) return self.network(*outputs) ^ # 1 In file /root/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/wrap/cell_wrapper.py(597) if self.add_cast_fp32: # 2 In file /root/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/wrap/cell_wrapper.py(596) outputs = self._network(data) ^ # 3 In file /tmp/ipykernel_28/623803114.py(443) # 4 In file /tmp/ipykernel_28/623803114.py(452) # 5 In file /root/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/layer/conv.py(267) if self.has_bias: # 6 In file /root/miniconda3/envs/mindspore1.5/lib/python3.7/site-packages/mindspore/nn/layer/conv.py(266) output = self.conv2d(x, self.weight) ^
最新发布
10-30
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值