还是把所有代码放在三个.py文件里:
- net.py用来定义卷积网络
- readpic.py用来读取自己手动画的一张图片,测试着玩
- CNN.py就是代码的主体部分
net.py
这里使用的是LeNet,LeNet是整个神经网络的开山之作,1998年由LeCun提出,它的结构特别简单。
import torch
import torch.nn as nn
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 6, 3, padding=1), # 6@28*28
nn.MaxPool2d(2, 2) # 6@14*14
)
self.layer2 = nn.Sequential(
nn.Conv2d(6, 16, 5), # 16@10*10
nn.MaxPool2d(2, 2) # 16@5*5
)
self.layer3 = nn.Sequential(
nn.Linear(400, 120), # 16*5*5=400
nn.Linear(120, 84),
nn.Linear(84, 10)
)
def forward(sel