PyTorch实战:Excel数据集到Tensor转换与鸢尾花分类BP神经网络

第一版本可以看
1.0

第二版本与1.0不一样的地方是使用了Dataset进行预处理数据,使用起来更加方便,同时使用了SummaryWriter保存准确率数据。其中SummaryWriter使用方法看
SummaryWriter
,同时增加了独热编码方法:

我将以鸢尾花数据集作为例子进行展示:

可以看到鸢尾花数据集有四个特征,分别是0,1,2,3,label是鸢尾花种类,共三种,分别以0,1,2表示。

1.使用Dataset预处理数据:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F


class Excel_dataset(Dataset):

    def __init__(self, dir, if_normalize=False, if_onehot=False):
        super(Excel_dataset, self).__init__()

        if (dir.endswith('.csv')):
            data = pd.read_csv(dir)
        elif (dir.endswith('.xlsx') or dir.endswith('.xls')):
            data = pd.read_excel(dir, engine="openpyxl")

        nplist = data.T.to_numpy()
        data = nplist[0:-1].T
        self.data = np.float64(data)
        self.target = nplist[-1]

        self.target_type = []
        #记录标签有几类
        for i in self.target:
            try:
                self.target_type.index(i)
            except(BaseException):
                self.target_type.append(i)
                # print(i)
        # 将标签转为自然数编码
        self.target_num = []
        for i in self.target:
            self.target_num.append(self.target_type.index(i))
            # print(self.target_type.index(i))

        # Tensor化
        self.data = np.array(self.data)
        self.data = torch.FloatTensor(self.data)
        self.target_num = np.array(self.target_num)
        self.target_num = self.target_num.astype(float)
        self.target_num = torch.LongTensor(self.target_num)
        self.if_onehot = if_onehot
        #生成独热编码
        self.target_onehot = []
        if if_onehot == True:

            for i in self.target_num:
                tar = F.one_hot(i.to(torch.int64), len(self.target_type))
                self.target_onehot.append(tar)
            # pass

        if if_normalize == True:
            self.data = nn.functional.normalize(self.data)

    def __getitem__(self, index):
        # tar = F.one_hot(self.target[index].to(torch.int64), len(self.target_type))
        # print(tar)
        if self.if_onehot == True:
            return self.data[index], self.target_onehot[index]

        else:
            return self.data[index], self.target_num[index]

    def __len__(self):
        return len(self.target)


def data_split(data, rate):
    train_l = int(len(data) * rate)
    test_l = len(data) - train_l
    """打乱数据集并且划分"""
    t
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值