import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor
from torch import nn
from torchsummary import summary
from torch import optim
from torch.utils.data import DataLoader
模型构建
class imagecls(nn.Module):
def __init__(self):
super(imagecls, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=3, stride=1, padding=0)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3=nn.Conv2d(16,128,kernel_size=3,stride=1,padding=0)
self.pool3=nn.MaxPool2d(kernel_size=2,stride=2)
self.fc1 = nn.Linear(in_features=512, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=4096)#out_size=num_filters
self.out = nn.Linear(in_features=4096, out_features=10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x.cuda())))
x = self.pool2(torch.relu(self.conv2(x)))
x=self.pool3(torch.relu(self.conv3(x)))
x = x.reshape(x.size(0), -1)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
out = self.out(x)
return out
#模型训练
def train():
optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=[0.9, 0.99])
error = nn.CrossEntropyLoss()
epoches = 10000
for epoch in range(epoches):
dataloader =DataLoader(train_dataset,batch_size=128,shuffle=True)
loss_sum = 0
num = 0.1
for x,y in dataloader:
y_=model(x.cuda())
loss =error(y_,y.cuda())
loss_sum+= loss.item()
num += 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss_sum/num)
break
torch.save(model.state_dict(),'model.pth')
模型预测
def test():
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)
model.load_state_dict(torch.load('model.pth'))
corr = 0
num = 0
for x, y in test_dataloader:
y_ = model(x).cuda()
out = torch.argmax(y_, dim=-1)
corr += (out == y.cuda()).sum()
num += len(y)
print(corr / num)
if __name__=='__main__':
model = imagecls().cuda()
summary(model, input_size=(3, 32, 32), batch_size=5)
train_dataset = CIFAR10(root='data', train=True, transform=Compose([ToTensor()]))
test_dataset = CIFAR10(root='data', train=False, transform=Compose([ToTensor()]))
#print(train_dataset.class_to_idx)
train()
test()