pytorch创建dataset的第二种方式

import torch
from torch import nn
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shutil
import os
import glob
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data.dataloader import default_collate

spices = ['cloudy', 'rain', 'shine', 'sunrise']


# 通过自定义dataset类的方式创建dataset
# 自定义类有两个要点,第一个要点是继承的父类是torch.utils.data.Dataset,第二个要点是要重写collate_fn并且在DataLoader中指定collate_fn=LEI.collate_fn
class MYDataset(torch.utils.data.Dataset):
    def __init__(self, image_path, labels, transformation):
        self.images = image_path
        self.labels = labels
        self.transformation = transformation

    def __getitem__(self, item):
        img = Image.open(self.images[item])
        label = self.labels[item]
        # 判断图片是否是三通道图片
        if np.asarray(img).shape[-1] == 3:
            img = self.transformation(img)
            return img, torch.tensor(label).type(torch.LongTensor)
        else:
            return self.__getitem__(item + 1)

    def __len__(self):
        return len(self.images)

    @staticmethod
    def collate_fn(batch):
        batch = [sample for sample in batch if sample is not None]
        return default_collate(batch)


train_transformation = transforms.Compose([
    transforms.Resize((96, 96)),
    # 数据增强
    # transforms.RandomCrop(64),
    transforms.RandomHorizontalFlip(0.2),
    transforms.RandomVerticalFlip(0.2),
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[1, 1, 1])
])
valid_transformation = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[1, 1, 1])
])
file_names2 = glob.glob('./dataset/*jpg')
index = np.random.permutation(len(file_names2))
labels = []
dict_spec = dict((c, i) for i, c in enumerate(spices))
for file_name in file_names2:
    for spec in spices:
        if spec in file_name:
            labels.append(dict_spec[spec])
labels = np.array(labels)[index]
file_names2 = np.array(file_names2)[index]
train_split = int(len(index) * 0.8)
train_files, valid_files = file_names2[: train_split], file_names2[train_split:]
train_labels, valid_labels = labels[: train_split], labels[train_split:]
train_ds2 = MYDataset(train_files, train_labels, train_transformation)
valid_ds2 = MYDataset(valid_files, valid_labels, valid_transformation)
train_dl2 = DataLoader(train_ds2, batch_size=32, shuffle=True, drop_last=True, collate_fn=MYDataset.collate_fn)
valid_dl2 = DataLoader(valid_ds2, batch_size=64, drop_last=True, collate_fn=MYDataset.collate_fn)


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Sequential(
            nn.Linear(128 * 12 * 12, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.ReLU()
        )
        self.fc2 = nn.Sequential(
            nn.Linear(1024, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.ReLU()
        )
        self.fc3 = nn.Linear(256, 4)

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        x = nn.Flatten()(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


model2 = Net()
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model2 = model2.to(device)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=0.001)


def cal_acc(y_pred, y_true):
    acc = (torch.argmax(y_pred, dim=1) == y_true).sum().item()
    return acc


def train_step(epoch, model, train_dl, valid_dl, loss_func, optimizer):
    total_train_loss = 0
    total_train_acc = 0
    total_train = 0
    total_valid_loss = 0
    total_valid_acc = 0
    total_valid = 0
    model.train()
    for x, y in train_dl:
        x, y = x.to(device), y.to(device)
        prediction = model(x)
        train_loss = loss_func(prediction, y)
        print(train_loss)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        with torch.no_grad():
            total_train_loss += train_loss.item()
            total_train_acc += cal_acc(prediction, y)
            total_train += y.size(0)
    total_train_loss = total_train_loss / total_train
    total_train_acc = total_train_acc / total_train
    model.eval()
    for x, y in valid_dl:
        with torch.no_grad():
            x, y = x.to(device), y.to(device)
            prediction = model(x)
            valid_loss = loss_func(prediction, y)
            total_valid_loss += valid_loss.item()
            total_valid_acc += cal_acc(prediction, y)
            total_valid += y.size(0)
    total_valid_loss = total_valid_loss / total_valid
    total_valid_acc = total_valid_acc / total_valid
    print('epoch: %d, train_loss: %3.3f, train_acc: %3.3f, valid_loss: %3.3f, valid_acc: %3.3f' % (
    epoch, total_train_loss, total_train_acc, total_valid_loss, total_valid_acc))
    return total_train_loss, total_train_acc, total_valid_loss, total_valid_acc


def model_fit(model, train_dl, valid_dl, loss_func, optimizer, epochs=10):
    history = {'train_loss': [],
               'train_acc': [],
               'valid_loss': [],
               'valid_acc': []}
    for epoch in range(epochs):
        train_loss, train_acc, valid_loss, valid_acc = train_step(epoch, model, train_dl, valid_dl, loss_func, optimizer)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['valid_loss'].append(valid_loss)
        history['valid_acc'].append(valid_acc)
    return history


def draw_curve(history):
    pd.DataFrame(history).plot()
    plt.gca().set_ylim(0, 1)
    plt.grid(True)
    plt.show()


history2 = model_fit(model2, train_dl2, valid_dl2, loss_fn2, optimizer2)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值