RuntimeError: each element in list of batch should be of equal size

RuntimeError: each element in list of batch should be of equal size
自己定义dataset类,返回需要返回的相应数据,发现报了以下错误

RuntimeError: each element in list of batch should be of equal size

百度一下说最直接的方法是吧batch_size的值改为1,报错解除,一试果然。但是我他妈是训练模型来的,不是为了仅仅改错。batch_size设为1还怎么训练模型,因此决定研究一下此错误。

Original Traceback (most recent call last):
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 83, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/home/cv/anaconda3/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 81, in default_collate
    raise RuntimeError('each element in list of batch should be of equal size')

根据报错信息可以查找错误来源在collate.py源码,错误就出现在default_collate()函数中。百度发现此源码的defaul_collate函数是DataLoader类默认的处理batch的方法,如果在定义DataLoader时没有使用collate_fn参数指定函数,就会默认调用以下源码中的方法。如果你出现了上述报错,应该就是此函数中出现了倒数第四行的错误

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int_classes):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))

此函数功能就是传入一个batch数据元组,元组中是每个数据是你定义的dataset类中__getitem__()方法返回的内容,元组长度就是你的batch_size设置的大小。但是DataLoader类中最终返回的可迭代对象的一个字段是将batch_size大小样本的相应字段拼接到一起得到的。因此默认调用此方法时,第一次会进入倒数第二行语句return [default_collate(samples) for samples in transposed]将batch元组通过zip函数生成可迭代对象。然后通过迭代取出相同字段递归重新传入default_collate()函数中,此时取出第一个字段判断数据类型在以上所列类型中,则可正确返回dateset内容。
如果batch数据是按以上顺序进行处理,则不会出现以上错误。如果进行第二次递归之后元素的数据还不在所列数据类型中,则依然会进入下一次也就是第三次递归,此时就算能正常返回数据也不符合我们要求,而且报错一般就是出现在第三次递归及之后。因此想要解决此错误,需要仔细检查自己定义dataset类返回字段的数据类型。也可以在defaule_collate()方法中输出处理前后batch内容,查看函数具体处理流程,以助力自己查找返回字段数据类型的错误。
友情提示:不要在源码文件中更改defaule_collate()方法,可以把此代码copy出来,定义一个自己的collate_fn()函数并在实例化DataLoader类时指定自己定义的collate_fn函数。
祝各位早日解决bug,跑通模型!

### 错误原因分析 在 PyTorch 中,`RuntimeError: each element in list of batch should be of equal size` 的主要原因是数据加载器 (`DataLoader`) 或者 `collate_fn` 函数尝试堆叠张量时发现这些张量的形状不一致[^1]。具体来说,在调用 `torch.stack()` 方法时,所有输入张量必须具有相同的尺寸才能成功执行操作。 此错误通常发生在自定义数据集类中的 `__getitem__` 方法返回的数据维度不同步的情况下。例如,某些图像可能有 RGB 三个通道 `[3, H, W]`,而另一些可能是灰度图只有单个通道 `[1, H, W]`[^3]。如果未处理好这种差异,则会引发上述运行时异常。 ### 解决方案 #### 方案一:标准化输入数据大小 确保所有样本都具备统一的形状是最直接有效的办法之一。可以通过预处理阶段调整每一张图片到固定分辨率以及相同颜色模式来实现这一点: ```python from torchvision import transforms transform = transforms.Compose([ transforms.Resize((448, 448)), # 统一分辨率 transforms.Grayscale(3), # 转换为三通道(即使原始图为灰度) transforms.ToTensor() # 转化成 Tensor 类型 ]) class CustomDataset(torch.utils.data.Dataset): def __init__(self, data_list, transform=None): self.data_list = data_list self.transform = transform def __len__(self): return len(self.data_list) def __getitem__(self, idx): img_path = self.data_list[idx] image = Image.open(img_path).convert('RGB') # 明确指定转换至 RGB 模式 if self.transform is not None: image = self.transform(image) return image ``` 通过以上方式可以有效避免因尺寸或通道数不同的情况引起的问题[^2]。 #### 方案二:自定义 collate_fn 函数 当无法完全控制源数据或者存在特殊情况不允许修改原数据结构的时候,可考虑重写 DataLoader 所使用的默认收集函数(`default_collate`) 。这样可以根据实际需求灵活地处理各种类型的批次数据而不强制要求它们严格匹配特定形式: ```python def custom_collate_fn(batch): imgs = [] labels = [] for item in batch: try: img, label = item imgs.append(transforms.functional.to_tensor(img)) labels.append(label) except Exception as e: print(f"Skipping corrupted example due to error {e}") # 使用 pad_sequence 来填充序列长度不一样的情况 from torch.nn.utils.rnn import pad_sequence padded_imgs = pad_sequence([img.permute(2,0,1) for img in imgs], padding_value=0).permute(1,2,3,0) return padded_imgs, torch.tensor(labels) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn) ``` 这里展示了如何构建一个简单的自定义 `collate_fn` ,它能够容忍部分损坏的例子并适配变长特征向量的情况。 --- ### 总结 针对该问题的核心在于保证进入模型前后的每一个 mini-batch 数据内部各元素均保持一致的形式;要么提前做好充分准备使得全部实例遵循同一规格标准,亦或是借助定制化的逻辑去动态适应多样性的状况。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值