有关python的iter,next,yield,和pytorch的dataloader

Python遍历与生成器:iter, next, yield与PyTorch DataLoader详解
本文探讨了Python中iter、next和yield在遍历数组、GeneratorFunction以及PyTorch DataLoader中的作用。重点讲解了它们之间的区别,以及PyTorch如何利用迭代器机制加载数据。
部署运行你感兴趣的模型镜像
问题

首先,对于一个python数组,我们可以用for,或者next来遍历。

其次,而对于pytorch的 torch.utils.data.DataLoader, 也可以使用类似 next(iter(DataLoader)) 的方式遍历地读取dataset的数据。

再次,使用yield关键字,也可以起到“遍历”的效果。

那么问题来了。这iter,next,yield几个东西之间有什么关联,又有什么区别呢?pytorch的DataLoader又是用的什么样的方式呢?

分析

先看一个简单例子:

fruit = ["apple", "banana", "cherry"]
# 1. 使用for循环
for f in fruit:
    print(f)
# 2. 使用iter创建迭代器,通过next遍历
fr = iter(fruit)
print(next(fr))
print(next(fr))
print(next(fr))

"""
上面两种方式,都会打印以下结果:
apple
banana
cherry
"""

不要被同样的结果迷惑了。python的iter方法将数组转为一个可遍历的对象,对这个对象遍历,和对原来的数组遍历,是不一样的。不信,可以接着再执行下面的代码:

for f in fr:
    print(f)

这次,没有任何输出。

其中的秘密,就在于fr是用iter()创建的对象。对这个对象遍历的时候,每个元素只能遍历到一次(可以理解为,其内部有当前读指针,每读一次,指针向前移动)。就是这么简单的过程。

再看yield,实际上干了差不多的事情:

def fruit_iter (fruits):
    count = len(fruits)
    for i in range(0, count):
        yield fruits[i]

it = fruit_iter(fruit)
for f in it:
  print(f)
"""
打印结果:
apple
banana
cherry
"""

上面fruit_iter被称为Generator Function,返回一个Generator对象。可以打印一下:

print(it)
"""
<generator object fruit_iter at 0x00000240C967C990>
"""

另外,虽然从上面的例子看出,next()和for对于Generator Function都能起到遍历的作用,但next()还有额外的解释(链接):

Note: When you use next(), Python calls ._next_() on the function
you pass in as a parameter.

也就是说,next()函数实际上调用了传入函数的.__next()__成员函数。所以,如果传入的函数没有这个成员,则会报错。

到这个地方,相信对iter()函数、next()函数,以及yield有了一定的认识。但还有一个问题:pytorch的torch.utils.data.DataLoader 是用哪种方式实现的?用一个简单的例子验证一下:

import torch
# 生成一些测试数据
X = torch.normal(0, 1, (1000, 2)) # x: sample size = 1000, feature_dim = 2
y = torch.normal(0, 1, (1000, 1)) # y: sample size = 1000, dim = 1

# 定义一个函数,返回dataloader
def load_array(data_and_label, batch_size, is_train=True):
    """Construct a PyTorch data iterator."""
    dataset = data.TensorDataset(*data_and_label)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((X, y), batch_size)

#这句会报错:next(data_iter)
next(iter(data_iter))
"""
输出前10组数据
"""

这里,为什么 next(data_iter) 报错,而 next(iter(data_iter)) 可以返回数据呢?这是因为,pytorch的DataLoader函数没有 _next_ 成员,但有 _iter_ 成员(见源文件)。所以,需要首先通过 iter() 函数返回一个 _iter_ 成员,再找这个 _iter_ 的 _next_

到这里,我们对本文开始提到的问题(iter, next, yield 和 pytorch dataloader)应该有一个比较深入的了解了。

参考

https://www.guru99.com/python-yield-return-generator.html
https://realpython.com/introduction-to-python-generators/
https://anandology.com/python-practice-book/iterators.html
https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值