定义:
#gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
使用:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import torch.utils.data
from resnet import ResNet18
#gpu or not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser(description='Python CIFAR10 Training')
parser.add_argument('--outf', default='./model', help='folder to output images and model checkpoints')
parser.add_argument('--net', default=',.model/Resnet18.pth', help='path to net(to continue training)')
args = parser.parse_args()
#超参数设置
EPOCH = 135
pre_epoch = 0
BATCH_SIZE = 128
LR