代码来源
一、对于图像的训练
数据集:https://pan.baidu.com/s/18Fz9Cpj0Lf9BC7As8frZrw 提取码:xhgk
注意:训练时需要在添加参数,即训练集的目录
训练集有60000万张0-9的手写字符图片
import torch import math import torch.nn as nn from torch.autograd import Variable from torchvision import transforms, models import argparse import os from torch.utils.data import DataLoader from dataloader import mnist_loader as ml from models.cnn import Net from toonnx import to_onnx parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--datapath', required=True, help='data path') parser.add_argument('--batch_size', type=int, default=256, help='training batch size') parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train') parser.add_argument('--use_cuda', default=False, help='using CUDA for training') args = parser.parse_args() args.cuda = args.use_cuda and torch.cuda.is_available() if args.cuda: