U = U_w + U_j # b*c*es*1 print("concat:", U.shape) p = torch.mean(U, dim=-1, out=None) # b*c*es print(Ugp.shape)