torch.as_strided(input, size, stride, storage_offset=0)→ Tensor
1. 官方文档
Create a view of an existing torch.Tensor input with specified size, stride and storage_offset.
parameters
- input (Tensor) – the input tensor.
- size (tuple or ints) – the shape of the output tensor
- stride (tuple or ints) – the stride of the output tensor
- storage_offset (int, optional) – the offset in the underlying storage of the output tensor
WARNING
More than one element of a created tensor may refer to a single memory location. As a result, in-place operations (especially ones that are vectorized) may result in incorrect behavior. If you need to write to the tensors, please clone them first.
Many PyTorch functions, which return a view of a tensor, are internally implemented with this function. Those functions, like
torch.Tensor.expand(), are easier to read and are therefore more advisable to use.
e.g.
>>> x = torch.randn(3, 3)
>>> x
tensor([[ 0.9039, 0.6291, 1.0795],
[ 0.1586, 2.1939, -0.4900],
[-0.1909, -0.7503, 1.9355]])
>>> t = torch.as_strided(x, (2, 2), (1, 2))
>>> t
tensor([[0.9039, 1.0795],
[0.6291, 0.1586]])
>>> t = torch.as_strided(x, (2, 2), (1, 2), 1)
tensor([[0.6291, 0.1586],
[1.0795, 2.1939]])
2. 个人理解
基于指定的参数(size、stride、offset)在原内存进行扩展生成的Tensor。因为按内存顺序读元素,会造成元素共享,因此如果对此过程生成的Tensor直接in_place修改可能会出现意想不到的错误,建议使用深拷贝。
关于其原理,不要对拘泥于原数据的行列关系,按照对应的stride从内存中读数即可。如下图所示,从4x4的源数据中按要求抽取对应的元素。
注意,这里起点为1,W,H方向上都是基于该起点进行。如果加了offset,则是将起点进行offset,见后续代码测试结果。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8eiIiw8s-1633252597768)

测试代码:
import torch
a = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
b = torch.as_strided(a, (2, 2), (0, 0))
c = torch.as_strided(a, (2, 2), (0, 1))
d = torch.as_strided(a, (2, 2), (1, 0))
e = torch.as_strided(a, (2, 2), (2, 0))
f = torch.as_strided(a, (2, 2), (0, 2))
g = torch.as_strided(a, (2, 2), (1, 1))
h = torch.as_strided(a, (2, 2), (2, 1))
i = torch.as_strided(a, (2, 2), (1, 2))
j = torch.as_strided(a, (2, 2), (2, 1), 1)
k = torch.as_strided(a, (2, 2), (1, 2), 1)
print(' >> input data: {}'.format(a))
print(' >> Test stride: (0, 0), result: {}'.format(b))
print(' >> Test stride: (0, 1), result: {}'.format(c))
print(' >> Test stride: (1, 0), result: {}'.format(d))
print(' >> Test stride: (2, 0), result: {}'.format(e))
print(' >> Test stride: (0, 2), result: {}'.format(f))
print(' >> Test stride: (1, 1), result: {}'.format(g))
print(' >> Test stride: (2, 1), result: {}'.format(h))
print(' >> Test stride: (1, 2), result: {}'.format(i))
print(' >> Test stride: (2, 1), offset: 1, result: {}'.format(h))
print(' >> Test stride: (1, 2), offset: 1, result: {}'.format(i))
测试结果:
>> input data: tensor([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12],
[13, 14, 15, 16]])
>> Test stride: (0, 0), result: tensor([[1, 1],
[1, 1]])
>> Test stride: (0, 1), result: tensor([[1, 2],
[1, 2]])
>> Test stride: (1, 0), result: tensor([[1, 1],
[2, 2]])
>> Test stride: (2, 0), result: tensor([[1, 1],
[3, 3]])
>> Test stride: (0, 2), result: tensor([[1, 3],
[1, 3]])
>> Test stride: (1, 1), result: tensor([[1, 2],
[2, 3]])
>> Test stride: (2, 1), result: tensor([[1, 2],
[3, 4]])
>> Test stride: (1, 2), result: tensor([[1, 3],
[2, 4]])
>> Test stride: (2, 1), offset: 1, result: tensor([[2, 3],
[4, 5]])
>> Test stride: (1, 2), offset: 1, result: tensor([[2, 4],
[3, 5]])
本文深入介绍了PyTorch中的torch.as_strided函数,用于创建Tensor视图并指定大小、步长和存储偏移。由于生成的Tensor元素可能共享内存,直接修改可能导致错误,建议先深拷贝。通过示例展示了不同参数组合下的结果,揭示了该函数的工作原理和潜在风险。
335

被折叠的 条评论
为什么被折叠?



