pytorch CNN 实现手写板识别

pytorch CNN 实现手写板识别

前面介绍了 CNN : 一般 涉及 卷积 ,池化 ,单位化,激活函数

cnn 代码实现

import  torch
import   torch.nn  as  nn
from  torch.autograd  import  Variable

import  torch.utils.data  as  Data

import   torchvision


class  CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        #[1,28,28]  ==>[16,28,28]
        #torch 有现成的卷积方法 
        # in_channels一张图片 只有一个通道 一般会将图片转换成 黑白图
        #out_channels  16 个特征 会卷积出 16 个图片出来
        #kernel_size 特征 5*5
        #stride 卷积的时候每次移动一个像素
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=16,
                               kernel_size=5,
                               stride=1,
                               padding=2)
                              
        #激活函数
        self.relu1 = nn.ReLU()

        # [16,28,28]  ==>[16,14,14]
        #池化
        self.maxPool1= nn.MaxPool2d(kernel_size=2)


        #[16,14,14] ==>[32,14,14]
        self.conv2 = nn.Conv2d(in_channels=16,
                               out_channels=32,
                               kernel_size=5,
                               stride=1,
                               padding=2)
        #self.relu2 = nn.ReLU()

        # [32,14,14] ==>[32,7,7]
        self.maxPool2= nn.MaxPool2d(kernel_size=2)
        # 预测出 10 个结果出来的概率
        #比如[ n , 32*7*7 ] * [ 32*7*7 ,10] ===>[n,10] 求出 N 张图片 每个图片从 0--9 个数字概率
        self.full = nn.Linear(32*7*7,10)
        #采用的损失函数
        self.optimizer = torch.optim.Adam(self.parameters(),lr=LR)
        self.lossFunc= nn.CrossEntropyLoss()
    #调用CNN 的时候执行
    def  forward(self,x):
        #卷积  激活 池化
        result= self.conv1(x)
        result = self.relu1(result)
        result = self.maxPool1(result)
        #卷积  激活 池化
        result = self.conv2(result)
        result = self.relu1(result)
        result = self.maxPool2(result)
        #reshape = batchsize ,32*7*7 重新调整 大小
        result= result.view(x.size(0),-1)
        #全链层
        result = self.full(result)

        return   result

    #损失函数  预测值  真实值 
    def  lossFunction(self,predict ,batchY):
        #预测值   和  真实值 差
        loss = self.lossFunc(predict,batchY)
        #先将记录清零
        self.optimizer.zero_grad()
        #向后做残差
        loss.backward()
        print("loss==",loss.data)
        #重新调整参数
        self.optimizer.step()


以后在用到CNN 直接调整里面的参数 就可以了。

用 手写板 验证 CNN


#下载数据  torchvision 里面自带 手写板数据

#一次性传递多少个图片 给CNN
BATCH_SIZE=50
#学习率
LR= 0.001
#是否下载 第一次启动下载  第二次以后 就不用
DOWNLOAD=Ture

#用CNN 来预测还是训练
TRAIN =True

#下载训练数据   每张图片 [28,28]的一个矩阵
tranData =  torchvision.datasets.MNIST(
    root="d:/mnist/",
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD

)

#下载测试数据
testData = torchvision.datasets.MNIST(
    root="d:/mnist/",
    train=False
)

#加载数据到内存
trainLoader =  Data.DataLoader(dataset=tranData,
                               batch_size=BATCH_SIZE,
                               shuffle=True
                               )

# 为了节约时间, 我们测试时只测试前2000个
# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)

test_x = Variable(torch.unsqueeze(testData.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
test_y = testData.test_labels[:2000]



myCNN= CNN()


if(TRAIN):
   # 数据训练3轮 
    for  epoch in  range(3):
      for  step  ,(x,y) in enumerate(trainLoader):
            #[50, 1, 28, 28]
            trainX = Variable(x)
            tranY = Variable(y)
            #预测数据
            predict= myCNN(trainX)
            #backforward  修正参数
            myCNN.lossFunction(predict,tranY)
    #保存参数到 硬盘
    torch.save(myCNN.state_dict(), "d:/mnist/cnn.pkl")
else:

    #从硬盘加载参数
    myCNN.load_state_dict(torch.load("d:/mnist/cnn.pkl"))




#预测10个结果概率 
testOut = myCNN(test_x[:20])

print(testOut)
#找到一个概率最大的 最为结果
testPredict = torch.max(testOut,1)[1]

print(testPredict,test_y[:20])

输出结果:
tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4]) tensor([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4])


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值