import torch
from torch.utils.data import Dataset #Dataset是抽象类,不能实例化,只能继承
from torch.utils.data import DataLoader
import numpy as np
class Set_Dataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype='float32')
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
if __name__ == "__main__":
dataset = Set_Dataset('C:/python3/envs/pytorch/atest_torch/data/diabetes.csv')
train_loader = DataLoader(dataset, shuffle=True, batch_size=32, num_workers=2)
model = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(20):
for i, data in enumerate(train_loader, 0):
# 1. Prepare data
inputs, labels = data
# 2. Forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# 3. Backward
optimizer.zero_grad()
loss.backward()
# 4. Update
optimizer.step()
运行结果:19 4 0.6433802247047424
19 5 0.6984792947769165
19 6 0.6620081663131714
19 7 0.6072109341621399
19 8 0.6251083016395569
19 9 0.6068292260169983
19 10 0.6617796421051025
19 11 0.6436219811439514
19 12 0.735340416431427
19 13 0.6616771817207336
19 14 0.6253309845924377
19 15 0.6617479920387268
19 16 0.5704611539840698
19 17 0.5883565545082092
19 18 0.6800984740257263
19 19 0.6618121862411499
19 20 0.6064887642860413
19 21 0.6620679497718811
19 22 0.6439443826675415
19 23 0.6199027299880981
总结:
1、train_loader = DataLoader(dataset, shuffle=True, batch_size=32,num_workers=2)
shuffle -洗牌-set to True to have the data reshuffled at every epoch (default: False).
指是否打乱抽样
batch_size 每个批次的大小
num_workers 多进程的数量
2、enumerate(train_loader, 0)
class enumerate()
def init(self, iterable: Iterable[_T], start: int = 0)
Initialize self. See help(type(self)) for accurate signature.
iterable: Iterable[_T]指一个可迭代序列(train_loader)
start: int = 0指开始位置(0)