一、背景
AE loss最早出自《Associative Embedding:End-to-End Learning for Joint Detection and Grouping》这篇文章,通过简单的embedding方式,用于bottom-up关键点检测中进行最后的聚类。
在CornerNet、HigherHRNet等文章中的均使用了该技巧。在目标检测任务ConrnerNet中用来区分多个同类别目标,在姿态估计任务HigherHRNet中用来完成关键点聚类。
目前网上tensorflow版本的实现较少,简单实现后记录于此。
二、实现
基本原理其实很不难:单个体内的关键点间差值要小,不同个体间的关键点均值差值要大
源码主要参考 CornerNet 官方pytorch版本
1.第一版
替换几个tf函数
#pytorch版本源码:cornerNet /models/py_utils/kp_utils.py#L180
#tag0, [b, max_tl=128, 1]
#tag1, [b, max_rb=128, 1]
#mask, [b, 128]
def ae_loss(tag0, tag1, mask):
num = tf.reduce_sum(mask, axis=1, keepdims=True) # b,n -> b,1 . sum(n)
tag0 = tf.squeeze(tag0,axis=-1) # b,n,1 -> b,n
tag1 = tf.squeeze(tag1,axis=-1)
#pull计算(类内差):每个batch 每个instance内均值, 再计算instance内每个joint与均值的差
tag_mean = (tag0 + tag1) / 2 # b,n
tag0 = tf.pow(tag0 - tag_mean, 2) / (num + 1e-4)
tag0 = tf.reduce_sum(tag0*mask)
tag1 = tf.pow(tag1 - tag_mean, 2) / (num + 1e-4)
tag1 = tf.reduce_sum(tag1*mask)
pull = tag0 + tag1
print(tag_mean)
#push计算(类外差)
mask = tf.expand_dims(mask, 1) + tf.expand_dims(mask, 2)
mask = tf.equal(mask, 2) # b,n,n
num = tf.expand_dims(num, 2) # b, 1 -> b, 1, 1
num2 = (num - 1) * num
dist = tf.expand_dims(tag_mean, 1) - tf.expand_dims(tag_mean, 2) # b,n,n。 混淆矩阵,(b,i,j)表示batch中i类j类的差
dist = 1 - tf.abs(dist) # 这里对角线会出现1, 不应该包含在后续sum计算中
dist = tf.nn.relu(dist)
dist = dist - 1 / (num + 1e-4) # 抵消对角线1的影响, 扣掉后才会与paper公式一致
dist = dist / (num2 + 1e-4)
mask = tf.cast(mask, tf.float32)
push = tf.reduce_sum(dist*mask)
return pull, push
测试代码
tag0 = tf.constant([[1., 2. , 2.],
[1., 2. , 3.]], dtype=tf.float32)
tag1 = tf.constant([[1., 2. , 3.],
[1., 2. , 3.]], dtype=tf.float32)
mask = tf.constant([[1., 1. , 1],
[1., 1. , 0]], dtype=tf.float32)
tag0 = tf.expand_dims(tag0, -1)
tag1 = tf.expand_dims(tag1, -1)
ae_loss(tag0, tag1, mask)
>> tf.Tensor(
[[1. 2. 2.5]
[1. 2. 3. ]], shape=(2, 3), dtype=float32)
2.第二版
主要两个修改点:
a.第一版输入tag0/tag1来自两个不同feature map,我这边用的是同一个feature map,并且支持任意关键点数量。
b.第一版按CornerNet的实现,当出现单个目标时存在bug(coco数据中大多是多目标)。理论上push值应为0,但实际不是,主要修复该bug。
# y_pred: (b, h, w, 1)
# mask: (b, max_objects) , [[1,0,0...],[1,1,1,0,0...]....]
# indices: (b, max_objects*4) , [[312,555,666,777,0,...],[........]....]
# output : (b, max_objects, 4) [[1.0, 2.0, 3.0, 4.0],.....]
def trans_objects(y_pred, mask, indices):
#batch, channel
b, c = tf.shape(y_pred)[0], tf.shape(y_pred)[-1]
#max_objects
max_corners = tf.shape(indices)[1]
max_objs = max_corners / 4
y_pred = tf.reshape(y_pred, (b, -1, c)) #b, h*w, c=1
length = tf.shape(y_pred)[1] #h*w
indices = tf.cast(indices, tf.int32) #b, max_objs*4
#y_pred: b, h*w, c -> b, max_objs, 2; 只计算indices指定坐标的loss
batch_idx = tf.expand_dims(tf.range(0, b), 1) #b, 1
batch_idx = tf.tile(batch_idx, (1, max_corners)) #b, max_corners. [[0,0,0..],[1,1,1...],...]
full_indices = (tf.reshape(batch_idx, [-1]) * tf.cast(length, tf.int32) + tf.reshape(indices, [-1])) #b*max_objs*4. [0+312,0,....,h*w+65,h*w+203,h*w+1105,h*w+0....,2*h*w+?,2*h*w+0,.......]
y_pred = tf.gather(tf.reshape(y_pred, [-1, c]), full_indices)#根据full_indices,在b*h*w维度中筛选,
y_pred = tf.reshape(y_pred, [b, -1, 4]) # [b*max_objs*4, c=1] -> [b, max_objs, 4]
#mask
mask = tf.tile(tf.expand_dims(mask, axis=-1), (1, 1, 4)) # b,max_objs -> b,max_objs,1 -> b,max_objs,4。 4个角点
return y_pred * mask
# y_pred: embedding, (batch_size, out_h, out_w, 1)
# mask: (batch_size, max_objects) , [[1,0,0...],[1,1,1,0,0...]....]
# indices:4个角点索引 (batch_size, max_objects*4) , [[312,555,666,777,0,...],[........]....]
def my_ae_loss(y_pred, mask, indices):
# b,h,w,1 -> b, n, 4
tag = trans_objects(y_pred, mask, indices)
max_objs = tf.shape(mask)[1] # n
num = tf.reduce_sum(mask, axis=1, keepdims=True) # b,n -> b,1 . sum(n) 每个batch的num可能不一样
tag_mean = tf.reduce_mean(tag, axis=-1, keepdims=True) # b,n,4 -> b,n,1 .同一object内的embedding均值
pull = tf.pow(tag - tag_mean, 2) # 每个点减去均值。 b,n,4。 tag_mean:b,n,1(自动扩4)
pull = pull / tf.expand_dims(tf.tile(num + 1e-4, (1, max_objs)), axis=-1)
pull = pull * tf.expand_dims(mask, axis=-1) # mask内统计。 b,n,4。 mask:b,n -> b,n,1 -> b,n,4
pull = tf.reduce_sum(pull)
#print(pull.shape)
#push计算(类外差)
mask = mask * tf.cast(tf.greater(num, 1), tf.float32) # by lvjj, num = 1的mask 直接置0,其pull值会为0(与num=0一样)
mask = tf.expand_dims(mask, 1) + tf.expand_dims(mask, 2)
mask = tf.equal(mask, 2) # b,n,n
num = tf.expand_dims(num, 2) # b, 1 -> b, 1, 1
num2 = (num - 1) * num
tag_mean = tf.squeeze(tag_mean,axis=-1) # b,n,1 -> b,n
dist = tf.expand_dims(tag_mean, 1) - tf.expand_dims(tag_mean, 2) # b,n,n。 混淆矩阵,(b,i,j)表示batch中i类j类的差
dist = 1 - tf.abs(dist) # 这里对角线会出现1, 不应该包含在后续sum计算中
dist = tf.nn.relu(dist)
dist = dist - 1 / (num + 1e-4) # 抵消对角线1的影响, 扣掉后才会与paper公式一致
dist = dist / (num2 + 1e-4)
mask = tf.cast(mask, tf.float32)
push = tf.reduce_sum(dist*mask)
return pull, push
测试代码,mask[1,1] 表示有两个目标, mask[1,0]表示一个目标,后者在原实现中会有问题
# 2,2,2,1 . embedding数值
y_pred = tf.constant([ [[[1.],[2.]],
[[1.],[2.]],],
[[[2.],[1.]],
[[2.],[1.]],],
], dtype=tf.float32)
# 2, max_objs=2 .
mask = tf.constant([ [1,1],
[1,1],
], dtype=tf.float32)
# 2, max_objs*4=8
indices = tf.constant([ [0,0,2,2, 1,1,3,3],
[0,0,2,2, 1,1,3,3],
], dtype=tf.float32)
my_ae_loss(y_pred, mask, indices)
>>(<tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=9.9897385e-05>)