pytorch 写的一个 lenet的分类网络,不是百分百还原哈,结构是一样的, 简单训练一下自己的数据集。数据集格式如下,data内存放 自己的数据,每个类别放到一个文件夹中,文件夹名称为类别标签如下图
1.网络搭建
import torch
import torch.nn as nn
class Lenet(nn.Module):
def __init__(self, num_classes = 1000):
super(Lenet, self).__init__()
self.conv1 = nn.Conv2d(3,6,5,1,0)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(6, 16, 5, 1, 0)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = nn.Conv2d(16, 120, 5, 1, 0)
self.fc6 = nn.Linear(120, 84)
self.fc7 = nn.Linear(84, 10)
self.relu = nn.ReLU(inplace=Tr