focal loss for multi-class classification

本文围绕Focal loss展开,用于解决分类问题中数据类别不平衡等问题。作者记录了Focal loss在多分类上的实现过程,解读了二分类和多分类在loss上的区别,给出了多分类和二分类的交叉熵公式,还通过三分类例子模拟流程,最后给出Keras版本代码。

转自:https://blog.youkuaiyun.com/Umi_you/article/details/80982190




Focal loss 出自何恺明团队Focal Loss for Dense Object Detection一文,用于解决分类问题中数据类别不平衡以及判别难易程度差别的问题。文章中因用于目标检测区分前景和背景的二分类问题,公式以二分类问题为例。项目需要,解决Focal loss在多分类上的实现,用此博客以记录过程中的疑惑、细节和个人理解,Keras实现代码链接放在最后。

框架:Keras(tensorflow后端)
环境:ubuntu16.04 python3.5

二分类和多分类

从初学开始就一直难以分清二分类和多分类在loss上的区别,虽然明白二分类其实是多分类的一个特殊情况,但在看Focal loss文章中的公式的时候还是不免头晕,之前不愿处理的细节如今不得不仔细从很基础的地方开始解读。

多分类Cross Entropy
H(y,y)=yilogyiH(y,y′)=−∑yi′logyi

二分类Cross Entropy
H(y,y)=1i=0yilogyi=(y0logy0+y1logy1)=[y0logy0+(1y0)log(1y0)]H(y,y′)=−∑i=01yi′logyi=−(y0′∗logy0+y1′∗logy1)=−[y0′∗logy0+(1−y0′)∗log(1−y0)]

可以看出二分类问题的交叉熵其实是多分类扩展后的变形,在FocalLoss文章中,作者用一个分段函数代表二分类问题的CE(CrossEntropy)以及用pt的一个分段函数来代表二分类中标签值为1的 yiyi部分(此处的标签值为one-hot[0 1]或[1 0]中1所在的类别):
这里写图片描述
这里写图片描述

文章图中的p(predict或probility?)等价于多分类Cross Entropy公式的y,也即经激活函数(多分类为softmax函数,二分类为sigmoid函数)后得到的概率,而文章中的y对应的是Cross Entropy中的yy′,即label。

CE经分段函数pt作为自变量后可以转化为CE(p,y)=CE(pt)=log(pt)CE(p,y)=CE(pt)=−log(pt),实际上ptpt所代表的就是多分类CE中的yiyi′(标签值)为1对应的yiyi的值,只不过在二分类中y0y0y1y1互斥(两者之和为1),所以可以用一个分段的变量ptpt来表示在i取不同值情况下的yiyi,我理解ptpt为当前样本的置信度,ptpt越大置信度越大,交叉熵越小。总结:多分类中每个样本的pt为one-hot中label为1的index对应预测结果pred的值,用代码表达就是max(ypredylabel,axis=1)max(ypred∗ylabel,axis=−1)

了解ptpt所代表的是什么之后,接下来多分类的Focal Loss就好解决了。接下来举个三分类的例子来模拟一下流程大致就知道代码怎么写了:
假设
ypredypred为softmax之后得出的结果:
这里写图片描述
ylabelylabel为one-hot标签:
这里写图片描述
pt=ypredylabelpt=ypred∗ylabel:
这里写图片描述

1pt1−pt :
这里写图片描述
log(pt)log(pt):(注意pt可能为0,log(x)的取值不能为0,所以加上epsilon)
这里写图片描述

Fl:
这里写图片描述
这里写图片描述

可以看到3.4538..的地方本该是0才对,原因是log函数后会得到一个很小的值,而不是0,所以应该先做log再乘y_label:
这里写图片描述

原: log(pt)=log(ylabelypred)log(pt)=log(ylabel∗ypred)
改: log(pt)=ylabellog(ypred)log(pt)=ylabel∗log(ypred)

顺带一提,在多分类中alpha参数是没有效果的,每个样本都乘以了同样的权重

详细信息可以看代码中的注释

代码:Keras版本

转载于:https://www.cnblogs.com/leebxo/p/10547091.html

from tensorflow.keras.applications import VGG16, InceptionV3, MobileNetV2, ResNet50, DenseNet121, EfficientNetB0 from tensorflow.keras.layers import Layer, Dropout, LayerNormalization, Dense class MultiHeadSelfAttention(Layer): """ Multi-Head Self Attention Layer. This layer implements the multi-head self-attention mechanism used in transformers. It projects the input into multiple heads, performs scaled dot-product attention on each head, and then concatenates and projects the results. Attributes: embed_dim: Dimensionality of the embedding. num_heads: Number of attention heads. dropout_rate: Dropout rate for regularization. """ def __init__(self, embed_dim=256, num_heads=8, dropout_rate=0.1): """ Initialize the layer. Args: embed_dim: Dimensionality of the embedding. num_heads: Number of attention heads. dropout_rate: Dropout rate for regularization. """ super(MultiHeadSelfAttention, self).__init__() self.num_heads = num_heads self.embed_dim = embed_dim self.dropout_rate = dropout_rate if embed_dim % num_heads != 0: raise ValueError(f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}") self.projection_dim = embed_dim // num_heads # Define dense layers for query, key, and value projections self.query_dense = Dense(embed_dim) self.key_dense = Dense(embed_dim) self.value_dense = Dense(embed_dim) # Define dense layer to combine the heads self.combine_heads = Dense(embed_dim) # Define dropout and layer normalization layers self.dropout = Dropout(dropout_rate) self.layernorm = LayerNormalization(epsilon=1e-6) def attention(self, query, key, value): """ Compute scaled dot-product attention. Args: query: Query tensor. key: Key tensor. value: Value tensor. Returns: attention: Result of the attention mechanism. """ score = tf.matmul(query, key, transpose_b=True) # Calculate dot product dim_key = tf.cast(tf.shape(key)[-1], tf.float32) # Get dimension of key scaled_score = score / tf.math.sqrt(dim_key) # Scale the scores weights = tf.nn.softmax(scaled_score, axis=-1) # Apply softmax to get attention weights attention = tf.matmul(weights, value) # Multiply weights with values return attention def separate_heads(self, x, batch_size): """ Separate the heads for multi-head attention. Args: x: Input tensor. batch_size: Batch size of the input. Returns: x: Tensor with separated heads. """ x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim)) return tf.transpose(x, perm=[0, 2, 1, 3]) def call(self, inputs): """ Forward pass for the layer. Args: inputs: Input tensor. Returns: output: Output tensor after applying multi-head self-attention. """ batch_size = tf.shape(inputs)[0] # Project inputs to query, key, and value tensors query = self.query_dense(inputs) key = self.key_dense(inputs) value = self.value_dense(inputs) # Separate the heads for multi-head attention query = self.separate_heads(query, batch_size) key = self.separate_heads(key, batch_size) value = self.separate_heads(value, batch_size) # Compute attention attention = self.attention(query, key, value) # Concatenate the heads and reshape the tensor attention = tf.transpose(attention, perm=[0, 2, 1, 3]) concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim)) # Combine heads and apply dropout and layer normalization output = self.combine_heads(concat_attention) output = self.dropout(output) output = self.layernorm(inputs + output) # Reduce mean across the time dimension to get fixed-size output output = tf.reduce_mean(output, axis=1) return output def compute_output_shape(self, input_shape): """ Compute the output shape of the layer. Args: input_shape: Shape of the input tensor. Returns: Output shape. """ return input_shape[0], self.embed_dim def compute_output_shape(self, input_shape): return input_shape[0], self.embed_dim def create_vgg16_model(): base_model = VGG16(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_inceptionv3_model(): base_model = tf.keras.applications.InceptionV3(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_mobilenet_model(): base_model = MobileNetV2(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_cnn_model(): model = models.Sequential() model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3))) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(64, (3, 3), activation='relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(128, (3, 3), activation='relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(256, (3, 3), activation='relu')) model.add(layers.BatchNormalization()) model.add(layers.MaxPooling2D((2, 2))) model.add(layers.Dropout(0.3)) model.add(layers.Flatten()) model.add(layers.Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))) model.add(layers.Dropout(0.5)) model.add(MultiHeadSelfAttention(embed_dim=512, num_heads=8)) model.add(layers.Dense(256, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_densenet121_model(): base_model = DenseNet121(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model from tensorflow.keras.applications import Xception, NASNetMobile, ResNet101, VGG19, InceptionResNetV2,ResNet50 def create_xception_model(): base_model = Xception(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/xception_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_nasnet_mobile_model(): base_model = NASNetMobile(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/NASNet-mobile-no-top.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_vgg19_model(): base_model = VGG19(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_inception_resnet_v2_model(): base_model = InceptionResNetV2(weights='/kaggle/input/transfer-learning-weights/Transfer-learning-weights/inception_resnet_v2_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model ######## new def create_densenet201_model(): base_model = tf.keras.applications.DenseNet201(weights='/kaggle/input/tf-keras-pretrained-model-weights/No Top/densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model def create_resnet50_model(): base_model = ResNet50(weights='/kaggle/input/tf-keras-pretrained-model-weights/No Top/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5', include_top=False, input_shape=(128, 128, 3)) for layer in base_model.layers: layer.trainable = False model = models.Sequential() model.add(base_model) model.add(layers.GlobalAveragePooling2D()) model.add(layers.Dense(256, activation='relu')) model.add(MultiHeadSelfAttention(embed_dim=256, num_heads=8)) model.add(layers.Dense(256, activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(num_classes, activation='softmax')) return model # Define dictionary of models models_dict = { "DenseNet201": create_densenet201_model(), "ResNet50": create_resnet50_model(), "VGG16": create_vgg16_model(), "InceptionV3": create_inceptionv3_model(), "MobileNetV2": create_mobilenet_model(), "DenseNet121": create_densenet121_model(), "Xception": create_xception_model(), "NASNetMobile": create_nasnet_mobile_model(), "VGG19": create_vgg19_model(), "InceptionResNetV2": create_inception_resnet_v2_model(), } from tensorflow.keras import backend as K # Define focal loss function def focal_loss(gamma=2., alpha=0.25): """ Compute focal loss for multi-class classification. Parameters: gamma (float): Focusing parameter. alpha (float): Balancing parameter. Returns: function: Loss function. """ def focal_loss_fixed(y_true, y_pred): epsilon = K.epsilon() y_pred = K.clip(y_pred, epsilon, 1. - epsilon) y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=y_pred.shape[-1]) alpha_t = y_true * alpha + (K.ones_like(y_true) - y_true) * (1 - alpha) p_t = y_true * y_pred + (K.ones_like(y_true) - y_true) * (1 - y_pred) fl = - alpha_t * K.pow((K.ones_like(y_true) - p_t), gamma) * K.log(p_t) return K.mean(K.sum(fl, axis=-1)) return focal_loss_fixed from sklearn.utils.class_weight import compute_class_weight # Compute class weights class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train) class_weights_dict = {i: class_weights[i] for i in range(len(class_weights))} from tensorflow.keras.optimizers import SGD,Adam # Train and evaluate models results = {} for model_name, model in models_dict.items(): print(f"Training {model_name} model...") from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau # Define callbacks for early stopping and learning rate reduction early_stopping = EarlyStopping(monitor='val_accuracy', mode='max', patience=8, restore_best_weights=True) reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.5, patience=3, min_lr=1e-8) # Compile the model model.compile(optimizer=Adam(learning_rate=1e-3), loss=focal_loss(gamma=2., alpha=0.25), metrics=['accuracy']) # Train the model model.fit(X_train, y_train, epochs=20, validation_data=(X_val, y_val), callbacks=[early_stopping, reduce_lr],verbose=0) # Restore best weights if early_stopping.best_weights is not None: model.set_weights(early_stopping.best_weights) print("Restored best model weights") print("Evaluation...") # Predict validation set y_pred = model.predict(X_val, verbose=0) if isinstance(y_pred, tf.RaggedTensor): y_pred = y_pred.to_tensor() y_pred_classes = np.argmax(y_pred, axis=1) # Calculate metrics accuracy = accuracy_score(y_val, y_pred_classes) precision = precision_score(y_val, y_pred_classes, average='weighted') recall = recall_score(y_val, y_pred_classes, average='weighted') f1 = f1_score(y_val, y_pred_classes, average='weighted') # Store results results[model_name] = { "Accuracy": accuracy, "Precision": precision, "Recall": recall, "F1-score": f1 } # Plot confusion matrix conf_mat = confusion_matrix(y_val, y_pred_classes) plt.figure(figsize=(7, 6)) sns.heatmap(conf_mat, annot=True, fmt='d', cmap='Blues', xticklabels=label_names, yticklabels=label_names) plt.title(f'Confusion Matrix - {model_name}') plt.xlabel('Predicted') plt.ylabel('True') plt.savefig(f'confusion_matrix_{model_name}.png') plt.show() if num_classes==2: # Plot ROC curves for binary classification from sklearn.metrics import roc_curve, auc import matplotlib.pyplot as plt # 假设 y_val 是真实标签,y_pred 是预测概率 # Compute ROC curve and AUC for binary classification fpr, tpr, _ = roc_curve(y_val, y_pred[:, 1]) # 只使用正类的概率 roc_auc = auc(fpr, tpr) # Plotting ROC curve plt.figure() plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = {0:0.2f})'.format(roc_auc)) plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--') plt.xlim([-0.01, 1.0]) plt.ylim([-0.01, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title(f'Receiver Operating Characteristic (ROC) - {model_name}') plt.legend(loc="lower right") plt.savefig(f'roc_curve_{model_name}.png') plt.show() else: # Plot ROC curves y_val_bin = label_binarize(y_val, classes=label_names) n_classes = y_val_bin.shape[1] fpr = dict() tpr = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], _ = roc_curve(y_val_bin[:, i], y_pred[:, i]) roc_auc[i] = auc(fpr[i], tpr[i]) plt.figure() colors = ['aqua', 'darkorange', 'cornflowerblue'] for i, color in zip(range(n_classes), colors): plt.plot(fpr[i], tpr[i], color=color, lw=2, label='ROC curve of {0} (area = {1:0.2f})' ''.format(label_names[i], roc_auc[i])) plt.plot([0, 1], [0, 1], 'k--', lw=2) plt.xlim([-0.01, 1.0]) plt.ylim([-0.01, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title(f'Receiver Operating Characteristic (ROC) - {model_name}') plt.legend(loc="lower right") plt.savefig(f'roc_curve_{model_name}.png') plt.show() # Print results for model_name, metrics in results.items(): print(f"Results for {model_name}:") for metric, value in metrics.items(): print(f"{metric}: {value}") print("\n") 帮我解释这个代码,选择cnn模型,并将我得代码梳理
最新发布
12-07
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值