pytorch 数据包的使用
pytorch 提供了,data包来进行数据的处理。
import torch.utils.data as Data
import pandas as pd
import numpy as np
# 生成数据
a=np.arange(24).reshape(6,4)
a12=pd.DataFrame(a,columns=list('asdf'))
list_feature=a12.iloc[:,:3].values
list_labels=a12['f'].values
# 将numpy 转为tensor
data_all=Data.TensorDataset(torch.from_numpy(list_feature),torch.from_numpy(list_labels))
# 设置batch_size
batch_size=2
# 随机读取小批次
data_iter=Data.DataLoader(data_all,batch_size,shuffle=True)
# 循环读取小批次数据 进行训练
# 设置训练迭代的次数
num_epochs= 5
for epoch in range(num_epochs):
for x,y in data_iter:
print(x,y)
print('***'*12)
# 训练等一系列操作