def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices) # 样本的读取顺序是随机的。
for i in range(0, num_examples, batch_size):
j = nd.array(indices[i: min(i + batch_size, num_examples)])
yield features.take(j), labels.take(j) # take 函数根据索引返回对应元素。
使用:
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, y)
break

本文介绍了一种用于机器学习训练过程中的数据加载方法,通过定义一个数据迭代器函数实现小批量随机读取训练数据,确保每次训练都能获取到不同的数据组合,有助于提高模型的泛化能力。
1848

被折叠的 条评论
为什么被折叠?



