CNN实现CIFAR 10 图像分类

get_data.py

import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader



transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在吧图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])



trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) #训练数据集
trainloader = DataLoader(trainset, batch_size=100, shuffle=True)


testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

show_test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False)






 

net.py

from torch import nn




class CNN(nn.Module):
    def __init__(self):
        super().__init__()   # 3,32,32
        self.layer1 = nn.Sequential(
                nn.Conv2d(in_channels = 3,out_channels = 16,kernel_size = 3,stride = 1,padding = 1),  #16,32,32
                nn.BatchNorm2d(16),
                nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
                nn.Conv2d(in_channels = 16, out_channels = 48, kernel_size = 3, stride = 1, padding = 1), #48,32,32
                nn.BatchNorm2d(48),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2),  # 48,16,16
        )
        self.layer3 = nn.Sequential(
                nn.Conv2d(in_channels = 48, out_channels = 96, kernel_size = 3, stride = 1, padding = 1), #96,16,16
                nn.BatchNorm2d(96),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2),  # 96,8,8
        )
        self.layer4 = nn.Sequential(
                nn.Conv2d(in_channels = 96, out_channels = 192, kernel_size = 3, stride = 1, padding = 1),  # 192,8,8
                nn.BatchNorm2d(192),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2),  # 192,4,4
        )
        self.layer5 = nn.Sequential(
                nn.Conv2d(in_channels = 192, out_channels = 256, kernel_size = 3, stride = 1, padding = 1),  # 256,4,4
                nn.BatchNorm2d(256),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2),  # 256,2,2
        )
        self.fc = nn.Sequential(
                nn.Linear(256*2*2, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Linear(256, 64),
                nn.BatchNorm1d(64),
                nn.ReLU(),
                nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
        return x






 

train.py

import torch
from torch import nn, optim
from torch.autograd import Variable
import random
from mmcv import ProgressBar
from normalnet import CNN
from get_data import trainloader, testloader, testset,trainset,batch_size
from tensorboardX import SummaryWriter
writer = SummaryWriter()


device = torch.device('cuda:5')

model = CNN()
model.to(device)
model.train()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())


def get_ACC_test():
    model.eval()
    total_num = len(testset)
    correct = 0
    for item in testloader:
        batch_imgs, batch_labels = item
        batch_imgs = Variable(batch_imgs).to(device)
        batch_labels = Variable(batch_labels).to(device)
        out = model(batch_imgs)
        _, pred = torch.max(out.data, 1)
        correct += torch.sum(pred == batch_labels)
    correct = correct.data.item()
    acc = correct / total_num
    print('test set, correct:{}, ACC:{}'.format(correct, acc))
    model.train()
    return acc

def get_ACC_train():
    model.eval()
    total_num = len(trainset)
    correct = 0
    for item in trainloader:
        batch_imgs, batch_labels = item
        batch_imgs = Variable(batch_imgs).to(device)
        batch_labels = Variable(batch_labels).to(device)
        out = model(batch_imgs)
        _, pred = torch.max(out.data, 1)
        correct += torch.sum(pred == batch_labels)
    correct = correct.data.item()
    acc = correct / total_num
    print('train set, correct:{}, ACC:{}'.format(correct, acc))
    model.train()
    return acc


for epoch in range(10):
    cnt = 1
    sum_loss = 0
    bar = ProgressBar(len(trainset) / batch_size)
    for item in trainloader:
        batch_input, batch_label = item
        batch_input = Variable(batch_input).to(device)
        batch_label = Variable(batch_label).to(device)
        out = model(batch_input)
        loss = criterion(out, batch_label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print_loss = loss.data.item()
        sum_loss += print_loss
        cnt += 1
        bar.update()

    ave_loss = sum_loss / cnt
    print('epoch:{},loss:{}'.format(epoch, ave_loss))
    acc_test = get_ACC_test()
    acc_train = get_ACC_train()
    print('ACC train:{}, ACC test:{}'.format(acc_train,acc_test))
    writer.add_scalars('CIFAR10', {'acc_test': acc_test, 'acc_train': acc_train}, epoch)
    print()


 

 

eval.py

import random

from torch.autograd import Variable

from get_data import show_test_set,testset
import torch

device = torch.device('cuda')

model = torch.load('model').to(device)
model.eval()

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


while True:

    index = random.randint(0,len(testset))
    item = testset[index]
    img,label = item
    img = img.unsqueeze(0)
    img = Variable(img).to(device)
    out = model(img)
    _, pred = torch.max(out.data, 1)
    print('predict:{},label:{}'.format(classes[pred.data.item()] ,classes[label] ))
    item_show = show_test_set[index]
    img_show = item_show[0]
    img_show.show()

    go_on = input('go on:')
    if (go_on == 'n'):
        break

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值