定义:
#gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
使用:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import torch.utils.data
from resnet import ResNet18
#gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='Python CIFAR10 Training')
parser.add_argument('--outf', default='./model', help='folder to output images and model checkpoints')
parser.add_argument('--net', default=',.model/Resnet18.pth', help='path to net(to continue training)')
args = parser.parse_args()
#超参数设置
EPOCH = 135
pre_epoch = 0
BATCH_SIZE = 128
LR = 0.1
#准备数据集并预处理
transform_train = transforms.Compose([
transforms.RandomCrop(size=32, padding=4), #先padding,再随机截取32*32
transforms.RandomHor