有关python的iter,next,yield,和pytorch的dataloader

本文探讨了Python中iter、next和yield在遍历数组、GeneratorFunction以及PyTorch DataLoader中的作用。重点讲解了它们之间的区别,以及PyTorch如何利用迭代器机制加载数据。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

问题

首先,对于一个python数组,我们可以用for,或者next来遍历。

其次,而对于pytorch的 torch.utils.data.DataLoader, 也可以使用类似 next(iter(DataLoader)) 的方式遍历地读取dataset的数据。

再次,使用yield关键字,也可以起到“遍历”的效果。

那么问题来了。这iter,next,yield几个东西之间有什么关联,又有什么区别呢?pytorch的DataLoader又是用的什么样的方式呢?

分析

先看一个简单例子:

fruit = ["apple", "banana", "cherry"]
# 1. 使用for循环
for f in fruit:
    print(f)
# 2. 使用iter创建迭代器,通过next遍历
fr = iter(fruit)
print(next(fr))
print(next(fr))
print(next(fr))

"""
上面两种方式,都会打印以下结果:
apple
banana
cherry
"""

不要被同样的结果迷惑了。python的iter方法将数组转为一个可遍历的对象,对这个对象遍历,和对原来的数组遍历,是不一样的。不信,可以接着再执行下面的代码:

for f in fr:
    print(f)

这次,没有任何输出。

其中的秘密,就在于fr是用iter()创建的对象。对这个对象遍历的时候,每个元素只能遍历到一次(可以理解为,其内部有当前读指针,每读一次,指针向前移动)。就是这么简单的过程。

再看yield,实际上干了差不多的事情:

def fruit_iter (fruits):
    count = len(fruits)
    for i in range(0, count):
        yield fruits[i]

it = fruit_iter(fruit)
for f in it:
  print(f)
"""
打印结果:
apple
banana
cherry
"""

上面fruit_iter被称为Generator Function,返回一个Generator对象。可以打印一下:

print(it)
"""
<generator object fruit_iter at 0x00000240C967C990>
"""

另外,虽然从上面的例子看出,next()和for对于Generator Function都能起到遍历的作用,但next()还有额外的解释(链接):

Note: When you use next(), Python calls ._next_() on the function
you pass in as a parameter.

也就是说,next()函数实际上调用了传入函数的.__next()__成员函数。所以,如果传入的函数没有这个成员,则会报错。

到这个地方,相信对iter()函数、next()函数,以及yield有了一定的认识。但还有一个问题:pytorch的torch.utils.data.DataLoader 是用哪种方式实现的?用一个简单的例子验证一下:

import torch
# 生成一些测试数据
X = torch.normal(0, 1, (1000, 2)) # x: sample size = 1000, feature_dim = 2
y = torch.normal(0, 1, (1000, 1)) # y: sample size = 1000, dim = 1

# 定义一个函数,返回dataloader
def load_array(data_and_label, batch_size, is_train=True):
    """Construct a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_and_label)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((X, y), batch_size)

#这句会报错:next(data_iter)
next(iter(data_iter))
"""
输出前10组数据
"""

这里,为什么 next(data_iter) 报错,而 next(iter(data_iter)) 可以返回数据呢?这是因为,pytorch的DataLoader函数没有 _next_ 成员,但有 _iter_ 成员(见源文件)。所以,需要首先通过 iter() 函数返回一个 _iter_ 成员,再找这个 _iter_ 的 _next_

到这里,我们对本文开始提到的问题(iter, next, yield 和 pytorch dataloader)应该有一个比较深入的了解了。

参考

https://www.guru99.com/python-yield-return-generator.html
https://realpython.com/introduction-to-python-generators/
https://anandology.com/python-practice-book/iterators.html
https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader

### 使用 `yield` 关键字实现 PyTorch DataLoader 自定义迭代器 在构建自定义的 PyTorch DataLoader 迭代器时,可以利用 Python 的 `yield` 关键字创建生成器函数。这有助于更高效地管理内存并按需提供数据批次。 #### 创建自定义 Dataset 类 为了更好地理解如何集成 `yield` DataLoader,先来看一个简单的自定义 Dataset 类: ```python from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data, labels): self.data = data self.labels = labels def __len__(self): return len(self.data) def __getitem__(self, idx): sample_data = self.data[idx] label = self.labels[idx] return sample_data, label ``` 此部分展示了如何初始化一个继承自 `torch.utils.data.Dataset` 的类,并实现了两个必要的方法:`__len__()` 返回数据集大小;`__getitem__()` 获取指定索引的数据项[^1]。 #### 构建带有 `yield` 的自定义 Sampler 或 BatchSampler 当希望控制批处理逻辑或采样方式时,可以通过扩展 `BatchSampler` 来实现这一点。下面是一个例子,展示如何使用 `yield` 提供批量索引: ```python import random from torch.utils.data.sampler import BatchSampler class CustomBatchSampler(BatchSampler): def __iter__(self): indices = list(range(len(self.sampler))) if self.shuffle: random.shuffle(indices) batch = [] for index in indices: batch.append(index) if len(batch) == self.batch_size: yield batch batch = [] # 如果最后一批次不为空,则也应返回它 if len(batch) > 0 and not self.drop_last: yield batch ``` 这段代码中,`CustomBatchSampler` 继承了 `BatchSampler` 并覆盖了其 `_iter_` 方法,在其中应用了 `yield` 表达式来逐个产出 mini-batches 的索引列表。这种方式允许更加灵活地调整每批次内的样本顺序以及处理最后一个可能不满额的小批次[^2]。 #### 定义自定义 DataLoader 有了上述组件之后,就可以轻松地将它们组合起来形成一个新的 DataLoader 实现: ```python from torch.utils.data import DataLoader def custom_dataloader(dataset, batch_sampler_class=CustomBatchSampler, **kwargs): return DataLoader( dataset, batch_sampler=batch_sampler_class(sampler=None, drop_last=False, **kwargs), num_workers=0 # 设置为0表示主线程执行I/O操作 ) ``` 这里的关键在于传递了一个定制化的 `batch_sampler_class` 参数给 `DataLoader`,从而使得整个流程能够按照预期工作。注意设置合适的参数值(如 `num_workers`)对于性能优化非常重要。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值