DL学习1--从零开始线性回归

该代码示例展示了如何在PyTorch中利用Fashion-MNIST数据集构建一个线性回归模型进行图像分类。模型训练和测试过程中计算了交叉熵损失和准确率,并通过梯度下降法更新参数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

DL学习1--从零开始线性回归

使用mnist数据集表写了一个线性回归分类器。

from torch.utils import data
import torchvision
from torchvision import transforms
import torch
import torch.nn as nn
from d2l import torch as d2l
import tqdm

def load_data_fashion_mnist(batch_size, resize=None):
    """下载Fashion-MNIST数据集,然后将其加载到内存中"""
    trans = [transforms.ToTensor()]
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=True)
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=True)
    return (data.DataLoader(mnist_train, batch_size, shuffle=True),
            data.DataLoader(mnist_test, batch_size, shuffle=False))

# batch size
batch_size = 64

train_iter,test_iter = load_data_fashion_mnist(batch_size=batch_size)
for X,y in train_iter:
    print(X.shape,y.shape)
    break

input_shape = 784
output_shape = 10

w = torch.normal(0,0.1,size=(input_shape,output_shape),requires_grad=True)
b = torch.zeros(output_shape,requires_grad=True)

def net(X):
    return softmax( torch.matmul(X.reshape(-1,w.shape[0]) , w ) + b )
def softmax(y_ht):
    y_hat = torch.exp(y_ht)
    tmp = y_hat.sum(1,keepdim=True)
    return y_hat / tmp

def cross_entropy(y_hat,y):
    return - torch.log( y_hat[range(len(y_hat)) , y] )

def accuracy(y_hat,y):
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:
        y_hat = y_hat.argmax(axis=1)
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum())

def sgd(pamars,lr,batch_size):
    with torch.no_grad():
        for pama in pamars:
            pama -= lr* pama.grad / batch_size
            pama.grad.zero_()

def updater(batch_size):
    return sgd([w,b],lr,batch_size)

# train
lr = 0.01
epochs = 20

for epoch in tqdm.tqdm(range(epochs)):
    if isinstance(net,torch.nn.Module):
        net.train()
    true_data,data_len = 0,0
    true_data_t,data_len_t = 0,0
    for X,y in train_iter:
        preb = net(X)
        l = cross_entropy(preb,y)
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.sum().backward()
            updater.step()
        else:
            l.sum().backward()
            updater(X.shape[0])
        true_data+= accuracy(preb,y)
        data_len+=len(X)
    # 测试数据
    if isinstance(net, torch.nn.Module):
        net.eval()
    with torch.no_grad():
        for X,y in test_iter:
            pre = net(X)
            l = cross_entropy(pre,y)
            true_data_t += accuracy(pre,y)
            data_len_t += len(X)

    print("epoch =",epoch+1,' train acc=',true_data/data_len,' test acc=',true_data_t/data_len_t)
    true_data, data_len = 0, 0
    true_data_t, data_len_t = 0, 0

运行结果

E:\anaconda3\envs\torch1.13\python.exe C:/Users/Jie/Desktop/renwu/LiMu-AI/py/1.softmax回归实现.py
torch.Size([64, 1, 28, 28]) torch.Size([64])
  5%|▌         | 1/20 [00:11<03:38, 11.49s/it]epoch = 1  train acc= 0.6595166666666666  test acc= 0.7441
 10%|█         | 2/20 [00:22<03:23, 11.30s/it]epoch = 2  train acc= 0.7742666666666667  test acc= 0.7789
 15%|█▌        | 3/20 [00:33<03:11, 11.26s/it]epoch = 3  train acc= 0.7963333333333333  test acc= 0.7943
 20%|██        | 4/20 [00:45<03:00, 11.25s/it]epoch = 4  train acc= 0.8083833333333333  test acc= 0.8007
epoch = 5  train acc= 0.8153833333333333  test acc= 0.8067
 30%|███       | 6/20 [01:07<02:37, 11.21s/it]epoch = 6  train acc= 0.8207666666666666  test acc= 0.8139
 35%|███▌      | 7/20 [01:18<02:25, 11.21s/it]epoch = 7  train acc= 0.8238333333333333  test acc= 0.8146
epoch = 8  train acc= 0.8278833333333333  test acc= 0.8161
 45%|████▌     | 9/20 [01:41<02:03, 11.26s/it]epoch = 9  train acc= 0.8303333333333334  test acc= 0.8196
 50%|█████     | 10/20 [01:52<01:52, 11.26s/it]epoch = 10  train acc= 0.8333666666666667  test acc= 0.824
epoch = 11  train acc= 0.8350333333333333  test acc= 0.823
 60%|██████    | 12/20 [02:14<01:29, 11.23s/it]epoch = 12  train acc= 0.8369  test acc= 0.8237
 65%|██████▌   | 13/20 [02:25<01:18, 11.19s/it]epoch = 13  train acc= 0.8387166666666667  test acc= 0.8254
epoch = 14  train acc= 0.83915  test acc= 0.8264
 75%|███████▌  | 15/20 [02:48<00:56, 11.25s/it]epoch = 15  train acc= 0.8408833333333333  test acc= 0.8266
epoch = 16  train acc= 0.8413  test acc= 0.8278
 85%|████████▌ | 17/20 [03:11<00:34, 11.35s/it]epoch = 17  train acc= 0.8427666666666667  test acc= 0.8269
 90%|█████████ | 18/20 [03:22<00:22, 11.38s/it]epoch = 18  train acc= 0.84355  test acc= 0.828
 95%|█████████▌| 19/20 [03:34<00:11, 11.31s/it]epoch = 19  train acc= 0.8446166666666667  test acc= 0.8302
epoch = 20  train acc= 0.84585  test acc= 0.8297
100%|██████████| 20/20 [03:45<00:00, 11.27s/it]

Process finished with exit code 0

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值