报错信息详细为:
Traceback (most recent call last):
File "tools/train.py", line 104, in <module>
main()
File "tools/train.py", line 100, in main
runner.train()
File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 296, in run
self.runner.val_loop.run()
File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 379, in run
self.run_iter(idx, data_batch)
File "/root/miniconda3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/mmengine/runner/loops.py", line 404, in run_iter
outputs = self.runner.model.val_step(data_batch)
File "/root/miniconda3/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 132, in val_step
data = self.data_preprocessor(data, False)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/root/autodl-tmp/mmsegmentation/mmseg/models/data_preprocessor.py", line 136, in forward
assert all(input_.shape[1:] == img_size for input_ in inputs), \
AssertionError: The image size in a batch should be the same.
从上面给出的错误信息可知,在验证阶段,代码碰到了一个断言错误(AssertionError
),具体是在 mmseg/models/data_preprocessor.py
文件的第 136 行。此错误表明在一个批次(batch)里,图像的尺寸不一致,而代码要求批次内所有图像的尺寸必须相同。
错误原因分析
在深度学习训练过程中,模型一般要求输入的图像尺寸是一致的,这样才能高效地进行批量处理。当一个批次里的图像尺寸不一样时,就会触发这个断言错误。
解决办法
1. 数据预处理阶段统一图像尺寸
在数据加载和预处理时,要保证所有图像的尺寸相同。可以借助图像缩放操作达成这一目的。以下是一个示例,展示如何在 mmsegmentation
里使用 Resize
变换来统一图像尺寸:
from mmseg.datasets import build_dataset
from mmseg.datasets.pipelines import Compose
# 定义数据预处理管道
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512), keep_ratio=False), # 统一图像尺寸
dict(type='RandomFlip', prob=0.5),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', img_scale=(512, 512), keep_ratio=False), # 统一图像尺寸
dict(type='PackSegInputs')
]
# 构建数据集
train_dataset = build_dataset(dict(
type='YourDatasetName',
data_root='your_data_root',
pipeline=train_pipeline,
test_mode=False
))
val_dataset = build_dataset(dict(
type='YourDatasetName',
data_root='your_data_root',
pipeline=test_pipeline,
test_mode=True
))
2. 检查数据加载过程
要确保数据加载过程中没有引入尺寸不一致的图像。可能是数据集中存在损坏或者尺寸异常的图像,我们可以编写一个脚本来检查数据集里所有图像的尺寸:
import os
from PIL import Image
data_root = 'your_data_root'
image_dir = os.path.join(data_root, 'images')
for filename in os.listdir(image_dir):
if filename.endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(image_dir, filename)
image = Image.open(image_path)
width, height = image.size
print(f'{filename}: {width}x{height}')
3. 调整批量大小
要是数据集里的图像尺寸差异很大,可以尝试减小批量大小,这样可以降低出现尺寸不一致的概率。在配置文件中修改 batch_size
参数:
train_dataloader = dict(
batch_size=4, # 减小批量大小
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=train_dataset
)
val_dataloader = dict(
batch_size=4, # 减小批量大小
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=val_dataset
)
通过上述方法,应该能够解决批次内图像尺寸不一致的问题。
在mmsegmentation
中,当你对图片进行Resize
操作时,通常也需要对对应的掩码图(mask)进行相同的Resize
操作,以保证图片和掩码图的尺寸始终一致。mmsegmentation
提供了数据预处理管道(pipeline),可以方便地实现这一过程。以下为你详细介绍具体的实现方法。
4. 使用Resize
变换
在数据预处理管道中使用Resize
变换时,它会同时对图片和掩码图进行缩放操作。下面是一个示例配置:
from mmseg.datasets.pipelines import Compose
# 定义数据预处理管道
train_pipeline = [
dict(type='LoadImageFromFile'), # 加载图片
dict(type='LoadAnnotations'), # 加载掩码图
dict(type='Resize', img_scale=(512, 512), keep_ratio=False), # 同时对图片和掩码图进行缩放
dict(type='RandomFlip', prob=0.5), # 随机翻转
dict(type='PackSegInputs') # 打包数据
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(512, 512), keep_ratio=False),
dict(type='PackSegInputs')
]
在上述代码中,Resize
变换会将输入的图片和对应的掩码图都缩放为指定的尺寸(512, 512)
。keep_ratio=False
表示不保持图片的原始宽高比,直接将图片和掩码图缩放到指定尺寸。
5. 构建数据集
使用上述定义的预处理管道来构建训练集和验证集:
from mmseg.datasets import build_dataset
# 构建训练集
train_dataset = build_dataset(dict(
type='YourDatasetName', # 替换为实际的数据集名称
data_root='your_data_root', # 替换为实际的数据根目录
pipeline=train_pipeline,
test_mode=False
))
# 构建验证集
val_dataset = build_dataset(dict(
type='YourDatasetName',
data_root='your_data_root',
pipeline=test_pipeline,
test_mode=True
))
6. 原理说明
mmsegmentation
的数据预处理管道是按照顺序依次执行每个变换操作的。当执行到Resize
变换时,它会根据输入的图片和掩码图的元数据(metadata),对它们同时进行缩放操作。具体来说,Resize
变换会调用相应的图像处理库(如OpenCV
)来实现缩放功能,并且会确保图片和掩码图的缩放比例一致,从而保证它们的尺寸始终匹配。
7. 注意事项
keep_ratio
参数:如果keep_ratio=True
,图片和掩码图会在保持原始宽高比的前提下进行缩放,可能会在边缘填充空白区域。因此,需要根据具体需求合理设置该参数。- 插值方法:
Resize
变换默认使用双线性插值方法对图片进行缩放,对于掩码图,通常使用最近邻插值方法,以避免引入新的类别标签。在mmsegmentation
中,这些默认设置已经处理好了,一般不需要额外修改。
通过以上步骤,可以在mmsegmentation
中方便地对图片和掩码图进行相同的Resize
操作。