索引与切片
Indexing
>>> a=torch.rand(4,3,28,28)
>>> a[0].shape
torch.Size([3, 28, 28])
>>> a[0,0].shape
torch.Size([28, 28])
>>> a[0,0,2,4]
tensor(0.5978)
select first/last N
>>> a.shape
torch.Size([4, 3, 28, 28])
>>> a[:2].shape
torch.Size([2, 3, 28, 28])
>>> a[:2,:1,:,:].shape
torch.Size([2, 1, 28, 28])
>>> a[:2,1:,:,:].shape
torch.Size([2, 2, 28, 28])
>>> a[:2,-1:,:,:].shape
torch.Size([2, 1, 28, 28])
select by steps
>>> a[:,:,0:28:2,0:28:2].shape
torch.Size([4, 3, 14, 14])
>>> a[:,:,::2,::2].shape
torch.Size([4, 3, 14, 14])
select by specific index
>>> a.shape
torch.Size([4, 3, 28, 28])
>>> a.index_select(0,torch.tensor([0,2])).shape
torch.Size([2, 3, 28, 28])
>>> a.index_select(1,torch.tensor([1,2])).shape
torch.Size([4, 2, 28, 28])
>>> a.index_select(2,torch.arange(28)).shape
torch.Size([4, 3, 28, 28])
>>> a.index_select(2,torch.arange(8)).shape
torch.Size([4, 3, 8, 28])
>>> a[...].shape
torch.Size([4, 3, 28, 28])
>>> a[0,...].shape
torch.Size([3, 28, 28])
>>> a[:,1,...].shape
torch.Size([4, 28, 28])
>>> a[...,:2].shape
torch.Size([4, 3, 28, 2])
select by mask
>>> x=torch.randn(3,4)
>>> x
tensor([[-0.0417, -0.7808, 0.4916, -0.7928],
[ 0.1324, 1.8244, -0.7456, 0.4702],
[-0.4376, 1.6554, -0.3548, 0.9621]])
>>> mask=x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, False, False],
[False, True, False, True]])
>>> torch.masked_select(x,mask)
tensor([1.8244, 1.6554, 0.9621])
>>> torch.masked_select(x,mask).shape
torch.Size([3])
select by flatten index
>>> src=torch.tensor([[4,3,5],[6,7,8]])
>>> src
tensor([[4, 3, 5],
[6, 7, 8]])
>>> torch.take(src,torch.tensor([0,2,5]))
tensor([4, 5, 8])