torch.scatter_()函数主要用于生成one_hot编码
torch.scatter_(dim, index, src):
dim:来确定是按照行还是列进行映射,dim=0=列,dim=1=行
index:映射按照index的元素进行
src:需要映射的值
dim=0 列映射:
实例:
对于scatter_()函数接受三个参数:维度,映射索引,映射数值
a=torch.zeros(3,5)生成三行五列元素均为0的张量
index=torch.tensor([ [0, 1, 2, 0, 0], [2, 0, 0, 1, 2] ])映射索引
b=torch.rand(2,5)映射数值
a.scatter_(dim=0,index,b)
split()以空格进行切割
strip([char])用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。