在TSN代码中, segmentconsensus是一个自定义函数, 所以要写一下它对应的梯度运算
# tj : https://blog.youkuaiyun.com/tsq292978891/article/details/79364140
class SegmentConsensus(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor, consensus_type, dim):
ctx.consensus_type = consensus_type
ctx.dim = dim
ctx.shape = input_tensor.size()
if consensus_type == 'avg':
output = input_tensor.mean(dim=dim, keepdim=True)
elif consensus_type == 'identity':
output = input_tensor
else:
output = None
return output
@staticmethod
def backward(ctx, grad_output):
if ctx.consensus_type == 'avg':
grad_in = grad_output.expand(ctx.shape) / float(ctx.shape[ctx.dim])
elif ctx.consensus_type == 'identity':
grad_in = grad_output
else:
grad_in = None
return grad_in, None, None
# tj : 常数要返回 None
# tj : https://blog.youkuaiyun.com/tsq292978891/article/details/79364140
class ConsensusModule(torch.nn.Module):
def __init__(self, consensus_type, dim=1):
super(ConsensusModule, self).__init__()
self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity'
self.dim = dim
def forward(self, input):
return SegmentConsensus.apply(input, self.consensus_type, self.dim)