【mmsegmentation】报错解决:AssertionError: The image size in a batch should be the same.

报错信息详细为:

Traceback (most recent call last):
  File "tools/train.py", line 104, in <module>
    main()
  File "tools/train.py", line 100, in main
    runner.train()
  File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
    model = self.train_loop.run()  # type: ignore
  File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 296, in run
    self.runner.val_loop.run()
  File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 379, in run
    self.run_iter(idx, data_batch)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 404, in run_iter
    outputs = self.runner.model.val_step(data_batch)
  File "/root/miniconda3/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 132, in val_step
    data = self.data_preprocessor(data, False)
  File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/autodl-tmp/mmsegmentation/mmseg/models/data_preprocessor.py", line 136, in forward
    assert all(input_.shape[1:] == img_size for input_ in inputs),  \
AssertionError: The image size in a batch should be the same.

从上面给出的错误信息可知,在验证阶段,代码碰到了一个断言错误(AssertionError),具体是在 mmseg/models/data_preprocessor.py 文件的第 136 行。此错误表明在一个批次(batch)里,图像的尺寸不一致,而代码要求批次内所有图像的尺寸必须相同。

错误原因分析

在深度学习训练过程中,模型一般要求输入的图像尺寸是一致的,这样才能高效地进行批量处理。当一个批次里的图像尺寸不一样时,就会触发这个断言错误。

解决办法

1. 数据预处理阶段统一图像尺寸

在数据加载和预处理时,要保证所有图像的尺寸相同。可以借助图像缩放操作达成这一目的。以下是一个示例,展示如何在 mmsegmentation 里使用 Resize 变换来统一图像尺寸:

from mmseg.datasets import build_dataset
from mmseg.datasets.pipelines import Compose

# 定义数据预处理管道
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(512, 512), keep_ratio=False),  # 统一图像尺寸
    dict(type='RandomFlip', prob=0.5),
    dict(type='PackSegInputs')
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', img_scale=(512, 512), keep_ratio=False),  # 统一图像尺寸
    dict(type='PackSegInputs')
]

# 构建数据集
train_dataset = build_dataset(dict(
    type='YourDatasetName',
    data_root='your_data_root',
    pipeline=train_pipeline,
    test_mode=False
))

val_dataset = build_dataset(dict(
    type='YourDatasetName',
    data_root='your_data_root',
    pipeline=test_pipeline,
    test_mode=True
))
2. 检查数据加载过程

要确保数据加载过程中没有引入尺寸不一致的图像。可能是数据集中存在损坏或者尺寸异常的图像,我们可以编写一个脚本来检查数据集里所有图像的尺寸:

import os
from PIL import Image

data_root = 'your_data_root'
image_dir = os.path.join(data_root, 'images')

for filename in os.listdir(image_dir):
    if filename.endswith(('.png', '.jpg', '.jpeg')):
        image_path = os.path.join(image_dir, filename)
        image = Image.open(image_path)
        width, height = image.size
        print(f'{filename}: {width}x{height}')
3. 调整批量大小

要是数据集里的图像尺寸差异很大,可以尝试减小批量大小,这样可以降低出现尺寸不一致的概率。在配置文件中修改 batch_size 参数:

train_dataloader = dict(
    batch_size=4,  # 减小批量大小
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='InfiniteSampler', shuffle=True),
    dataset=train_dataset
)

val_dataloader = dict(
    batch_size=4,  # 减小批量大小
    num_workers=4,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=val_dataset
)

通过上述方法,应该能够解决批次内图像尺寸不一致的问题。

mmsegmentation中,当你对图片进行Resize操作时,通常也需要对对应的掩码图(mask)进行相同的Resize操作,以保证图片和掩码图的尺寸始终一致。mmsegmentation提供了数据预处理管道(pipeline),可以方便地实现这一过程。以下为你详细介绍具体的实现方法。

4. 使用Resize变换

在数据预处理管道中使用Resize变换时,它会同时对图片和掩码图进行缩放操作。下面是一个示例配置:

from mmseg.datasets.pipelines import Compose

# 定义数据预处理管道
train_pipeline = [
    dict(type='LoadImageFromFile'),  # 加载图片
    dict(type='LoadAnnotations'),  # 加载掩码图
    dict(type='Resize', img_scale=(512, 512), keep_ratio=False),  # 同时对图片和掩码图进行缩放
    dict(type='RandomFlip', prob=0.5),  # 随机翻转
    dict(type='PackSegInputs')  # 打包数据
]

test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations'),
    dict(type='Resize', img_scale=(512, 512), keep_ratio=False),
    dict(type='PackSegInputs')
]

在上述代码中,Resize变换会将输入的图片和对应的掩码图都缩放为指定的尺寸(512, 512)keep_ratio=False表示不保持图片的原始宽高比,直接将图片和掩码图缩放到指定尺寸。

5. 构建数据集

使用上述定义的预处理管道来构建训练集和验证集:

from mmseg.datasets import build_dataset

# 构建训练集
train_dataset = build_dataset(dict(
    type='YourDatasetName',  # 替换为实际的数据集名称
    data_root='your_data_root',  # 替换为实际的数据根目录
    pipeline=train_pipeline,
    test_mode=False
))

# 构建验证集
val_dataset = build_dataset(dict(
    type='YourDatasetName',
    data_root='your_data_root',
    pipeline=test_pipeline,
    test_mode=True
))
6. 原理说明

mmsegmentation的数据预处理管道是按照顺序依次执行每个变换操作的。当执行到Resize变换时,它会根据输入的图片和掩码图的元数据(metadata),对它们同时进行缩放操作。具体来说,Resize变换会调用相应的图像处理库(如OpenCV)来实现缩放功能,并且会确保图片和掩码图的缩放比例一致,从而保证它们的尺寸始终匹配。

7. 注意事项
  • keep_ratio参数:如果keep_ratio=True,图片和掩码图会在保持原始宽高比的前提下进行缩放,可能会在边缘填充空白区域。因此,需要根据具体需求合理设置该参数。
  • 插值方法Resize变换默认使用双线性插值方法对图片进行缩放,对于掩码图,通常使用最近邻插值方法,以避免引入新的类别标签。在mmsegmentation中,这些默认设置已经处理好了,一般不需要额外修改。

通过以上步骤,可以在mmsegmentation中方便地对图片和掩码图进行相同的Resize操作。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旅途中的宽~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值