pytorch执行MNIST源码
# Import things like usual
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import numpy as np
import torch
import helper
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# Download and load the training data
trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
dataiter = iter(trainloader)
images, labels = dataiter.next()
#报错如下
RuntimeError Traceback (most recent call last)
<ipython-input-22-840309f5aa1d> in <module>
1 dataiter = iter(trainloader)
----> 2 images, labels = dataiter.next()
C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
558 if self.num_workers == 0: # same-process loading
559 indices = next(self.sample_iter) # may raise StopIteration
--> 560 batch = self.collate_fn([self.dataset[i] for i in indices])
561 if self.pin_memory:
562 batch = _utils.pin_memory.pin_memory_batch(batch)
C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in <listcomp>(.0)
558 if self.num_workers == 0: # same-process loading
559 indices = next(self.sample_iter) # may raise StopIteration
--> 560 batch = self.collate_fn([self.dataset[i] for i in indices])
561 if self.pin_memory:
562 batch = _utils.pin_memory.pin_memory_batch(batch)
C:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index)
93
94 if self.transform is not None:
---> 95 img = self.transform(img)
96
97 if self.target_transform is not None:
C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)
59 def __call__(self, img):
60 for t in self.transforms:
---> 61 img = t(img)
62 return img
63
C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, tensor)
162 Tensor: Normalized Tensor image.
163 """
--> 164 return F.normalize(tensor, self.mean, self.std, self.inplace)
165
166 def __repr__(self):
C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace)
206 mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
207 std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
--> 208 tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
209 return tensor
210
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
解决办法:
将三通道的标准化改为1通道的,因为使用的图片集是1通道的,如下
#transform = transforms.Compose([transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
# ])
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,),(0.5,)),
])
#解决
https://blog.youkuaiyun.com/weixin_43159148/article/details/88778371