import torch
import torch.utils.data as data
1.
class CustomDataset(data.Dataset):
__init__(self, transform)
__len__(self)
__getitem__(self, index):
transform(data)
loader = data.DataLoader(dataset=torch_dataset, batch_size= bz, shuffle=True/False, num_workers=3)
for epoch in range(5):
for step, (batch_x, batch_y) in enumerate(loader):
pass
transform:
transform = transforms.Compose([transforms.Scale(40), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()])
transforms.Normalize((0.1307,), (0.3081,))、
mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]
----------------------------------------------------------reference---------------------
1. https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
2. http://blog.youkuaiyun.com/u012436149/article/details/69061711