scatter_(input, dim, index, src)将src中数据根据index中的索引按照dim的方向填进input中.
1 >>> x = torch.rand(2, 5) 2 >>> x 3 4 0.4319 0.6500 0.4080 0.8760 0.2355 5 0.2609 0.4711 0.8486 0.8573 0.1029 6 [torch.FloatTensor of size 2x5]
1) dim = 0,分别对每列填充:
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) 0.4319 0.4711 0.8486 0.8760 0.2355 0.0000 0.6500 0.0000 0.8573 0.0000 0.2609 0.0000 0.4080 0.0000 0.1029 [torch.FloatTensor of size 3x5]
实现原理:
对于LoneTensor内的矩阵,暂且称为 tmp = [[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]];将最终的 3*5的矩阵,暂且称为result。result初始为全0,需要经过scatter_处理。
举例:
对于tmp[0][0] = 0 -> 取x中x[0][0] = 0.4319,将其插入到result第

本文介绍了PyTorch中的scatter_()函数,用于根据指定的索引在Tensor中填充数据。scatter_()沿特定维度将源数据src的值放置到输入input对应的位置,根据index的指示修改Tensor。该函数常用于one-hot编码和张量的特定位置赋值。
最低0.47元/天 解锁文章
1562





