torch.size(), torch.view(),torch.add_()

1.torch.size()的实质上是tuple

from __future__ import print_function
import torch

x = torch.Tensor(5, 3)
print(x.size()[0])  # Print the row dimension
print(x.size()[1])  # Print the column dimension

输出: 5 3
注: 可以使用tuple的基本操作,获取torch.size()的值

2.使用torch.view对tensor进行变形

from __future__ import print_function
import torch

x = torch.randn(8, 8)
z = x.view(-1, 4)
print(x.size(), z.size())

输出: (8L, 8L) (16L, 4L)

注:这里如果单独对x.view(-1,4)之后,x的shape是不发生任何变化的,需要重新赋值给另外的变量。

3.使用add_()对原始的tensor的每一个数值进行加的操作

from __future__ import print_function
import torch
x=torch.rand(5,3)
print(x)
print(x.add_(1))

输出:

 0.3659  0.1633  0.2380
 0.1108  0.9994  0.9355
 0.3237  0.9887  0.1847
 0.7046  0.0212  0.7640
 0.0731  0.2785  0.1676
[torch.FloatTensor of size 5x3]

 1.3659  1.1633  1.2380
 1.1108  1.9994  1.9355
 1.3237  1.9887  1.1847
 1.7046  1.0212  1.7640
 1.0731  1.2785  1.1676
[torch.FloatTensor of size 5x3]

注:torch.add_()这里加了下划线,表示是内建(in-place)函数, 将要改变x的值

一般来说函数加了下划线的属于内建函数,将要改变原来的值,没有加下划线的并不会改变原来的数据,引用时需要另外赋值给其他变量

def forward(self, feat_fv, feat_fi, feat_v, feat_i): batchSize = feat_fv.shape[0] dim = feat_fv.shape[1] T = 0.07 feat_fv = F.normalize(feat_fv, dim=1) feat_fi = F.normalize(feat_fi, dim=1) feat_v = F.normalize(feat_v, dim=1) feat_i = F.normalize(feat_i, dim=1) l_pos_v = torch.bmm(feat_fv.view(batchSize, 1, -1), feat_v.view(batchSize, -1, 1)) l_pos_v = l_pos_v.view(batchSize, 1) l_pos_i = torch.bmm(feat_fi.view(batchSize, 1, -1), feat_i.view(batchSize, -1, 1)) l_pos_i = l_pos_i.view(batchSize, 1) l_pos = torch.zeros((batchSize, 1)).to(self.device) l_neg = torch.zeros((batchSize, batchSize-1)).to(self.device) for b in range(batchSize): if l_pos_v[b] >= l_pos_i[b]: # pos logit l_pos_batch = l_pos_v[b].unsqueeze(0) l_pos = l_pos.scatter_add_(0, torch.tensor([[b]]).to(self.device), l_pos_batch) # neg logit feat_v_without_pos = feat_v[torch.arange(feat_v.size(0)) != b] index_base = [b for _ in range(batchSize-1)] index = torch.tensor([index_base]).to(self.device) l_neg_batch = torch.bmm(feat_fv[b].unsqueeze(0).unsqueeze(0), feat_v_without_pos.view(1, -1, dim).transpose(2, 1)) l_neg_batch = l_neg_batch.view(-1, batchSize-1) l_neg = l_neg.scatter_add_(0, index, l_neg_batch) else: # pos logit l_pos_batch = l_pos_i[b].unsqueeze(0) l_pos = l_pos.scatter_add_(0, torch.tensor([[b]]).to(self.device), l_pos_batch) # neg logit feat_i_without_pos = feat_i[torch.arange(feat_i.size(0)) != b] index_base = [b for _ in range(batchSize-1)] index = torch.tensor([index_base]).to(self.device) l_neg_batch = torch.bmm(feat_fi[b].unsqueeze(0).unsqueeze(0), feat_i_without_pos.view(1, -1, dim).transpose(2, 1)) l_neg_batch = l_neg_batch.view(-1, batchSize-1) l_neg = l_neg.scatter_add_(0, index, l_neg_batch) out = torch.cat((l_pos, l_neg), dim=1) / T loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=feat_fv.device)) return loss
03-21
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值