pytorch-custom dataset

import torch

import torch.utils.data as data


1.

 class CustomDataset(data.Dataset):

__init__(self, transform)

__len__(self)

__getitem__(self, index): 

    transform(data)


loader = data.DataLoader(dataset=torch_dataset, batch_size= bz, shuffle=True/False, num_workers=3)


for epoch in range(5):

    for step, (batch_x, batch_y) in enumerate(loader):

          pass


transform:


transform = transforms.Compose([transforms.Scale(40), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()])

transforms.Normalize((0.1307,), (0.3081,))、


    mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]











----------------------------------------------------------reference---------------------

1. https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py

2. http://blog.youkuaiyun.com/u012436149/article/details/69061711

### 定义和使用 `Dataset` 类 为了在 PyTorch 中定义自定义的数据集类,通常继承 `torch.utils.data.Dataset` 并实现两个方法:`__len__()` 和 `__getitem__()`. 这种方式允许灵活地加载各种形式的数据。 对于图像分类任务中的 CIFAR-10 数据集,可以利用内置的 `datasets.CIFAR10` 来简化流程[^1]. 不过,当面对更复杂的情况或特定需求时,则需创建自己的数据集类. 下面是一个简单的例子展示如何构建一个用于训练神经网络模型的自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms, datasets class CustomImageDataset(Dataset): """Custom dataset for images.""" def __init__(self, annotations_file, img_dir, transform=None, target_transform=None): """ Args: annotations_file (string): Path to the csv file with annotations. img_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on an image. target_transform (callable, optional): Optional transform to be applied on a label. """ self.img_labels = pd.read_csv(annotations_file) self.img_dir = img_dir self.transform = transform self.target_transform = target_transform def __len__(self): return len(self.img_labels) def __getitem__(self, idx): img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0]) image = read_image(img_path) label = self.img_labels.iloc[idx, 1] if self.transform: image = self.transform(image) if self.target_transform: label = self.target_transform(label) sample = {"image": image, "label": label} return sample ``` 上述代码展示了如何通过读取 CSV 文件获取图片路径及其对应的标签,并应用转换操作以准备输入给模型使用的张量. 此外,在实例化此类对象之后还可以将其传递给 `DataLoader`, 实现批量处理等功能: ```python dataset = CustomImageDataset('annotations.csv', 'img_folder') dataloader = DataLoader(dataset, batch_size=4, shuffle=True) ``` 关于将 Tensor 转换为 Python 基本类型的变量,如果遇到单元素 Tensors 的情况可以直接调用 `.item()` 方法获得其数值表示[^2]. 最后值得注意的是,在定义线性层(`Linear`)时不显式指定权重矩阵是因为这些参数已经被封装到了模块内部并自动初始化了合适的尺寸[^3].
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值