import torch
from torch import nn
class Nn(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def forward(self,input):
output =2*input+1
return output
forward_nn=Nn()
x =torch.tensor(2.0)
output =forward_nn(x)
print(output)
解释一下:
forward_nn=Nn() # 对象的实例化 output =forward_nn(2) ,这里没用forward_nn.forward ,是因为我们继承了 nn.modules ,
-
当你继承了
nn.Module时,nn.Module内部实现了__call__方法:
def __call__(self, *args, **kwargs): # 一些内部处理,比如 hooks return self.forward(*args, **kwargs)
-
所以:
forward_nn(2)
实际上就是:
forward_nn.__call__(2) # 进一步调用 forward() forward_nn.forward(2)

被折叠的 条评论
为什么被折叠?



