在线性回归的从零开始实现的代码中,由于本人python底子偏弱,一边学习写代码一边理解代码。对于加载数据这一块的indices的使用有了一点了解,因此对于这块内容做一个小笔记。
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
# 这些样本是随机读取的,没有特定的顺序
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = torch.tensor(
indices[i: min(i + batch_size, num_examples)])
yield features[batch_indices], labels[batch_indices]
在我们所谓的打乱顺序中,其实打乱的是indices,也就是下标,对于数组本身没有任何关系,indices存的就是下标。本身features数组与label数组是一一对应的,通过数组下标进行连接对应。下面的测试用例很好的表现了这类关系,
import torch
import random
# 示例数据
features = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
labels = torch.tensor([0, 1, 2, 3, 4])
# 获取样本数量
num_examples = len(features)
# 创建索引列表
indices = list(range(num_examples))#indices就是存了一个下标,然后打乱下标,对应的元素的那个数组没有变
# 打乱索引列表
random.shuffle(indices)
print(indices)
print(indices[0])
# 使用索引访问元素
batch_size = 2
for i in range(0, num_examples, batch_size):
batch_indices = indices[i:i + batch_size]
batch_features = features[batch_indices]
batch_labels = labels[batch_indices]
print(f"Batch {i//batch_size + 1}: Features = {batch_features}, Labels = {batch_labels}")
输出结果如下。