1.torch.numel() 返回tensor变量内所有元素的个数,也可以简单理解为矩阵内yu元素的个数
例如,a的size为([64, 3, 7, 7]),那么a.numel() 返回值为64*3*7*7=9408
2.torch.squeeze() 将输入张量形状中的1去除并返回,如果输入是形如(Ax1xBx1xCx1xD),那么输出形状就为(AxBxCxD)

3.torch.unsqueeze() 在指定维度上增加一个维度,该维度数为1

4.torch.sort() 返回元组,第一个是pa排序后的值,第二个是排序后的值在元数据中的index,如果不给dim,默认
按照最后一维

5.torch.mean() 如果不给维度,则fa返回所有值的均值

6.torch.gather(input, dim, index) 把输入按照给定的维度和index取input中的值

7.torch.stack(sequence,dim,out) stack()会增加一个维度出来,sequence是tensor列表,dim表示拼接的维度,而torch.cat()是在已有的维度上拼接,




本文详细介绍了PyTorch中的一些基本操作,包括torch.numel()用于计算tensor元素总数,torch.squeeze()和torch.unsqueeze()用于调整张量维度,torch.sort()进行排序,torch.mean()计算平均值,torch.gather()按指定维度和索引选取值,以及torch.stack()和torch.cat()的使用方法。
1011

被折叠的 条评论
为什么被折叠?



