U-Net网络缺陷检测 样本像素不均衡问题
U-Net 的网络设计如上图所示,
或者参考文章Automatic Metallic Surface Defect Detection and Recognition with Convolutional Neural Networks来设计你的网络。
现模型输入 input_image 1 x 512 x 512 x 1 input_label 1 x 512 x 512
现模型输出 prediction 1 x 512 x 512 x 2
In general, a captured image of the metallic surface has more background pixels than defective pixels.
通常,金属表面的捕获图像具有比缺陷像素更多的背景像素。
样本均衡 postive:negative ≈ 1 : 3,缺陷检测往往面临样本不均衡问题 postive:negative != 1 : 3
针对样本不均衡问题
solution-1 focal loss
def focal_loss(labels, logits, gamma=2, alpha=0.25):
labels = tf.one_hot(labels, depth=2, on_value=None, off_value=None, axis=None, dtype=None, name=None)
# 防止为0
epsilon = 1.e-3
logits = tf.clip_by_value(logits, epsilon, 1. - epsilon)
# 交叉熵
cross_entropy = -labels * tf.log(logits)
# focal loss 权重交叉熵模型
focal_loss = tf.pow(1 - logits, gamma) * cross_entropy
# loss
loss = tf.reduce_mean(tf.reduce_sum(focal_loss, axis=-1))
return loss
solution-2 re-weight the imbalanced classes
参考文章 Automatic Metallic Surface Defect Detection and Recognition with Convolutional Neural Networks的思路
def imbalance_cross_entropy(labels, logits):
labels_one_hot = tf.one_hot(labels, depth=2, on_value=None, off_value=None, axis=None, dtype=None, name=None)
# 防止为0
epsilon = 1.e-3
logits = tf.clip_by_value(logits, epsilon, 1. - epsilon)
# 交叉熵
cross_entropy = -labels_one_hot * tf.log(logits)
cross_entropy_sum = tf.reduce_sum(cross_entropy, axis=-1)
# imbalance weight
labels_reshape = tf.reshape(labels, [-1])
weight_BG = tf.constant(0.1, dtype=tf.float32, shape=labels_reshape.shape)
weight_FG = tf.constant(0.9, dtype=tf.float32, shape=labels_reshape.shape)
labels_weight = tf.where(tf.equal(labels_reshape, 1), weight_BG, weight_FG)
labels_imbalance = tf.reshape(labels_weight, (1, 512, 512))
# imbalance_cross_entropy
loss = cross_entropy_sum*labels_imbalance
imbalance_cross_entropy_loss = tf.reduce_mean(loss)
return imbalance_cross_entropy_loss
这里设置的 re-weight 参数是0.1,0.9
weight_BG 对背景像素的权重设置
weight_FG 对前景像素即目标像素的权重设置
注意设置 weight_BG + weight_FG = 1
经过测试,做如下总结:(这是经验总结,针对你的问题需多做测试)
若权重设置更偏向于FG 则 decrease FN increase Recall 宁可多prediction,不错过label
若权重设置更偏向于BG 则 decrease FP increase Precision 宁可少prediction,不出错prediction
solution-3 lovasz loss
solution-4
def stop_and_imbalance(labels, logits):
# H W 数值
H = labels.shape[1]
W = labels.shape[2]
# one_hot处理标签, 格式对应 B×H×W×C
labels_one_hot = tf.one_hot(labels, depth=2, on_value=None, off_value=None, axis=None, dtype=None, name='one_hot')
# 防止为0
epsilon = 1.e-3
logits = tf.clip_by_value(logits, epsilon, 1. - epsilon)
# 交叉熵
cross_entropy = -tf.multiply(labels_one_hot, tf.log(logits, name='log_logits'), name='cross_entropy')
# 交叉熵在最后的一个维度计算sum, 格式对应 B×H×W
cross_entropy_sum = tf.reduce_sum(cross_entropy, axis=-1, name='cross_entropy_sum')
# 交叉熵格式对应 [262144,]
cross_entropy_sum_flatten = tf.reshape(cross_entropy_sum, [-1], name='cross_entropy_sum_flatten')
# 标签格式对应 [262144,]
labels_flatten = tf.reshape(labels, [-1], name='reshape_label')
# 标签样本个数 H*W
labels_num = tf.multiply(H, W, name='label_num')
# 统计BG样本个数, 背景 负样本 Demo 中 BG数值为1 直接reduce_sum求和就是pixel为1的和,就是BG的个数
imbalance_BG_NUM = tf.reduce_sum(labels_flatten, name='imbalance_BG_NUM')
# 统计FG样本个数, 前景 正样本 Demo 中 FG数值为0 直接labels_num - imbalance_BG_NUM = imbalance_FG_NUM
imbalance_FG_NUM = tf.subtract(labels_num, imbalance_BG_NUM, name='imbalance_FG_NUM')
# gather 定位 FG
label_0 = tf.equal(labels_flatten, 0, name='label_0')
label_0_cross_entropy_sum_flatten = tf.boolean_mask(cross_entropy_sum_flatten, label_0, name='pos_label_FG')
# 全部损失回传, 梯度回传
FG_Loss = tf.reduce_sum(label_0_cross_entropy_sum_flatten)
FG_Loss = tf.cast(FG_Loss,dtype=tf.float32)
# gather 定位 BG
label_1 = tf.not_equal(labels_flatten, 0, name='label_1')
label_1_cross_entropy_sum_flatten = tf.boolean_mask(cross_entropy_sum_flatten, label_1, name='pos_label_BG')
# 对label_1_cross_entropy_sum_flatten进行排序
value, index = tf.nn.top_k(label_1_cross_entropy_sum_flatten, tf.shape(label_1_cross_entropy_sum_flatten)[0], name='sort_BG')
# value前imbalance_FG_NUM*N部分计算损失
value_BG_Loss_Top_K = tf.slice(value,[0],[tf.cast(imbalance_FG_NUM*3,dtype=tf.int32)])
BG_Loss = tf.reduce_sum(value_BG_Loss_Top_K)
BG_Loss = tf.cast(BG_Loss, dtype=tf.float32)
# value后imbalance_FG_NUM*N部分stop梯度传播
value_BG_Loss_Stop = tf.slice(value,[tf.cast(imbalance_FG_NUM*3,dtype=tf.int32)],[tf.cast(tf.shape(value)[0]-imbalance_FG_NUM*3,dtype=tf.int32)])
index_BG_Loss_Stop = tf.slice(index,[tf.cast(imbalance_FG_NUM*3,dtype=tf.int32)],[tf.cast(tf.shape(value)[0]-imbalance_FG_NUM*3,dtype=tf.int32)])
tf.stop_gradient(value_BG_Loss_Stop, name='stop_gradient_value_BG_Loss_Stop')
tf.stop_gradient(index_BG_Loss_Stop, name='stop_gradient_index_BG_Loss_Stop')
# sum_mean
stop_and_imbalance_cross_entropy_loss = FG_Loss/(tf.cast(labels_num,dtype=tf.float32))+BG_Loss/(tf.cast(labels_num,dtype=tf.float32))
return stop_and_imbalance_cross_entropy_loss