如何理解shuffle
>>> indices=list(range(1000))
>>> random.shuffle(indices)
>>> print(indices)
[771, 505, 356, 548, 171, 304, 776, 143, 496, 52, 203, 811, 825, 102, 380, 560, 582, 355, 149, 574, 547, 704, 863, 425, 679, 385, 810, 190, 163, 859, 116, 29, 838, 67, 327, 75, 852, 157, 572, 223, 252, 814, 745, 517, 216, 909, 646, 347, 396, 509, 816, 152, 466, 100, 104, 900, 145, 696, 691, 315, 87, 237, 862, 737, 139, 82, 282, 168, 264, 552, 285, 873, 174, 891, 597, 378, 500, 131, 232, 842, 429, 418, 940, 193, 908, 345, 501, 487, 311, 518, 871, 178, 497, 271, 79, 549, 83, 342, 159, 747, 657, 671, 621, 363, 14, 883, 733, 440, 575, 279, 322, 239, 70, 981, 449, 204, 598, 244, 236, 120, 664, 59, 39, 133, 581, 532, 750, 445, 407, 328, 693, 32, 638, 866, 989, 326, 488, 158, 630, 478, 173, 931, 930, 42, 510, 371, 932, 895, 426, 861, 490, 793, 914, 73, 441, 868, 724, 694, 624, 551, 928, 23, 242, 527, 554, 901, 667, 179, 398, 783, 350, 844, 373, 462, 538, 938, 430, 22, 637, 307, 432, 963, 864, 199, 792, 364, 633, 886, 249, 990, 553, 555, 55, 77, 966, 666, 755, 305, 499, 984, 276, 740, 823, 588, 830, 103, 40, 113, 112, 725, 292, 748, 850, 534, 69, 484, 787, 763, 849, 557, 176, 583, 24, 815, 735, 105, 781, 374, 720, 62, 46, 952, 424, 848, 758, 786, 166, 253, 789, 872, 998, 150, 336, 35, 89, 471, 156, 743, 642, 26, 391, 246, 2, 584, 592, 529, 448, 700, 698, 535, 221, 947, 922, 38, 675, 49, 602, 710, 899, 718, 955, 162, 258, 413, 109, 757, 834, 147, 949, 524, 101, 57, 468, 267, 702, 225, 847, 135, 839, 613, 885, 813, 389, 84, 390, 806, 33, 738, 804, 799, 148, 138, 948, 915, 753, 719, 580, 522, 617, 980, 729, 699, 631, 226, 607, 970, 247, 36, 722, 302, 65, 332, 414, 526, 406, 180, 284, 979, 346, 770, 381, 600, 126, 809, 695, 917, 184, 353, 921, 372, 820, 794, 968, 896, 125, 450, 762, 469, 837, 681, 201, 85, 686, 741, 259, 661, 405, 775, 309, 463, 333, 827, 654, 394, 514, 576, 78, 983, 198, 939, 281, 408, 122, 11, 994, 58, 603, 995, 589, 687, 4, 412, 622, 296, 280, 383, 384, 516, 297, 379, 217, 28, 780, 459, 61, 808, 701, 88, 656, 520, 924, 401, 712, 563, 128, 566, 99, 428, 233, 916, 623, 92, 716, 146, 167, 419, 402, 286, 175, 668, 170, 185, 480, 420, 926, 492, 248, 996, 703, 976, 231, 508, 240, 903, 912, 790, 831, 472, 72, 354, 648, 546, 660, 136, 263, 937, 439, 835, 81, 182, 640, 382, 235, 210, 561, 93, 894, 521, 470, 41, 477, 802, 670, 443, 475, 387, 558, 329, 15, 310, 672, 541, 234, 933, 482, 66, 953, 870, 754, 485, 812, 80, 251, 8, 256, 189, 643, 228, 511, 388, 260, 303, 0, 6, 528, 997, 858, 376, 570, 777, 421, 422, 214, 172, 639, 556, 294, 609, 634, 774, 118, 913, 578, 306, 562, 321, 64, 635, 367, 713, 164, 206, 301, 669, 715, 512, 533, 818, 274, 63, 682, 19, 262, 893, 192, 21, 751, 312, 653, 507, 320, 962, 571, 74, 114, 652, 68, 544, 975, 676, 760, 692, 202, 483, 918, 465, 283, 129, 153, 987, 288, 734, 929, 822, 370, 130, 711, 56, 519, 934, 215, 295, 142, 330, 523, 265, 208, 887, 678, 565, 956, 207, 456, 452, 957, 359, 882, 106, 273, 659, 417, 362, 791, 982, 620, 177, 335, 313, 649, 386, 796, 605, 132, 502, 833, 464, 316, 766, 227, 498, 784, 395, 925, 397, 550, 444, 368, 991, 559, 936, 627, 94, 115, 427, 629, 458, 491, 689, 415, 404, 111, 493, 855, 48, 779, 973, 819, 257, 361, 334, 31, 594, 299, 795, 238, 119, 845, 587, 438, 393, 732, 945, 805, 999, 888, 960, 442, 782, 644, 344, 44, 360, 400, 619, 54, 278, 542, 615, 10, 596, 611, 788, 898, 636, 919, 662, 768, 460, 434, 7, 797, 944, 843, 461, 117, 121, 878, 677, 616, 530, 457, 433, 943, 399, 626, 683, 645, 706, 269, 890, 650, 884, 736, 800, 200, 772, 601, 911, 803, 726, 155, 151, 53, 951, 765, 291, 586, 218, 317, 194, 25, 314, 826, 832, 205, 12, 824, 340, 60, 964, 531, 503, 892, 86, 545, 245, 828, 349, 988, 220, 942, 673, 971, 684, 778, 161, 536, 451, 13, 709, 222, 568, 435, 785, 857, 197, 739, 881, 290, 165, 525, 277, 366, 641, 188, 604, 325, 224, 365, 230, 318, 323, 920, 181, 183, 759, 160, 377, 195, 941, 409, 961, 96, 319, 17, 515, 897, 714, 98, 907, 416, 658, 410, 107, 965, 935, 494, 985, 599, 690, 977, 272, 474, 331, 123, 761, 127, 606, 34, 807, 853, 593, 756, 585, 51, 403, 992, 298, 742, 923, 467, 169, 392, 974, 506, 697, 685, 612, 50, 358, 270, 767, 746, 727, 9, 877, 978, 241, 889, 902, 486, 18, 569, 144, 191, 579, 369, 266, 798, 705, 1, 841, 730, 731, 447, 187, 229, 293, 655, 874, 876, 446, 308, 154, 860, 453, 423, 959, 674, 573, 543, 665, 20, 90, 647, 352, 829, 769, 817, 47, 958, 614, 540, 137, 16, 954, 590, 91, 348, 721, 904, 255, 43, 76, 851, 504, 846, 454, 591, 880, 969, 539, 339, 495, 625, 324, 455, 5, 618, 910, 821, 338, 481, 479, 375, 108, 707, 773, 967, 663, 728, 879, 865, 906, 749, 723, 275, 45, 986, 905, 869, 196, 212, 854, 268, 97, 211, 351, 628, 337, 186, 95, 875, 357, 134, 993, 651, 946, 610, 436, 3, 476, 124, 71, 37, 261, 140, 717, 27, 972, 577, 209, 840, 300, 608, 867, 688, 708, 836, 473, 437, 411, 564, 343, 254, 801, 287, 489, 764, 567, 856, 30, 431, 752, 250, 341, 289, 537, 680, 927, 219, 513, 243, 744, 632, 213, 141, 950, 110, 595]
>>>
机器学习中数据打乱(Shuffling) 的典型用法
这种数据打乱技术是机器学习中数据预处理的标准步骤,特别是在批量训练和交叉验证中非常重要。
在PyTorch中,yield和return的区别主要体现在数据流控制和内存管理方面:
return
-
立即返回:执行到return语句时,函数立即结束并返回结果
-
一次性输出:返回完整的结果集
-
内存占用高:需要一次性保存所有结果
def get_batches_return(data, batch_size):
batches = []
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
batches.append(batch)
return batches # 一次性返回所有批次
# 使用
data = torch.randn(1000, 10)
all_batches = get_batches_return(data, 32) # 内存中保存所有批次
yield
-
生成器:创建迭代器,每次产生一个值后暂停
-
惰性计算:按需生成数据,不立即计算所有结果
-
内存友好:只保持当前批次在内存中
def get_batches_yield(data, batch_size):
for i in range(0, len(data), batch_size):
batch = data[i:i+batch_size]
yield batch # 每次产生一个批次
# 使用
data = torch.randn(1000, 10)
batch_generator = get_batches_yield(data, 32)
for batch in batch_generator:
# 训练模型
loss = model(batch)
loss.backward()
optimizer.step()
在PyTorch中的典型应用场景
yield在数据加载中的优势
class DataLoader:
def __iter__(self):
for i in range(0, len(self.dataset), self.batch_size):
indices = self.indices[i:i+self.batch_size]
batch = self.collate_fn([self.dataset[j] for j in indices])
yield batch # 逐批产生数据
# PyTorch DataLoader内部类似实现
自定义训练循环
def train_epoch(model, dataloader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(dataloader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 每100个batch返回一次状态
if batch_idx % 100 == 0:
yield batch_idx, loss.item() # 可以中途返回训练状态
# 使用
for batch_idx, current_loss in train_epoch(model, dataloader, optimizer):
print(f"Batch {batch_idx}, Loss: {current_loss:.4f}")
主要区别总结
| 特性 | return | yield |
|---|---|---|
| 执行方式 | 立即返回,函数结束 | 暂停执行,保持状态 |
| 内存使用 | 高(保存所有结果) | 低(只保存当前状态) |
| 适用场景 | 小数据集,需要完整结果 | 大数据集,流式处理 |
| 返回类型 | 具体数据 | 生成器对象 |
在PyTorch中,yield特别适合处理大规模数据集和实现自定义的训练流程控制。
若把书中yield改为return,则会报错,原因在于:
-
yield与return的区别:-
yield:生成器函数,可以多次产生值 -
return:普通函数,只返回一次值
-
-
你的代码逻辑:
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = indices[i:min(i + batch_size, num_examples)]
return features[batch_indices], labels[batch_indices] # 这里有问题!
问题:在第一次循环时就执行了 return,函数立即结束,只返回一个批次的数据。
正确的解决方案
方案1:使用 yield(推荐)
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
batch_indices = indices[i:min(i + batch_size, num_examples)]
yield features[batch_indices], labels[batch_indices] # 使用 yield
# 使用方式
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
print(X, '\n', y)
# 这里只要注释掉break就会遍历所有批次
# break
方案2:使用 return 但返回所有批次
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
batches = []
for i in range(0, num_examples, batch_size):
batch_indices = indices[i:min(i + batch_size, num_examples)]
batches.append((features[batch_indices], labels[batch_indices]))
return batches # 返回所有批次的列表
# 使用方式
batch_size = 10
batches = data_iter(batch_size, features, labels)
for X, y in batches:
print(X, '\n', y)
关于上面的append
假设我们有一个数据集,特征features和标签labels,假设有1000个样本,每个特征是一个向量,标签是一个标量。
我们想要将数据分成多个批次,每个批次包含batch_size个样本。
步骤:
-
我们首先生成一个索引列表[0,1,2,...,999],然后打乱这个列表。
-
我们按照打乱后的顺序,每次取batch_size个索引(最后一个批次可能不足batch_size)。
-
对于每个批次,我们根据这些索引从features和labels中取出对应的数据。
现在,我们来看这行代码:
batches.append((features[batch_indices], labels[batch_indices]))
这里,batch_indices是一个整数列表,表示当前批次的索引。
例如,假设batch_size=10,第一个批次的batch_indices可能是[5, 12, 36, ..., 88](打乱后的10个索引)。
那么:
features[batch_indices] 会从features中取出这些索引对应的特征,形成一个子数组。
labels[batch_indices] 会从labels中取出这些索引对应的标签,形成一个子数组。
然后,我们将这两个子数组组成一个元组 (特征子数组, 标签子数组),并将这个元组添加到batches列表中。
因此,batches列表最终会包含多个这样的元组,每个元组代表一个批次的特征和标签。
例如,如果有1000个样本,batch_size=10,那么batches将是一个长度为100的列表,每个元素是一个元组,元组的第一个元素是10个特征,第二个元素是10个标签。
这样,当我们遍历batches时,每次取出一个元组,就可以用两个变量(比如X和y)来接收这个元组中的两个元素。
注意:这里用两个括号是因为外面的括号是append的参数,里面的括号是元组的构造。
即:append( (元素1, 元素2) ),这样就将一个元组添加到了列表中。
希望这样解释清楚了。
batches.append((features[batch_indices], labels[batch_indices]))
这行代码可以分为几个部分来理解:
1. features[batch_indices] 和 labels[batch_indices]
-
batch_indices是一个索引列表,比如[3, 7, 2, 9, 66, 818, 938, 33, 69, 2] -
features[batch_indices]会从features数组中提取索引为3, 7, 2, 9, 66, 818, 938, 33, 69, 2的数据 -
labels[batch_indices]会从labels数组中提取相同索引的数据
2. (features[batch_indices], labels[batch_indices])
-
这里创建了一个元组(tuple)
-
元组包含两个元素:特征批次和标签批次
-
例如:
(batch_features, batch_labels)
3. batches.append(...)
-
将这个元组添加到
batches列表中
具体例子
假设我们有:
features = [f1, f2, f3, f4, f5, f6] # 6个样本的特征
labels = [l1, l2, l3, l4, l5, l6] # 6个样本的标签
batch_size = 2
batch_indices = [1, 3, 0] # 当前批次的索引
那么:
features[batch_indices] = [f1, f3, f0] # 提取索引1,3,0对应的特征
labels[batch_indices] = [l1, l3, l0] # 提取相同索引的标签
# 创建元组
batch_tuple = ( [f1, f3, f0], [l1, l3, l0] )
# 添加到批次列表
batches.append(batch_tuple)
最终结果
batches 列表会变成:
batches = [
( [第一批特征], [第一批标签] ),
( [第二批特征], [第二批标签] ),
( [第三批特征], [第三批标签] ),
# ...
]
为什么要这样做?
-
保持对应关系:确保每个批次的特征和标签是匹配的
-
便于访问:后续使用时可以这样:
for batch_features, batch_labels in batches:
# 训练模型
model.train(batch_features, batch_labels)
-
数据组织:将相关的数据组织在一起,便于管理和传递
简单来说,这行代码就是在创建一个包含特征和标签的配对,并保存到列表中,为后续的批量训练做准备。
为什么会出现 "too many values to unpack" 错误
当你使用 return 时:
-
函数返回的是
(features[batch_indices], labels[batch_indices])这个元组 -
在循环中,Python 尝试遍历这个元组
-
元组的第一个元素是
features[batch_indices](一个数组) -
当尝试将数组解包给
X, y两个变量时,就出现了错误
总结
使用 yield 是正确的做法,因为:
-
可以按需生成批次,节省内存
-
支持在循环中逐个处理批次
-
是标准的批量数据处理模式
把 return 改回 yield 就能解决你的问题!
break导致只打印(输出)第一个元素
>>> for number in range(0,6):
... print(number)
... break
...
0
>>> for number in range(0,6):
... print(number)
...
0
1
2
3
4
5
>>>
1773

被折叠的 条评论
为什么被折叠?



