先附上官方文档说明:torch.nn.functional — PyTorch 1.13 documentation
torch.nn.functional.
kl_div
(input, target, size_average=None, reduce=None, reduction='mean')Parameters
input – Tensor of arbitrary shape
target – Tensor of the same shape as input
size_average (bool, optional) – Deprecated (see
reduction
). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the fieldsize_average
is set toFalse