torch.masked_select
import torch
x = torch.randn(3,4)
print(x)
mask = torch.randn(3,1) >0
print(mask)
s=torch.masked_select(x,mask)
print(s)
输出
tensor([[-2.7601, -0.9957, -0.2239, -1.2674],
[ 1.2190, 1.6782, -0.3635, 0.6803],
[-1.4541, -2.6175, 0.5655, 1.8493]])
tensor([[ True],
[False],
[ True]])
tensor([-2.7601, -0.9957, -0.2239, -1.2674, -1.4541, -2.6175, 0.5655, 1.8493])
import torch
x = torch.randn(3,4)
print(x)
mask = torch.randn(3,1) >0
print(mask)
re=torch.randn(3,4)+100
print(re)
x.masked_scatter_(mask,re)
print(x)
输出
tensor([[-0.2986, -2.1390, -1.3152, 0.1101],
[-0.3291, -0.1943, 0.5282, -0.9538],
[-0.1157, 1.1725, -1.1309, 1.1864]])
tensor([[ True],
[False],
[ True]])
tensor([[100.2276, 99.5916, 100.3636, 98.6109],
[ 98.6181, 98.8684, 100.3451, 100.8166],
[100.1912, 100.7507, 100.6337, 100.2147]])
tensor([[100.2276, 99.5916, 100.3636, 98.6109],
[ -0.3291, -0.1943, 0.5282, -0.9538],
[ 98.6181, 98.8684, 100.3451, 100.8166]])