看了许多blog,不想翻源码。说来说去,就是要继承torch.utils.data.Dataset,override三个函数,分别是__len__、__getitem__、__init__。如果你做分类问题,那么getitem函数不仅返回x,还要返回label。
我自己用了ICU的adult数据集,预处理和清洗数据都要自己做。
简单贴一下我的set,我在init里面做了对原数据的变动。
import torch
import torch.utils.data.dataset
import torch.nn.modules
import torch.nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
NUM_OF_TEST_SET=2000
NUM_OF_PARTY=5
PARTY_DATA=6000
NUM_OF_ALL_DATA=32561
NUM_OF_FEATURE=14#NOT including label.
def get_num_correct(preds,labels):
return preds.argmax(dim=1).eq(labels).sum().item()
class Net(torch.nn.Module):
def __init__(self,NF,NH1,NH2,NO):
super(Net,self).__init__()
self.layer1=torch.nn.Linear(NF,NH1)
self.layer2=torch.nn.Linear(NH1,NH2)
self.out=torch.nn.Linear(NH2,NO)
def forward(self, x):
x=F.relu(self.layer1(x))
x=F.relu(self.layer2(x))
x=self.out(x)
return x
class MySet(torch.utils.data.Dataset):
def __init__(self,PATH):
self.RawData = []
self.Data = []
self.GetClass = {}
self.SecondFeatureMin = 12285
self.SecondFeatureMax = 1484705
self.Second = self.SecondFeatureMax - self.SecondFeatureMin
self.LoadData(PATH)
self.WashData()
def LoadData(self, PATH):
with open(PATH, 'r') as f:
for line in f:
self.RawData.append([thing.strip() for thing in line.split(',')])
self.RawData = self.RawData[:-1]
def WashData(self):
attribute = len(self.RawData[0])
index = 0
while index < attribute:
cnt = 0
temp = [line[index] for line in self.RawData]
for item in temp:
if item not in self.GetClass:
self.GetClass[item] = cnt
cnt += 1
index += 1
self.NUMOFFEATURE = index
for line in self.RawData:
t = []
for j, f in enumerate(line):
if j == 0:
t.append(float(f) / 10)
elif j == 2:
t.append(round(float(f) / self.Second, 4))
elif j == 10 or j == 11:
t.append(float(f) / 2000)
elif j == 12:
t.append(float(f) / 10)
elif j == 14:
t.append(self.GetClass[f])
else:
t.append(self.GetClass[f] / 10)
self.Data.append(t)
def __len__(self):
return len(self.RawData)
def __getitem__(self, index):
return torch.tensor(self.Data[index][:-1],dtype=torch.float), \
self.Data[index][-1]
#[1,0] if self.Data[index][-1] is 1 else [0,1]
if __name__=="__main__":
file_path= r'data-source\adult.csv'
train_set=MySet(file_path)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
network = Net(14, 10, 9, 2)
optimizer = torch.optim.SGD(network.parameters(), lr=0.01)
x,y=next(iter(train_loader))
print(x,y)
然后是训练一下,接着main函数写下来。我直接用tensorboard可视化,看上去没什么问题。
if __name__=="__main__":
file_path= r'data-source\adult.csv'
train_set=MySet(file_path)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1000, shuffle=True)
network = Net(14, 10, 9, 2)
optimizer = torch.optim.SGD(network.parameters(), lr=0.01)
x,y=next(iter(train_loader))
print(x,y)
#tensor board
# comment=f'bath_size={batch_size} lr={lr} shuffle={shuffle}'
tb=SummaryWriter(log_dir='number_run')#,comment=comment)
tb.add_graph(model=network,input_to_model=x)
for epoch in range(10):
total_loss = 0
total_correct = 0
for batch in train_loader:
x,labels=batch
preds=network(x)
loss=F.cross_entropy(preds,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss+=loss.item()
total_correct+=get_num_correct(preds,labels)
tb.add_scalar('Loss',total_loss,epoch)
tb.add_scalar('correct number', total_correct, epoch)
tb.add_scalar('Accuracy', total_correct/len(train_set), epoch)
# for name,weight in network.named_parameters():
# tb.add_histogram(name,weight,epoch)
# tb.add_histogram(f'{name}.grad',weight.grad,epoch)
# print("epoch",epoch,"total_correct:",total_correct,"loss:",total_loss)
tb.close()
exit(0)