今天学习了神经网络-线性层及其他层。这段代码的主要目的是为了实现一个简单的神经网络模型,并在 CIFAR-10 数据集上进行预测。
import torch
import torchvision
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor(),
download=True)
dataloader = DataLoader(dataset, batch_size=64,drop_last=True)
class Yang(nn.Module):
def __init__(self):
super(Yang, self).__init__()
self.linear1 = Linear(196608, 10)
def forward(self, input):
output = self.linear1(input)
return output
yang = Yang()
for data in dataloader:
imgs, targets = data
print(