PyTorch demo——基于MLP的鸢尾花分类

系统框架

在这里插入图片描述

1. 数据集加载

  继承torch.utils.data.Dataset类,重写__getitem__和__len__方法,并在__getitem__中预处理数据。

# load.py
import torch


class IrisDataset(torch.utils.data.Dataset):
    def __init__(self, data_file, iris_class):
        super(IrisDataset, self).__init__()

        self.iris_class = iris_class

        self.all_data = []
        with open(data_file, 'r') as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
            for l in lines:
                l = l.split(',')
                vec = [float(i) for i in l[:-1]]
                label = self.iris_class[str(l[-1])]
                self.all_data.append([vec, label])


    def __getitem__(self, item):
        fea, label = self.all_data[item]
        fea, label = torch.tensor(fea, dtype=torch.float32), torch.tensor(label, dtype=torch.long)
		# No data augmentation
		
        return fea, label


    def __len__(self):

        return len(self.all_data)


if __name__ == "__main__":
    import config
    dataset = IrisDataset("iris/train", config.iris_class)
    print(dataset.__getitem__(0))

2. 网络模型——MLP

在这里插入图片描述

# net.py
import torch
import torch.nn as nn


class Net(torch.nn.Module):
    def __init__(self, input_dim=4, num_class=3):
        super(Net, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Linear(64, num_class),
            nn.Softmax()
        )

    def forward(self, x):

        return self.fc(x)


if __name__ == "__main__":
    net = Net()
    print(net)

    x = torch.randn(2, 4)
    print(net(x).shape)

3. 配置文件——网络参数、训练参数整理

# config.py
import warnings
warnings.filterwarnings('ignore')

"""dataset"""
iris_class = {
   
    "Iris-setosa": 0,
    "Iris-versicolor": 1,
    "Iris-virginica": 2
}

"""net args"""
input_dim = 4
num_class = 3

"""train & valid"""
train_data = 'iris/train'
valid_data = 'iris/valid'
batch_size = 10
nworks = 1
max_epoch = 200
lr = 1e-3
factor = 0.9

""" test """
test_data = "iris/test"
pre_model = "pth/model_100.pth"

4. 训练

# train.py
import torch, os, tqdm
from torch.utils.data import DataLoader

import load, net, config
import matplotlib.pyplot as plt


def train():
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda:" + str(0))   
        torch.backends.cudnn.benchmark = True
    else:
        DEVICE = torch.device("cpu")
    print("current deveice:", DEVICE)

    # load dataset for train and eval
    train_dataset = load.IrisDataset(config.train_data, config.iris_class)
    train_batchs = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.nworks, pin_memory=True)
    valid_dataset = load.IrisDataset(config.valid_data, config.iris_class)
    valid_batchs = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.nworks, pin_memory&#
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值