动手学深度学习 -- 3.5~3.7

3.5-3.7 算法逻辑总结

1. 数据的来源与预处理

数据来源

  • 数据集名称:Fashion‑MNIST
  • 数据来源:Fashion‑MNIST 数据集由 Xiao 等人于 2017 年发布,作为 MNIST 数据集的更具挑战性的替代品。
  • 下载方式
    • 通过深度学习框架内置的接口自动下载。例如,在 PyTorch 中使用 torchvision.datasets.FashionMNIST,并指定参数 download=True 后,如果本地不存在数据,则自动从官网或镜像站点下载数据文件。

数据原始内容

  • 样本:每个样本是一张灰度图像,原始尺寸为 28×28 像素,代表不同服装类别(共 10 类,如 T 恤、裤子、外套等)。
  • 标签:每张图像对应一个数字标签(0~9),表示图像所属的类别。

数据转换过程

  1. 转换为张量

    • 函数transforms.ToTensor()
    • 作用:将 PIL 图像或 NumPy 数组转换为张量,同时将像素值从 0~255 归一化到 [0, 1]。
    • 结果:图像数据从 28×28 的格式转换为张量格式。在 PyTorch 中,通常为 (C, H, W),此处灰度图像的通道数为 1。
  2. 可选的图像尺寸调整

    • 函数transforms.Resize(resize)
    • 作用:如果调用数据加载函数时指定了 resize 参数(例如设置为 64),则首先将图像尺寸调整为 64×64。
    • 注意:调整尺寸仅改变图像的高度和宽度,不改变通道数。
  3. 组合转换

    • 函数transforms.Compose(trans)
    • 作用:将多个转换(例如 Resize 和 ToTensor)按顺序组合成一个整体的转换函数,加载每个样本时依次应用这些转换。

最终,通过如下函数实现数据的下载、转换和加载:

def load_data_fashion_mnist(batch_size, resize=None):  #@save
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))  # 先进行尺寸调整
    trans = transforms.Compose(trans)             # 组合转换函数
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=get_dataloader_workers()),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=get_dataloader_workers()))
  • 输出:返回两个数据迭代器(DataLoader),分别用于训练和测试。
  • 示例数据格式:调用 load_data_fashion_mnist(32, resize=64) 后,每个批次中图像数据 X 的形状为 (32, 1, 64, 64),标签 y 的形状为 (32,)

2. 从数据到模型预测:softmax 回归的手工实现

数据在模型中的输入

  • 展平处理
    在训练 softmax 回归模型时,每张图像需要先被“展平”为一个向量。例如,将 28×28 的图像展平成 784 维向量,通常使用:

    X.reshape((-1, W.shape[0]))
    

    其中 W 是模型的权重矩阵,其第一维即输入特征数(例如 784)。

模型参数与线性变换

  • 参数初始化

    • 权重矩阵 W:形状为 n u m _ i n p u t s × n u m _ o u t p u t s num\_inputs \times num\_outputs num_inputs×num_outputs(例如 784 × 10 784 \times 10 784×10),采用均值为 0、标准差为 0.01 的正态分布随机初始化。
    • 偏置 b:形状为 10 10 10(或 1 × 10 1 \times 10 1×10),初始化为 0。
  • 线性变换计算
    对每个展平后的样本 x x x,计算未归一化的输出(logits):

o = x W + b o = xW + b o=xW+b

其中:
- x x x 的形状为 ( 1 , 784 ) (1, 784) (1,784)
- W W W 的形状为 ( 784 , 10 ) (784, 10)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Elsa的迷弟

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值