系统框架
1. 数据集加载
继承torch.utils.data.Dataset类,重写__getitem__和__len__方法,并在__getitem__中预处理数据。
# load.py
import torch
class IrisDataset(torch.utils.data.Dataset):
def __init__(self, data_file, iris_class):
super(IrisDataset, self).__init__()
self.iris_class = iris_class
self.all_data = []
with open(data_file, 'r') as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
for l in lines:
l = l.split(',')
vec = [float(i) for i in l[:-1]]
label = self.iris_class[str(l[-1])]
self.all_data.append([vec, label])
def __getitem__(self, item):
fea, label = self.all_data[item]
fea, label = torch.tensor(fea, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
# No data augmentation
return fea, label
def __len__(self):
return len(self.all_data)
if __name__ == "__main__":
import config
dataset = IrisDataset("iris/train", config.iris_class)
print(dataset.__getitem__(0))
2. 网络模型——MLP
# net.py
import torch
import torch.nn as nn
class Net(torch.nn.Module):
def __init__(self, input_dim=4, num_class=3):
super(Net, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Linear(512, 128),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(inplace=True),
nn.Linear(64, num_class),
nn.Softmax()
)
def forward(self, x):
return self.fc(x)
if __name__ == "__main__":
net = Net()
print(net)
x = torch.randn(2, 4)
print(net(x).shape)
3. 配置文件——网络参数、训练参数整理
# config.py
import warnings
warnings.filterwarnings('ignore')
"""dataset"""
iris_class = {
"Iris-setosa": 0,
"Iris-versicolor": 1,
"Iris-virginica": 2
}
"""net args"""
input_dim = 4
num_class = 3
"""train & valid"""
train_data = 'iris/train'
valid_data = 'iris/valid'
batch_size = 10
nworks = 1
max_epoch = 200
lr = 1e-3
factor = 0.9
""" test """
test_data = "iris/test"
pre_model = "pth/model_100.pth"
4. 训练
# train.py
import torch, os, tqdm
from torch.utils.data import DataLoader
import load, net, config
import matplotlib.pyplot as plt
def train():
if torch.cuda.is_available():
DEVICE = torch.device("cuda:" + str(0))
torch.backends.cudnn.benchmark = True
else:
DEVICE = torch.device("cpu")
print("current deveice:", DEVICE)
# load dataset for train and eval
train_dataset = load.IrisDataset(config.train_data, config.iris_class)
train_batchs = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.nworks, pin_memory=True)
valid_dataset = load.IrisDataset(config.valid_data, config.iris_class)
valid_batchs = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.nworks, pin_memory&#