B站--刘二大人《PyTorch深度学习实践》完结合集 08. 加载数据集
PPT 链接:网盘 提取码:cxe4
dataset:构建数据集,
dataloader:为训练提供mini-batch数据
Manual data feed
- mini-batch:
mini-batch可以有效解决按点问题,并可以降低随机性
- epoch
一个周期,所有样本进行了前馈-反馈-更新的过程,所有样本都进行了一次训练
- batch-size
批量大小,训练一次所需要的样本数量,也就是一个mini-batch的大小。
for epoch in range(training_epochs): #对所有数据进行重复训练
for i in range(total_batch):#遍历每个mini-batch数据
- literation
迭代次数,一个mini-batch经过多少次迭代把所有样本训练完一次,
直观上来看就是总的batch所包含的mini-batch的数量。
- dataloader
shuffle=True表示打乱样本顺序;然后将样本分成2个一组batch。如图所示
代码实现dataset,dataloader
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 新定义一个类继承自Dataset
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
# 直接取xy数据集的形状的行,表示长度,即样本数量
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[ :, :-1]) #取所有行,除最后一列之外的其他所有列
self.y_data = torch.from_numpy(xy[ :, [-1]]) #取所有行以及最后一列
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# 实例化数据集对象
dataset = DiabetesDataset('../dataset/diabetes.csv.gz')
# 加载数据集
# dataset是数据集对象,batch-size是批次大小,shuffle表示是否打乱样本顺序,num_workers表示使用多线程
#多线程进行,提高运行效率
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
# self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Model()
criterion = torch.nn.BCELoss(size_average=True)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_list = []
if __name__ == '__main__':
for epoch in range(100):
for i, data in enumerate(train_loader, 0):
# 1. 准备数据
inputs, labels = data
# 2. 前馈
y_pred = model(inputs)
loss = criterion(y_pred, labels)
# print(epoch, i, loss.item())
# 3. 反馈
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch, loss.item())
loss_list.append(loss.item())
print(loss_list)
plt.plot(range(100), loss_list)
plt.xlabel('epoch')
plt.ylabel('cost')
plt.show()
参考链接: