import torch
a = torch.ones([1, 2, 3, 4])
a_0 = torch.nn.functional.normalize(a, dim=0)
a_1 = torch.nn.functional.normalize(a, dim=1)
a_2 = torch.nn.functional.normalize(a, dim=2)
a_3 = torch.nn.functional.normalize(a, dim=-1)
print(a)
print("***********")
print(a_0)
print("***********")
print(a_1)
print("***********")
print(a_2)
print("***********")
print(a_3)
输出结果:
tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]])
***********
tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]])
***********
tensor([[[[0.7071, 0.7071, 0.7071, 0.7071],
[0.7071, 0.7071, 0.7071, 0.7071],
[0.7071, 0.7071, 0.7071, 0.7071]],
[[0.7071, 0.7071, 0.7071, 0.7071],
[0.7071, 0.7071, 0.7071, 0.7071],
[0.7071, 0.7071, 0.7071, 0.7071]]]])
***********
tensor([[[[0.5774, 0.5774, 0.5774, 0.5774],
[0.5774, 0.5774, 0.5774, 0.5774],
[0.5774, 0.5774, 0.5774, 0.5774]],
[[0.5774, 0.5774, 0.5774, 0.5774],
[0.5774, 0.5774, 0.5774, 0.5774],
[0.5774, 0.5774, 0.5774, 0.5774]]]])
***********
tensor([[[[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000]],
[[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000],
[0.5000, 0.5000, 0.5000, 0.5000]]]])