Pytorch的tensor的view方法
相当于numpy中resize()的功能,但是用法可能不太一样。
把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。比如说是不管你原先的数据是[[[1,2,3],[4,5,6]]]还是[1,2,3,4,5,6],因为它们排成一维向量都是6个元素,所以只要view后面的参数一致,得到的结果都是一样
X.view(x.size(0),-1) ,-1为推断是几(即一维X的len除以x.size(0))。但有的机器推断不出来。
这句话一般出现在model类的forward函数中,具体位置一般都是在调用分类器之前。分类器是一个简单的nn.Linear()结构,输入输出都是维度为一的值,x = x.view(x.size(0), -1) 这句话的出现就是为了将前面多维度的tensor展平成一维。下面是个简单的例子,我将会根据例子来对该语句进行解析。
class NET(nn.Module):
def __init__(self,batch_size):
super(NET,self).__init__()
self.conv = nn.Conv2d(outchannels=3