RuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1,3,840,840]
以上是在训练semantic segmentation时,出现的报错。翻了一圈看到的唯一靠谱的解释。
原因大概就是label图片需要1维的数据格式(灰度图),但是图片在输入前仍是3维的RGB图片,没转换成1维的。以下是靠谱解释的链接
dcRuntimeError: 1only batches of spatial targets supported (3D tensors) but got targets of size: : [1, 3, 96, 128] - vision - PyTorch ForumsI am using UNet I have images with masks (black background and highlighted portion contain different RGB colors) and 12 classes. I want to do training using uNET. but i got error. “RuntimeError: 1only batches of spatial…https://discuss.pytorch.org/t/runtimeerror-1only-batches-of-spatial-targets-supported-3d-tensors-but-got-targets-of-size-1-3-96-128/95030Training Semantic Segmentation - #3 by WeiQin_Chuah - vision - PyTorch ForumsHi, I am trying to reproduce PSPNet using PyTorch and this is my first time creating a semantic segmentation model. I understand that for image classification model, we have RGB input = [h,w,3] and label or ground truth…
https://discuss.pytorch.org/t/training-semantic-segmentation/49275/3
/tmp/pip-req-build-xlj_h8ax/aten/src/THCUNN/SpatialClassNLLCriterion.cu:106: cunn_SpatialClassNLLCriterion_updateOutput_kernel: block: [4,0,0], thread: [417,0,0] Assertion `t >= 0 && t < n_classes` failed.
RuntimeError: CUDA error: device-side assert triggered
以上报错是因为训练时候输入的是4类,但是label数据因为用cv2.resize时,采用了默认的cv2.INTER_LINER插值法,插入了其他非类别的数值,改成cv2.INTER_NEAREST最近邻插值法即可。转换语句如下:
cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST)
# 不可简写成cv2.resize(image, (tw, th), cv2.INTER_NEAREST),否则仍是默认的cv2.INTER_LINER
参考:
Is there anybody happen this error? - #14 by XiaoAHeng - autograd - PyTorch Forums/opt/conda/conda-bld/pytorch_1512386481460/work/torch/lib/THCUNN/SpatialClassNLLCriterion.cu:99: void cunn_SpatialClassNLLCriterion_updateOutput_kernel(T *, T *, T *, long *, T *, int, int, int, int, int, long) [with T =…https://discuss.pytorch.org/t/is-there-anybody-happen-this-error/17416/14 由于PIL.Image 和cv2读取图片后数据格式不同,(Image读取后的size是宽,高;cv2读取后的shape是高,宽),我就手残把处理图片的代码全改了(其实完全不用),以下是使用差别较大的几个方法:
# 截取图片
label = label[h:h+h, w:w+w] # cv2
label = label.crop((w, h, w + w, h + h)) # PIL
# 左右翻转
cv2.flip(image, 1) # cv2
image.transpose(0) # PIL
# 修改图片尺寸,特别要注意interpolation的对应关系
cv2.resize(image, (tw, th), interpolation=cv2.INTER_NEAREST) # cv2
image.resize((tw, th), interpolation) # PIL