def scatter_(self, dim, index, src, reduce=None):
'''
dim: 在哪个维度上进行赋值(操作)
index: 可以理解为标签
src: 具体给什么数值(通常给 1,毕竟是one-hot)
'''
这里只对四维数据分析one-hot,即 (batch_size, num_class, h, w) ------图像------
假设输入的图像是batch_size = 2,且是灰度图像,即数据类型是 (2, 1, h, w)。
import torch
batch_size = 2
class_num = 3
H = 5
W = 5
# 灰度图像,channel == 1
label = torch.randint(0, class_num, size=(batch_size, 1, H, W))
print(label)
one_hot = torch.zeros(batch_size, class_num, H, W)
target = one_hot.scatter_(1, label.type(torch.long), 1)
print(target)
理解(分析):
以第一个batch分析,第一个batch的label是
以这个标签(class=2)即左上角第一个数据(红色框)为例,由于dim=1,所以将src=1赋值给one-hot中索引为 [0, 2, 0, 0] 这个位置 (前面第一个位置 0是 batch_size索引,由于这里选择第一个batch,所以索引是0,第二个位置2是标签索引,也就是class=2,后面的h和w是原来标签所在的h和w位置)。同理,绿色框的数据和蓝色框的数据)。