自定义autograd function

这篇博客详细介绍了如何在PyTorch中为自定义函数SegmentConsensus实现梯度运算。SegmentConsensus根据输入的共识类型(如平均或身份)进行操作,并在前向传播中计算输出。在反向传播时,它根据共识类型计算输入的梯度。此功能对于视频理解等任务中的时空特征聚合至关重要。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在TSN代码中, segmentconsensus是一个自定义函数, 所以要写一下它对应的梯度运算


# tj : https://blog.youkuaiyun.com/tsq292978891/article/details/79364140
class SegmentConsensus(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor, consensus_type, dim):


        ctx.consensus_type = consensus_type
        ctx.dim = dim
        ctx.shape = input_tensor.size()
        if consensus_type == 'avg':
            output = input_tensor.mean(dim=dim, keepdim=True)
        elif consensus_type == 'identity':
            output = input_tensor
        else:
            output = None

        return output



    @staticmethod
    def backward(ctx, grad_output):


        if ctx.consensus_type == 'avg':
            grad_in = grad_output.expand(ctx.shape) / float(ctx.shape[ctx.dim])
        elif ctx.consensus_type == 'identity':
            grad_in = grad_output
        else:
            grad_in = None

        return grad_in, None, None
        # tj : 常数要返回 None

# tj : https://blog.youkuaiyun.com/tsq292978891/article/details/79364140
class ConsensusModule(torch.nn.Module):

    def __init__(self, consensus_type, dim=1):
        super(ConsensusModule, self).__init__()
        self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
        self.dim = dim

    def forward(self, input):
        return SegmentConsensus.apply(input, self.consensus_type, self.dim)
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值