adj.sum(dim=-1, keepdim=True).clamp(min=1)
将张量里面的最后一个维度数值相加,即把一行的数值相加得到一个数,且保持数值最小值为1
比如:
import torch
import numpy as np
x = torch.arange(0,48)
b = x.view(2,2,4,3)
print(b)
print(b.sum(dim=-1, keepdim=True))
初始(2x2x4x3)
tensor([[[[ 0, 1, 2],
[ 3,</