PyTorch函数中的__call__和forward函数

本文通过实例解析PyTorch中nn.Module的_init_()和_call_()方法,以及它们与forward的关系。nn.Module是构建神经网络的基础,它的_call_()内部调用了forward(),允许对象像函数一样被调用。通过重写_call_()或forward(),可以自定义网络行为。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

初学nn.Module,看不懂各种调用,后来看明白了,估计会忘,故写篇笔记记录

init & call

代码:

class A():
    def __init__(self):
        print('init函数')        
        
    def __call__(self, param):
        print('call 函数', param)
a = A()

输出
在这里插入图片描述
分析:A进行类的实例化,生成对象a,这个过程自动调用_init_(),没有调用_call_()


上面的代码加一行

class A():
    def __init__(self):
        print('init函数') 
        
    def __call__(self, param):
        print('call 函数', param)
a = A()
a(1)

输出
在这里插入图片描述
分析:a是对象,python中让对象有了像函数一样加括号(参数)的功能,使用这种功能时,自动调用_call_()


_ call_()中可以调用其它函数,如forward函数

class A():
    def __init__(self):
        print('init函数')
        
    def __call__(self, param):
        print('call 函数', param)
        res = self.forward(param)
        return res + 2
        
    def forward(self, input_): 
        print('forward 函数', input_)
        return input_
    
a = A()
b = a(1)
print('结果b =',b)

在这里插入图片描述
分析:_call _()成功调用了forward(),且返回值给了b


nn.Module

看了上面的例子,就知道了_call _()的作用,那下面看更接近CNN的例子

from torch import nn
import torch

class Ding(nn.Module):
    def __init__(self):
        print('init')
        super().__init__()
    
    def forward(self, input):
        output = input + 1
        print("forward")
        return output

dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)

结果:
在这里插入图片描述
分析:
这里并没有调用_call_() 和forward(),但还是显示了forward,原因是:Ding这个子类继承了父类nn.Module里的call函数,接下来去源码看
在这里插入图片描述
发现_call_调用了_call_impl这个函数,相当于起了个外号一样,那就去这个函数看

在这里插入图片描述
在这里插入图片描述

这里有很多参数,详细可见参考2。发现这里forward_call 要么是_slow_forward,要么是self.forward(),而这个_slow_forward()也会用self.forward()
在这里插入图片描述
所以: _call _()用了forward,而这个父类的forward在子类中重写了(简单代码)
在这里插入图片描述


当然,也可以重写__call__(),比如我们不让它使用forward()

from torch import nn
import torch

class Ding(nn.Module):
    def __init__(self):
        print('init')
        super().__init__()
        
    def __call__(self, input_):
        print('重写call, 不用forward')
        return 'hhh'
        
    def forward(self, input):
        output = input + 1
        print("forward")
        return output

dzy = Ding()
x = torch.tensor(1.0)
out = dzy(x)
print(out)

在这里插入图片描述

总结

使用对象dzy(x)时,用了父类nn.Module的call函数,调用了forward,而这个forward又被我们在子类里重写了。

参考

https://blog.youkuaiyun.com/dss_dssssd/article/details/83750838
https://zhuanlan.zhihu.com/p/366461413

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值