自定义DataSet
数字
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import torch
# 自构建数据集
dataset = TensorDataset(torch.arange(1, 40))
dl = DataLoader(dataset, batch_size=10)
# 数据输出
for batch in dl:
print(batch)
图片
from torch.utils.data import Dataset
from PIL import Image #这里用来读取图片数据
import os
class MyData(Dataset): # 这里定义了一个MyData类继承Dataset读取数据
def __init__(self, root_dir, label_dir): #进行初始化数据
self.root_dir = root_dir #比如 "data/train" 相对路径
self.label_dir = label_dir #这里可以是 "ant"
self.path = os.path.join(self.root_dir,self.label_dir) #利用一个函数将其合并为ant标签对应的路径
# os.listdir()会将path路径下对应的所有文件的“文件名”转化成列表中的一个个元素
self.img_path = os.listdir(self.path)
def __getitem__(self, idx): #实现这个方法是为了获取第idx个文件对应的数据和标签的
img_name = self.img_path[idx] #这个列表就可以返回名字呢
img_item_path = os.path.join(self.path,img_name) #将其之前的路径和文件名拼起来就是最终路径了
img = Image.open(img_item_path) #数据读取到img变量里了
label = self.label_dir #标签就是上层文件夹名
return img,label #返回img,label完成重写函数任务
def __len__(self):
return len(self.img_path) #列表的长度就是数据的长度
root_dir = "../../../../../datas/hymenoptera_data/train"
ants_lable_dir = "ants"
bees_lable_dir = "bees"
imageList = MyData(root_dir, ants_lable_dir)
imge, lable = imageList[1]
imge.show()
print(lable)
显示带有地标的图片
from __future__ import print_function, division
import os
import torch
import pandas as pd #用于更容易地进行csv解析
from skimage import io, transform #用于图像的IO和变换
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# 忽略警告
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
# scikit-image:用于图像的IO和变换
# pandas:用于更容易地进行csv解析
# 读取数据
# 将csv中的标注点数据读入(N,2)数组中,其中N是特征点的数量。读取数据代码如下:
landmarks_frame = pd.read_csv('../../../../../datas/faces/face_landmarks.csv')
n = 67
# 获取第n行的图像名称(第0列)
img_name = landmarks_frame.iloc[n, 0]
print('Image name: {}'.format(img_name))
# 获取第n行的地标数据(第1列及之后的所有列),并转换为矩阵
landmarks = landmarks_frame.iloc[n, 1:].to_numpy()
print(landmarks)
# 将地标数据类型转换为浮点型,并重塑为二维数组,每个地标点包含两个值(x和y坐标)
print(landmarks.astype('float'))
landmarks = landmarks.astype('float').reshape(-1, 2)
print(landmarks)
print('Landmarks shape: {}'.format(landmarks.shape))
print('First 4 Landmarks: {}'.format(landmarks[:4]))
def show_landmarks(image, landmarks):
"""显示带有地标的图片"""
# 使用 matplotlib.pyplot 的 imshow 函数显示图像。
plt.imshow(image)
# 使用 scatter 函数在图像上绘制地标点。
# landmarks[:, 0] 和 landmarks[:, 1] 分别表示地标点的 x 坐标和 y 坐标,
# s=10 设置点的大小,marker='.' 设置点的形状为小圆点,c='r' 设置点的颜色为红色。
plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
# 暂停0.001秒,以确保图像更新并显示出来。这在某些环境中是必要的,以确保绘图窗口能够及时刷新。
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
show_landmarks(io.imread(os.path.join('../../../../../datas/faces/', img_name)),landmarks)
plt.show()
dataloader() 方法
上面我们知道了dataset 这个数据集之后,为啥还需要适用datalaoder, 这个时候如果你是从java过来的, 应该知道classloader,需要对内容进行加载,如果不是从java过来也没关系。
dataset 我们可以比作时一副扑克牌,但是我们怎么抓到每个人手里是有dataloader说了算
简单的例子
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import torch
# 自构建数据集
dataset = TensorDataset(torch.arange(1, 40))
r"""
batch_size 一批加载多少个数据
shuffle 随机的意思, 默认是false, 如果是True 每次加载的数据顺序就不一样
drop_last 最后不够batch_size 这么多时,数据是否丢弃
num_workers 表示几个子进程进行数据加载, 0 表示在主现场中加载
"""
dl = DataLoader(dataset, batch_size=9, shuffle = True, drop_last = True, num_workers = 0)
# 数据输出
for batch in dl:
print(batch)
图片加载的例子
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
data_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_data = torchvision.datasets.CIFAR10("../../../../../datas/download",train=True,transform=data_transform,download=True)
test_data = torchvision.datasets.CIFAR10("../../../../../datas/download",train=False,transform=data_transform,download=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=0, drop_last= True)
image, target = test_data[0]
print(image.shape)
print(target)
writer = SummaryWriter("../../../../../datas/03_tensorboard")
step = 0
for data in test_dataloader:
imgs, targets =data
# print(imgs.shape)
# print(targets)
writer.add_images("test_data_loader", imgs, step)
step = step + 1
for epoch in range(2):
step = 0
for data in test_dataloader:
imgs, targets =data
# print(imgs.shape)
# print(targets)
writer.add_images("test_data_loader_epoch_{}".format(epoch), imgs, step)
step = step + 1
writer.close()