nn.Sequential()引起的 forward() takes 1 positional argument but 2 were given

博客讲述了在训练PyTorch模型时尝试移除分类层以获取特征图,但遇到了`forward()`函数传参错误的问题。作者检查了常见的错误源,如`__init__`拼写、调用对象方式以及参数数量,最终发现问题出在模型内部的DWConv层,该层的`forward`方法需要多个参数。解决方案是不使用`nn.Sequential`,而是直接修改原始PvT模型的源码以适应需求,从而解决了问题。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

最近在训练模型时,想要将模型的分类层去除,输出模型的特征图,于是进行如下操作去除模型的最后两层结构,然后奇怪的事情就发生了,运行时程序老是报错, forward() takes 1 positional argument but 2 were given!!!

class PvT(nn.Module):
    def __init__(self):  # num_classes,此处为 二分类值为2
        super().__init__()
        #创建模型,并且加载预训练参数
        net= pvt_v2.PyramidVisionTransformerV2()
        #去除分类层
        self.feature = nn.Sequential(*list(net.children())[:-2])

    def forward(self, x):
        x = self.feature(x)
        B, N, C = x.shape
        N = int(N ** 0.5)
        x = rearrange(x, 'b (h w) c -> b c h w', h=N, w=N)
        return x

于是在网上搜寻答案,常见的几个回答有三个:

一. __init__拼写问题

这个就是注意init前后都有两个短的下划线,这个问题好解决,自己对照着改改就好,别整错了

二. 调用对象问题

python调用类时,需要先将类实例化,再给对象传入参数

错误示例:

net=PvT(input)

正确示例:

net=PvT()
net(input)

三. 就是forward需要传入的是一个参数,确实传了两个进来

net=PvT()
net(input1, input2)

如果检查以上三个都没有问题了,那就需要考虑是不是 nn.Sequential()引起的问题

nn.Sequential本质上定义了一个网络,这个网络里面有天然存在的输入输出继承关系。我们可以通过nn.Sequential的源码看到,其自带的forward() 函数不支持传递多个参数。经过查看PvT的源码,发现该模型里面有子模块DWCov,forward里面需要传入多个参数,故此不能使用nn.Sequential,所以会报错。

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)
        return x

最后直接将PvT源码拿过来用,将最后的head层去除,不使用nn.Sequential,问题得以解决。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值