网上搜的代码,加了个损失曲线可视化,需要改下,保存最优模型,懒得改,没力气,代码就是拿着搜到的代码改来改去,感觉结果还不错,数据集再多点应该更好。
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings(“ignore”)
device = torch.device(“cuda:0” if torch.cuda.is_available() else ‘cpu’)
n_classes = 3 # 几种分类的
preteain = False # 是否下载使用训练参数 有网true 没网false
epoches &#