train.py

部署运行你感兴趣的模型镜像
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear
from torch.nn.modules.flatten import Flatten
from torch.utils.data import DataLoader

device = torch.device("cuda")

train_data = torchvision.datasets.CIFAR10("../tudui/data", train=True, transform=torchvision.transforms.ToTensor(),
                                          download=True)
test_data = torchvision.datasets.CIFAR10("../tudui/data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=False)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集长度为:{}".format(train_data_size))
print("测试数据集长度为:{}".format(test_data_size))

train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

tudui = Tudui()
tudui = tudui.to(device)

loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置网络参数
total_train = 0
total_test = 0
epoch = 50
writer = SummaryWriter("logs")

for i in range(epoch):
    print("--------------第{}轮训练----------".format(i + 1))
    tudui.train()
    for data in train_dataloader:
        imgs, targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        output = tudui(imgs)
        loss = loss_fn(output, targets)

        #优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train += 1
        if total_train%100 == 0:
            print("训练次数:{},loss:{}".format(total_train, loss))

    #测试
    tudui.eval()
    total_test_loss = 0
    total_test_step = 0
    total_test_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
            accuracy = (outputs.argmax(1) == targets).sum()
            total_test_accuracy += accuracy
        print("整体测试集上loss:{}".format(total_test_loss))
        print("整体测试集上accuracy:{}".format(total_test_accuracy/test_data_size))
        writer.add_scalar("test_loss", total_test_loss, total_test_step)
        writer.add_scalar("test_accuracy", total_test_accuracy/test_data_size, total_test_step)
        total_test_step += 1
    
    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值