scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中。
import torch
import numpy as np
def one_hot(labels):
_labels = torch.zeros([2,2,4])
_labels.scatter_(dim=0, index=labels.long(), value=1)#scatter_(input, dim, index, src)
return _labels
temp = torch.Tensor([[[1,1,0,0],
[1,1,0,0]]])
target2 = one_hot(temp)
print(target2)
输出:channel0是背景,channel1是前景。
tensor([[[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[1., 1., 0., 0.],
[1., 1., 0., 0.]]])