读动手学深度学习,P63的代码

部署运行你感兴趣的模型镜像

如何理解shuffle

>>> indices=list(range(1000))
>>> random.shuffle(indices)
>>> print(indices)
[771, 505, 356, 548, 171, 304, 776, 143, 496, 52, 203, 811, 825, 102, 380, 560, 582, 355, 149, 574, 547, 704, 863, 425, 679, 385, 810, 190, 163, 859, 116, 29, 838, 67, 327, 75, 852, 157, 572, 223, 252, 814, 745, 517, 216, 909, 646, 347, 396, 509, 816, 152, 466, 100, 104, 900, 145, 696, 691, 315, 87, 237, 862, 737, 139, 82, 282, 168, 264, 552, 285, 873, 174, 891, 597, 378, 500, 131, 232, 842, 429, 418, 940, 193, 908, 345, 501, 487, 311, 518, 871, 178, 497, 271, 79, 549, 83, 342, 159, 747, 657, 671, 621, 363, 14, 883, 733, 440, 575, 279, 322, 239, 70, 981, 449, 204, 598, 244, 236, 120, 664, 59, 39, 133, 581, 532, 750, 445, 407, 328, 693, 32, 638, 866, 989, 326, 488, 158, 630, 478, 173, 931, 930, 42, 510, 371, 932, 895, 426, 861, 490, 793, 914, 73, 441, 868, 724, 694, 624, 551, 928, 23, 242, 527, 554, 901, 667, 179, 398, 783, 350, 844, 373, 462, 538, 938, 430, 22, 637, 307, 432, 963, 864, 199, 792, 364, 633, 886, 249, 990, 553, 555, 55, 77, 966, 666, 755, 305, 499, 984, 276, 740, 823, 588, 830, 103, 40, 113, 112, 725, 292, 748, 850, 534, 69, 484, 787, 763, 849, 557, 176, 583, 24, 815, 735, 105, 781, 374, 720, 62, 46, 952, 424, 848, 758, 786, 166, 253, 789, 872, 998, 150, 336, 35, 89, 471, 156, 743, 642, 26, 391, 246, 2, 584, 592, 529, 448, 700, 698, 535, 221, 947, 922, 38, 675, 49, 602, 710, 899, 718, 955, 162, 258, 413, 109, 757, 834, 147, 949, 524, 101, 57, 468, 267, 702, 225, 847, 135, 839, 613, 885, 813, 389, 84, 390, 806, 33, 738, 804, 799, 148, 138, 948, 915, 753, 719, 580, 522, 617, 980, 729, 699, 631, 226, 607, 970, 247, 36, 722, 302, 65, 332, 414, 526, 406, 180, 284, 979, 346, 770, 381, 600, 126, 809, 695, 917, 184, 353, 921, 372, 820, 794, 968, 896, 125, 450, 762, 469, 837, 681, 201, 85, 686, 741, 259, 661, 405, 775, 309, 463, 333, 827, 654, 394, 514, 576, 78, 983, 198, 939, 281, 408, 122, 11, 994, 58, 603, 995, 589, 687, 4, 412, 622, 296, 280, 383, 384, 516, 297, 379, 217, 28, 780, 459, 61, 808, 701, 88, 656, 520, 924, 401, 712, 563, 128, 566, 99, 428, 233, 916, 623, 92, 716, 146, 167, 419, 402, 286, 175, 668, 170, 185, 480, 420, 926, 492, 248, 996, 703, 976, 231, 508, 240, 903, 912, 790, 831, 472, 72, 354, 648, 546, 660, 136, 263, 937, 439, 835, 81, 182, 640, 382, 235, 210, 561, 93, 894, 521, 470, 41, 477, 802, 670, 443, 475, 387, 558, 329, 15, 310, 672, 541, 234, 933, 482, 66, 953, 870, 754, 485, 812, 80, 251, 8, 256, 189, 643, 228, 511, 388, 260, 303, 0, 6, 528, 997, 858, 376, 570, 777, 421, 422, 214, 172, 639, 556, 294, 609, 634, 774, 118, 913, 578, 306, 562, 321, 64, 635, 367, 713, 164, 206, 301, 669, 715, 512, 533, 818, 274, 63, 682, 19, 262, 893, 192, 21, 751, 312, 653, 507, 320, 962, 571, 74, 114, 652, 68, 544, 975, 676, 760, 692, 202, 483, 918, 465, 283, 129, 153, 987, 288, 734, 929, 822, 370, 130, 711, 56, 519, 934, 215, 295, 142, 330, 523, 265, 208, 887, 678, 565, 956, 207, 456, 452, 957, 359, 882, 106, 273, 659, 417, 362, 791, 982, 620, 177, 335, 313, 649, 386, 796, 605, 132, 502, 833, 464, 316, 766, 227, 498, 784, 395, 925, 397, 550, 444, 368, 991, 559, 936, 627, 94, 115, 427, 629, 458, 491, 689, 415, 404, 111, 493, 855, 48, 779, 973, 819, 257, 361, 334, 31, 594, 299, 795, 238, 119, 845, 587, 438, 393, 732, 945, 805, 999, 888, 960, 442, 782, 644, 344, 44, 360, 400, 619, 54, 278, 542, 615, 10, 596, 611, 788, 898, 636, 919, 662, 768, 460, 434, 7, 797, 944, 843, 461, 117, 121, 878, 677, 616, 530, 457, 433, 943, 399, 626, 683, 645, 706, 269, 890, 650, 884, 736, 800, 200, 772, 601, 911, 803, 726, 155, 151, 53, 951, 765, 291, 586, 218, 317, 194, 25, 314, 826, 832, 205, 12, 824, 340, 60, 964, 531, 503, 892, 86, 545, 245, 828, 349, 988, 220, 942, 673, 971, 684, 778, 161, 536, 451, 13, 709, 222, 568, 435, 785, 857, 197, 739, 881, 290, 165, 525, 277, 366, 641, 188, 604, 325, 224, 365, 230, 318, 323, 920, 181, 183, 759, 160, 377, 195, 941, 409, 961, 96, 319, 17, 515, 897, 714, 98, 907, 416, 658, 410, 107, 965, 935, 494, 985, 599, 690, 977, 272, 474, 331, 123, 761, 127, 606, 34, 807, 853, 593, 756, 585, 51, 403, 992, 298, 742, 923, 467, 169, 392, 974, 506, 697, 685, 612, 50, 358, 270, 767, 746, 727, 9, 877, 978, 241, 889, 902, 486, 18, 569, 144, 191, 579, 369, 266, 798, 705, 1, 841, 730, 731, 447, 187, 229, 293, 655, 874, 876, 446, 308, 154, 860, 453, 423, 959, 674, 573, 543, 665, 20, 90, 647, 352, 829, 769, 817, 47, 958, 614, 540, 137, 16, 954, 590, 91, 348, 721, 904, 255, 43, 76, 851, 504, 846, 454, 591, 880, 969, 539, 339, 495, 625, 324, 455, 5, 618, 910, 821, 338, 481, 479, 375, 108, 707, 773, 967, 663, 728, 879, 865, 906, 749, 723, 275, 45, 986, 905, 869, 196, 212, 854, 268, 97, 211, 351, 628, 337, 186, 95, 875, 357, 134, 993, 651, 946, 610, 436, 3, 476, 124, 71, 37, 261, 140, 717, 27, 972, 577, 209, 840, 300, 608, 867, 688, 708, 836, 473, 437, 411, 564, 343, 254, 801, 287, 489, 764, 567, 856, 30, 431, 752, 250, 341, 289, 537, 680, 927, 219, 513, 243, 744, 632, 213, 141, 950, 110, 595]
>>> 

机器学习中数据打乱(Shuffling) 的典型用法

这种数据打乱技术是机器学习中数据预处理的标准步骤,特别是在批量训练和交叉验证中非常重要。

在PyTorch中,yieldreturn的区别主要体现在数据流控制和内存管理方面:

return

  • 立即返回:执行到return语句时,函数立即结束并返回结果

  • 一次性输出:返回完整的结果集

  • 内存占用高:需要一次性保存所有结果

def get_batches_return(data, batch_size):
    batches = []
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        batches.append(batch)
    return batches  # 一次性返回所有批次

# 使用
data = torch.randn(1000, 10)
all_batches = get_batches_return(data, 32)  # 内存中保存所有批次

yield

  • 生成器:创建迭代器,每次产生一个值后暂停

  • 惰性计算:按需生成数据,不立即计算所有结果

  • 内存友好:只保持当前批次在内存中

def get_batches_yield(data, batch_size):
    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        yield batch  # 每次产生一个批次

# 使用
data = torch.randn(1000, 10)
batch_generator = get_batches_yield(data, 32)

for batch in batch_generator:
    # 训练模型
    loss = model(batch)
    loss.backward()
    optimizer.step()

在PyTorch中的典型应用场景

yield在数据加载中的优势

class DataLoader:
    def __iter__(self):
        for i in range(0, len(self.dataset), self.batch_size):
            indices = self.indices[i:i+self.batch_size]
            batch = self.collate_fn([self.dataset[j] for j in indices])
            yield batch  # 逐批产生数据
    
    # PyTorch DataLoader内部类似实现

自定义训练循环

def train_epoch(model, dataloader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # 每100个batch返回一次状态
        if batch_idx % 100 == 0:
            yield batch_idx, loss.item()  # 可以中途返回训练状态

# 使用
for batch_idx, current_loss in train_epoch(model, dataloader, optimizer):
    print(f"Batch {batch_idx}, Loss: {current_loss:.4f}")

主要区别总结

特性returnyield
执行方式立即返回,函数结束暂停执行,保持状态
内存使用高(保存所有结果)低(只保存当前状态)
适用场景小数据集,需要完整结果大数据集,流式处理
返回类型具体数据生成器对象

在PyTorch中,yield特别适合处理大规模数据集和实现自定义的训练流程控制。

若把书中yield改为return,则会报错,原因在于:

  1. yield 与 return 的区别

    • yield:生成器函数,可以多次产生值

    • return:普通函数,只返回一次值

  2. 你的代码逻辑

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    
    for i in range(0, num_examples, batch_size):
        batch_indices = indices[i:min(i + batch_size, num_examples)]
        return features[batch_indices], labels[batch_indices]  # 这里有问题!

问题:在第一次循环时就执行了 return,函数立即结束,只返回一个批次的数据。

正确的解决方案

方案1:使用 yield(推荐)

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    
    for i in range(0, num_examples, batch_size):
        batch_indices = indices[i:min(i + batch_size, num_examples)]
        yield features[batch_indices], labels[batch_indices]  # 使用 yield

# 使用方式
batch_size = 10
for X, y in data_iter(batch_size, features, labels):
    print(X, '\n', y)
    # 这里只要注释掉break就会遍历所有批次
    # break

方案2:使用 return 但返回所有批次

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)
    
    batches = []
    for i in range(0, num_examples, batch_size):
        batch_indices = indices[i:min(i + batch_size, num_examples)]
        batches.append((features[batch_indices], labels[batch_indices]))
    
    return batches  # 返回所有批次的列表

# 使用方式
batch_size = 10
batches = data_iter(batch_size, features, labels)
for X, y in batches:
    print(X, '\n', y)

关于上面的append

假设我们有一个数据集,特征features和标签labels,假设有1000个样本,每个特征是一个向量,标签是一个标量。
我们想要将数据分成多个批次,每个批次包含batch_size个样本。

步骤:

  1. 我们首先生成一个索引列表[0,1,2,...,999],然后打乱这个列表。

  2. 我们按照打乱后的顺序,每次取batch_size个索引(最后一个批次可能不足batch_size)。

  3. 对于每个批次,我们根据这些索引从features和labels中取出对应的数据。

现在,我们来看这行代码:
batches.append((features[batch_indices], labels[batch_indices]))

这里,batch_indices是一个整数列表,表示当前批次的索引。
例如,假设batch_size=10,第一个批次的batch_indices可能是[5, 12, 36, ..., 88](打乱后的10个索引)。

那么:
features[batch_indices] 会从features中取出这些索引对应的特征,形成一个子数组。
labels[batch_indices] 会从labels中取出这些索引对应的标签,形成一个子数组。

然后,我们将这两个子数组组成一个元组 (特征子数组, 标签子数组),并将这个元组添加到batches列表中。

因此,batches列表最终会包含多个这样的元组,每个元组代表一个批次的特征和标签。

例如,如果有1000个样本,batch_size=10,那么batches将是一个长度为100的列表,每个元素是一个元组,元组的第一个元素是10个特征,第二个元素是10个标签。

这样,当我们遍历batches时,每次取出一个元组,就可以用两个变量(比如X和y)来接收这个元组中的两个元素。

注意:这里用两个括号是因为外面的括号是append的参数,里面的括号是元组的构造。
即:append( (元素1, 元素2) ),这样就将一个元组添加到了列表中。

希望这样解释清楚了。

batches.append((features[batch_indices], labels[batch_indices]))

这行代码可以分为几个部分来理解:

1. features[batch_indices] 和 labels[batch_indices]

  • batch_indices 是一个索引列表,比如 [3, 7, 2, 9, 66, 818, 938, 33, 69, 2]

  • features[batch_indices] 会从 features 数组中提取索引为 3, 7, 2, 9, 66, 818, 938, 33, 69, 2的数据

  • labels[batch_indices] 会从 labels 数组中提取相同索引的数据

2. (features[batch_indices], labels[batch_indices])

  • 这里创建了一个元组(tuple)

  • 元组包含两个元素:特征批次和标签批次

  • 例如:(batch_features, batch_labels)

3. batches.append(...)

  • 将这个元组添加到 batches 列表中

具体例子

假设我们有:

features = [f1, f2, f3, f4, f5, f6]  # 6个样本的特征
labels = [l1, l2, l3, l4, l5, l6]    # 6个样本的标签
batch_size = 2
batch_indices = [1, 3, 0]  # 当前批次的索引

那么:

features[batch_indices] = [f1, f3, f0]  # 提取索引1,3,0对应的特征
labels[batch_indices] = [l1, l3, l0]    # 提取相同索引的标签

# 创建元组
batch_tuple = ( [f1, f3, f0], [l1, l3, l0] )

# 添加到批次列表
batches.append(batch_tuple)

最终结果

batches 列表会变成:

batches = [
    ( [第一批特征], [第一批标签] ),
    ( [第二批特征], [第二批标签] ),
    ( [第三批特征], [第三批标签] ),
    # ...
]

为什么要这样做?

  1. 保持对应关系:确保每个批次的特征和标签是匹配的

  2. 便于访问:后续使用时可以这样:

for batch_features, batch_labels in batches:
    # 训练模型
    model.train(batch_features, batch_labels)
  1. 数据组织:将相关的数据组织在一起,便于管理和传递

简单来说,这行代码就是在创建一个包含特征和标签的配对,并保存到列表中,为后续的批量训练做准备。

为什么会出现 "too many values to unpack" 错误

当你使用 return 时:

  • 函数返回的是 (features[batch_indices], labels[batch_indices]) 这个元组

  • 在循环中,Python 尝试遍历这个元组

  • 元组的第一个元素是 features[batch_indices](一个数组)

  • 当尝试将数组解包给 X, y 两个变量时,就出现了错误

总结

使用 yield 是正确的做法,因为:

  • 可以按需生成批次,节省内存

  • 支持在循环中逐个处理批次

  • 是标准的批量数据处理模式

把 return 改回 yield 就能解决你的问题!

break导致只打印(输出)第一个元素

>>> for number in range(0,6):
...     print(number)
...     break
... 
0
>>> for number in range(0,6):
...     print(number)
... 
0
1
2
3
4
5
>>> 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

### 动手学深度学习第四章课后题答案解析 #### 关于随机梯度下降(SGD) 对于《动手学深度学习》中的第四章节,涉及到了随机梯度下降算法及其可能遇到的问题。SGD是一种常用的优化方法,在每次迭代过程中仅使用单一样本或者少量样本更新权重参数,这使得训练过程更加高效并能有效处理大规模数据集。然而,这种方法也可能带来一些挑战,比如收敛路径不稳定以及容易陷入局部最优等问题[^1]。 #### 马尔可夫假设与n-gram模型的数据稀疏性 书中提到马尔科夫假设指的是词语之间相互独立;当试图预测下一个单词时,则需考虑前面所有已知词汇组成的序列概率分布情况。采用n-gram模型来建模自然语言时会面临严重的数据稀疏现象,即大多数情况下所估计得到的概率值接近于零,这是因为实际语料库难以覆盖所有的n-gram组合实例[^2]。 #### 练习部分指导建议 针对第四章的具体练习题目解答: - 对于涉及到线性回归的内容,应该熟悉如何从头构建一个简单的线性回归模型,并理解其背后的原理和实现细节。同时也要掌握利用高级API快速搭建相同功能模块的方法。 - 在softmax回归方面,重点在于了解多类别分类任务下的损失函数定义方式——交叉熵损失,并能够手动编写代码完成整个前向传播到反向传播的过程。 - 此外,《动手学深度学习》还强调了对图像分类数据集的操作实践,包括加载预处理步骤在内的全流程操作指南都非常重要。 为了更好地理解和解决这些练习题,推荐者仔细阅每节后面的小结部分,它总结了核心知识点并对后续的学习提供了方向指引[^3]。 ```python import torch from d2l import torch as d2l def train_epoch_ch3(net, train_iter, loss, updater): #@save """训练模型一个迭代周期(定义见第3章)。""" # 将模型设置为训练模式 if isinstance(net, torch.nn.Module): net.train() # 训练损失总和、训练准确率总和、样本数 metric = Accumulator(3) for X, y in train_iter: # 计算梯度并更新参数 l = loss(net(X), y) if isinstance(updater, torch.optim.Optimizer): # 使用PyTorch内置的优化器和损失函数 updater.zero_grad() l.mean().backward() updater.step() else: # 使用定制的优化器和损失函数 l.sum().backward() updater(X.shape[0]) metric.add(float(l.sum()), accuracy(net(X), y), y.numel()) return metric[0] / metric[2], metric[1] / metric[2] class Accumulator: #@save """在`n`个变量上累加。""" def __init__(self, n): self.data = [0.0] * n def add(self, *args): self.data = [a + float(b) for a, b in zip(self.data, args)] def reset(self): self.data = [0.0] * len(self.data) def __getitem__(self, idx): return self.data[idx] ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值