在深度学习中进行张量操作的时候经常会遇到这样的切片形式,一开始总会感到头疼,所以现在总结一下用法,欢迎大家进行补充。
data[x:y:z]
data[::-1]
data[:-1]
data[-3:]
data[1:4, 2:4]
data[…, 1:2]
data[:, None]
这是最近在看的代码,进行loss的计算,用到了很多切片的知识,所以进行总结,希望看完这篇博客以后大家也能看懂类似的代码
def compute_loss(p, targets, model): # predictions, targets, model
ft = torch.cuda.FloatTensor if p[0].is_cuda else torch.Tensor
lxy, lwh, lcls, lconf = ft([0]), ft([0]), ft([0]), ft([0])
# build_targets对targets向量进行处理
txy, twh, tcls, indices = build_targets(model, targets)
# Define criteria
MSE = nn.MSELoss()
CE = nn.CrossEntropyLoss() # (weight=model.class_weights)
BCE = nn.BCEWithLogitsLoss()
# Compute losses
h = model.hyp # hyperparameters
bs = p[0].shape[0] # batch size
k = bs # loss gain
for i, pi0 in enumerate(p): # layer i predictions, i
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
tconf = torch.zeros_like(pi0[..., 0]) # conf
# Compute losses
if len(b): # number of targets
pi = pi0[b, a, gj, gi] # predictions closest to anchors
tconf[b, a, gj, gi] = 1 # conf
# pi[..., 2:4] = torch.sigmoid(pi[..., 2:4]) # wh power loss (uncomment)
lxy += (k * h['xy']) * MSE(torch.sigmoid(pi[..., 0:2]), txy[i]) # xy loss
lwh += (k * h['wh']) * MSE(pi[..., 2:4], twh[i]) # wh yolo loss
lcls += (k * h['cls']) * CE(pi[..., 5:], tcls[i]) # class_conf loss
# pos_weight = ft([gp[i] / min(gp) * 4.])
# BCE = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
lconf += (k * h['conf']) * BCE(pi0[..., 4], tconf) # obj_conf loss
loss = lxy + lwh + lconf + lcls
return loss, torch.cat((lxy, lwh, lconf, lcls, loss)).detach()
1. 单列表切片
对于一个普遍形式:data[x:y:z] 讲一下他的意义:
The syntax
[x:y:z]
means "take everyz
th element of a list from indexx
to indexy
". Whenz
is negative, it indicates going backwards. Whenx
isn’t specified, it defaults to the first element of the list in the direction you are traversing the list. Wheny
isn’t specified, it defaults to the last element of the list. So if we want to take every 2th element of a list, we use[::2]
.
– from github:python-is-cool
翻译过来就是:x是起始坐标,y是结束坐标(不包含),z是代表每隔z个数取一个数,如果z是负数代表方向是从后往前。
看一个例子:
>>> data = list(range(10))
>>> data
[0, 1, 2, 3, 4, 5, 6, 7, 8