pytorch分类

本文介绍了如何使用PyTorch对CIFAR-10数据集进行图像分类。首先,通过torch.utils.data.DataLoader自动下载并处理数据。接着,展示了数据集的预览方式。接着,讲解了模型的保存和加载方法。此外,文章还详细解释了如何计算模型的准确率,并提供了完整的代码实现。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数据集

trainset = torchvision.datasets.CIFAR10(root=root, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=root, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

采用公开数据集 cifar-10-python.tar.gz ,使用参数 download=True,代码自动下载数据,root 指向数据集存储路径。torch.utils.data.DataLoader() 打包的数据格式与上一篇读取本地图像的格式一致。

数据集预览

dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))

打包后的数据集使用 iter() 获取一个批次的数据。

模型存储与加载

torch.save(net.state_dict(), PATH)//模型参数保存至PATH
#模型加载:1、初始化网络 2、load模型
net = Net()
net.load_state_dict(toch.load(PATH))

计算准确率

correct = 0
total = 0
with torch.no_grad():
	for data in testloader:
	    images, labels = data
	    outputs = net(images)
	    _, predicted = torch.max(outputs.data, 1)
	    total += labels.size(0)
	    correct += (predicted == labels).sum().item()

1、计算准确率在不更新参数的上下文条件下进行。
2、torch.max(outputs.data, 1) 类似 one-hot 编码。
3、labels.size(0) 记录的是一个批次数据的个数。
4、predicted == labels 是两个 tensor 计算,相同为 True, 不同为 False。.sum().item() 统计 True 的总数。从而,准确率 = correct / total

全部代码

# -*- coding: utf-8 -*-
"""
Training a Classifier
=====================

This is it. You have seen how to define neural networks, compute loss and make
updates to the weights of the network.

Now you might be thinking,

What about data?
----------------

Generally, when you have to deal with image, text, audio or video data,
you can use standard python packages that load data into a numpy array.
Then you can convert this array into a ``torch.*Tensor``.

-  For images, packages such as Pillow, OpenCV are useful
-  For audio, packages such as scipy and librosa
-  For text, either raw Python or Cython based loading, or NLTK and
   SpaCy are useful

Specifically for vision, we have created a package called
``torchvision``, that has data loaders for common datasets such as
Imagenet, CIFAR10, MNIST, etc. and data transformers for images, viz.,
``torchvision.datasets`` and ``torch.utils.data.DataLoader``.

This provides a huge convenience and avoids writing boilerplate code.

For this tutorial, we will use the CIFAR10 dataset.
It has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. The images in CIFAR-10 are of
size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

.. figure:: /_static/img/cifar10.png
   :alt: cifar10

   cifar10


Training an image classifier
----------------------------

We will do the following steps in order:

1. Load and normalizing the CIFAR10 training and test datasets using
   ``torchvision``
2. Define a Convolutional Neural Network
3. Define a loss function
4. Train the network on the training data
5. Test the network on the test data

1. Loading and normalizing CIFAR10
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Using ``torchvisi
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值