AE(associative embedding) loss TF版本实现

本文介绍了AE(Associative Embedding)损失函数的起源和作用,它常用于底部向上关键点检测的聚类。在CornerNet和HigherHRNet等目标检测与姿态估计模型中,AE损失帮助区分同一类别的多个目标并完成关键点的精确聚类。文章提供了TensorFlow版本的AE损失函数实现,包括两个版本,修复了单目标情况下的错误,并对其工作原理进行了详细说明。

一、背景

        AE loss最早出自《Associative Embedding:End-to-End Learning for Joint Detection and Grouping》这篇文章,通过简单的embedding方式,用于bottom-up关键点检测中进行最后的聚类。

        在CornerNet、HigherHRNet等文章中的均使用了该技巧。在目标检测任务ConrnerNet中用来区分多个同类别目标,在姿态估计任务HigherHRNet中用来完成关键点聚类。

        目前网上tensorflow版本的实现较少,简单实现后记录于此。

二、实现

        基本原理其实很不难:单个体内的关键点间差值要小,不同个体间的关键点均值差值要大

        源码主要参考 CornerNet 官方pytorch版本 

1.第一版

        替换几个tf函数

#pytorch版本源码:cornerNet  /models/py_utils/kp_utils.py#L180
#tag0, [b, max_tl=128, 1]
#tag1, [b, max_rb=128, 1]
#mask, [b, 128]
def ae_loss(tag0, tag1, mask):
    num = tf.reduce_sum(mask, axis=1, keepdims=True) # b,n -> b,1 .  sum(n)
    tag0 = tf.squeeze(tag0,axis=-1) # b,n,1 -> b,n
    tag1 = tf.squeeze(tag1,axis=-1)

    #pull计算(类内差):每个batch 每个instance内均值, 再计算instance内每个joint与均值的差
    tag_mean = (tag0 + tag1) / 2 # b,n
    tag0 = tf.pow(tag0 - tag_mean, 2) / (num + 1e-4)
    tag
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值