import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
dataloader = DataLoader(dataset, batch_size=64)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.linear1 = Linear(196608, 10)
def forward(self, input):
output = self.linear1(input)
return output
tudui = Tudui()
for data in dataloader:
imgs, targets = data
output = torch.reshape(imgs, (1, 1, 1, -1))
print(output.shape)
output = tudui(output)
print(output)
nn_linear.py
最新推荐文章于 2025-01-19 20:29:04 发布
该博客展示了如何利用PyTorch加载CIFAR10数据集,并定义一个简单的全连接网络(Tudui)进行图像分类。通过DataLoader加载数据,对输入图像进行reshape,然后通过网络进行前向传播,输出模型的预测结果。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
1187

被折叠的 条评论
为什么被折叠?



