python 定义数字_Python自定义数字数据集,pytorch,数值,型

该博客主要介绍使用Python和PyTorch处理数据集及训练模型。需继承torch.utils.data.Dataset并override三个函数,使用ICU的adult数据集,要自行预处理和清洗数据。还给出了自定义数据集类和网络模型类的代码,最后用tensorboard可视化训练过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

看了许多blog,不想翻源码。说来说去,就是要继承torch.utils.data.Dataset,override三个函数,分别是__len__、__getitem__、__init__。如果你做分类问题,那么getitem函数不仅返回x,还要返回label。

我自己用了ICU的adult数据集,预处理和清洗数据都要自己做。

简单贴一下我的set,我在init里面做了对原数据的变动。

import torch

import torch.utils.data.dataset

import torch.nn.modules

import torch.nn

import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

NUM_OF_TEST_SET=2000

NUM_OF_PARTY=5

PARTY_DATA=6000

NUM_OF_ALL_DATA=32561

NUM_OF_FEATURE=14#NOT including label.

def get_num_correct(preds,labels):

return preds.argmax(dim=1).eq(labels).sum().item()

class Net(torch.nn.Module):

def __init__(self,NF,NH1,NH2,NO):

super(Net,self).__init__()

self.layer1=torch.nn.Linear(NF,NH1)

self.layer2=torch.nn.Linear(NH1,NH2)

self.out=torch.nn.Linear(NH2,NO)

def forward(self, x):

x=F.relu(self.layer1(x))

x=F.relu(self.layer2(x))

x=self.out(x)

return x

class MySet(torch.utils.data.Dataset):

def __init__(self,PATH):

self.RawData = []

self.Data = []

self.GetClass = {}

self.SecondFeatureMin = 12285

self.SecondFeatureMax = 1484705

self.Second = self.SecondFeatureMax - self.SecondFeatureMin

self.LoadData(PATH)

self.WashData()

def LoadData(self, PATH):

with open(PATH, 'r') as f:

for line in f:

self.RawData.append([thing.strip() for thing in line.split(',')])

self.RawData = self.RawData[:-1]

def WashData(self):

attribute = len(self.RawData[0])

index = 0

while index < attribute:

cnt = 0

temp = [line[index] for line in self.RawData]

for item in temp:

if item not in self.GetClass:

self.GetClass[item] = cnt

cnt += 1

index += 1

self.NUMOFFEATURE = index

for line in self.RawData:

t = []

for j, f in enumerate(line):

if j == 0:

t.append(float(f) / 10)

elif j == 2:

t.append(round(float(f) / self.Second, 4))

elif j == 10 or j == 11:

t.append(float(f) / 2000)

elif j == 12:

t.append(float(f) / 10)

elif j == 14:

t.append(self.GetClass[f])

else:

t.append(self.GetClass[f] / 10)

self.Data.append(t)

def __len__(self):

return len(self.RawData)

def __getitem__(self, index):

return torch.tensor(self.Data[index][:-1],dtype=torch.float), \

self.Data[index][-1]

#[1,0] if self.Data[index][-1] is 1 else [0,1]

if __name__=="__main__":

file_path= r'data-source\adult.csv'

train_set=MySet(file_path)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)

network = Net(14, 10, 9, 2)

optimizer = torch.optim.SGD(network.parameters(), lr=0.01)

x,y=next(iter(train_loader))

print(x,y)

然后是训练一下,接着main函数写下来。我直接用tensorboard可视化,看上去没什么问题。

if __name__=="__main__":

file_path= r'data-source\adult.csv'

train_set=MySet(file_path)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)

network = Net(14, 10, 9, 2)

optimizer = torch.optim.SGD(network.parameters(), lr=0.01)

x,y=next(iter(train_loader))

print(x,y)

#tensor board

# comment=f'bath_size={batch_size} lr={lr} shuffle={shuffle}'

tb=SummaryWriter(log_dir='number_run')#,comment=comment)

tb.add_graph(model=network,input_to_model=x)

for epoch in range(10):

total_loss = 0

total_correct = 0

for batch in train_loader:

x,labels=batch

preds=network(x)

loss=F.cross_entropy(preds,labels)

optimizer.zero_grad()

loss.backward()

optimizer.step()

total_loss+=loss.item()

total_correct+=get_num_correct(preds,labels)

tb.add_scalar('Loss',total_loss,epoch)

tb.add_scalar('correct number', total_correct, epoch)

tb.add_scalar('Accuracy', total_correct/len(train_set), epoch)

# for name,weight in network.named_parameters():

# tb.add_histogram(name,weight,epoch)

# tb.add_histogram(f'{name}.grad',weight.grad,epoch)

# print("epoch",epoch,"total_correct:",total_correct,"loss:",total_loss)

tb.close()

exit(0)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值