RuntimeError: output with shape [1, 28, 28] doesnt match the broadcast shape [3, 28, 28]

pytorch执行MNIST源码
# Import things like usual

%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import torch

import helper

import matplotlib.pyplot as plt
from torchvision import datasets, transforms
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                             ])
# Download and load the training data
trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

dataiter = iter(trainloader)
images, labels = dataiter.next()

#报错如下



RuntimeError                              Traceback (most recent call last)
<ipython-input-22-840309f5aa1d> in <module>
      1 dataiter = iter(trainloader)
----> 2 images, labels = dataiter.next()

C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    558         if self.num_workers == 0:  # same-process loading
    559             indices = next(self.sample_iter)  # may raise StopIteration
--> 560             batch = self.collate_fn([self.dataset[i] for i in indices])
    561             if self.pin_memory:
    562                 batch = _utils.pin_memory.pin_memory_batch(batch)

C:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in <listcomp>(.0)
    558         if self.num_workers == 0:  # same-process loading
    559             indices = next(self.sample_iter)  # may raise StopIteration
--> 560             batch = self.collate_fn([self.dataset[i] for i in indices])
    561             if self.pin_memory:
    562                 batch = _utils.pin_memory.pin_memory_batch(batch)

C:\ProgramData\Anaconda3\lib\site-packages\torchvision\datasets\mnist.py in __getitem__(self, index)
     93 
     94         if self.transform is not None:
---> 95             img = self.transform(img)
     96 
     97         if self.target_transform is not None:

C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, img)
     59     def __call__(self, img):
     60         for t in self.transforms:
---> 61             img = t(img)
     62         return img
     63 

C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, tensor)
    162             Tensor: Normalized Tensor image.
    163         """
--> 164         return F.normalize(tensor, self.mean, self.std, self.inplace)
    165 
    166     def __repr__(self):

C:\ProgramData\Anaconda3\lib\site-packages\torchvision\transforms\functional.py in normalize(tensor, mean, std, inplace)
    206     mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
    207     std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
--> 208     tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
    209     return tensor
    210 

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

 

解决办法:

将三通道的标准化改为1通道的,因为使用的图片集是1通道的,如下

#transform = transforms.Compose([transforms.ToTensor(),
#                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
#                             ])
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,),(0.5,)),
                             ])

#解决

https://blog.youkuaiyun.com/weixin_43159148/article/details/88778371

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值