# -*- coding: utf-8 -*-
"""
📌 中国研究生数学建模竞赛 E题 · 任务三(增强版)
🔧 方法:CORAL + Random Forest + 伪标签迭代训练(Self-Training)
🎯 目标:提升目标域预测置信度与一致性
"""
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
# ====================== 超参数设置 ======================
INPUT_DIM = 41
NUM_CLASSES = 4
N_TREES = 100
MAX_DEPTH = 10
RANDOM_STATE = 42
MAX_ITER = 3 # 最大伪标签迭代次数
CONFIDENCE_THRESHOLD = 0.6 # 置信度阈值:高于此值才加入训练
EARLY_STOPPING = True # 若无新样本加入则提前停止
class_names = ['Normal', 'Outer Race Fault', 'Inner Race Fault', 'Ball Fault']
feature_cols_ignore = ['filename', 'label', 'domain']
# ====================== CORAL 对齐函数(同前正确实现)======================
def coral_loss(source_features, target_features):
Xs = (source_features - source_features.mean(0, keepdims=True)) / (source_features.std(0, keepdims=True) + 1e-6)
Xt = (target_features - target_features.mean(0, keepdims=True)) / (target_features.std(0, keepdims=True) + 1e-6)
Cs = np.cov(Xs.T)
Ct = np.cov(Xt.T)
return np.linalg.norm(Cs - Ct, 'fro') ** 2
def coral_align(Xs, Xt):
d = Xs.shape[1]
Xs_mean = Xs.mean(axis=0, keepdims=True)
Xt_mean = Xt.mean(axis=0, keepdims=True)
Xs_c = Xs - Xs_mean
Xt_c = Xt - Xt_mean
ns, nt = len(Xs), len(Xt)
Cs = (Xs_c.T @ Xs_c) / (ns - 1) + 1e-6 * np.eye(d)
Ct = (Xt_c.T @ Xt_c) / (nt - 1) + 1e-6 * np.eye(d)
U_s, S_s, Vt_s = np.linalg.svd(Cs)
U_t, S_t, Vt_t = np.linalg.svd(Ct)
Cs_sqrt = U_s @ np.diag(np.sqrt(S_s)) @ U_s.T
Ct_inv_sqrt = U_t @ np.diag(S_t ** -0.5) @ U_t.T
Xt_aligned = Xt_c @ Ct_inv_sqrt @ Cs_sqrt + Xs_mean
return Xt_aligned.astype(np.float32)
# ====================== 主函数:带伪标签迭代训练 ======================
def main_with_self_training():
print("🚀 开始执行任务三(增强版):CORAL + RF + 伪标签迭代训练")
# 1. 加载数据
df = pd.read_csv('extracted_features_with_domain.csv')
feature_cols = [col for col in df.columns if col not in feature_cols_ignore]
source_data = df[df['domain'] == 'source'].copy()
target_data = df[df['domain'] == 'target'].copy()
Xs = source_data[feature_cols].values.astype(np.float32)
ys = source_data['label'].values.astype(int)
Xt = target_data[feature_cols].values.astype(np.float32)
filenames_target = target_data['filename'].values
print(f"✅ 源域样本数: {len(Xs)}")
print(f"✅ 目标域样本数: {len(Xt)}")
# 2. 数据标准化(使用源域标准)
scaler = StandardScaler()
X_all = np.vstack((Xs, Xt))
X_all_scaled = scaler.fit_transform(X_all)
Xs_scaled = X_all_scaled[:len(Xs)]
Xt_scaled = X_all_scaled[len(Xs):]
# 3. CORAL 对齐(只做一次)
print("🔄 正在使用 CORAL 对齐特征...")
loss_before = coral_loss(Xs_scaled, Xt_scaled)
Xt_aligned = coral_align(Xs_scaled, Xt_scaled)
loss_after = coral_loss(Xs_scaled, Xt_aligned)
print(f"📊 CORAL Loss Before: {loss_before:.4f} → After: {loss_after:.4f}")
# 4. 初始化训练集
X_train = Xs_scaled.copy()
y_train = ys.copy()
initial_model_trained = False
# 存储每轮结果
history = {
'iter': [], 'added': [], 'conf_avg': [], 'preds': [], 'probas': []
}
for it in range(MAX_ITER):
print(f"\n🔄 进行第 {it+1} 轮伪标签训练...")
# 训练模型
rf_model = RandomForestClassifier(
n_estimators=N_TREES,
max_depth=MAX_DEPTH,
random_state=RANDOM_STATE,
class_weight='balanced'
)
rf_model.fit(X_train, y_train)
initial_model_trained = True
# 在目标域上预测
probas = rf_model.predict_proba(Xt_aligned)
predictions = rf_model.predict(Xt_aligned)
confidences = np.max(probas, axis=1)
avg_conf = confidences.mean()
pred_labels = [class_names[i] for i in predictions]
# 找出高置信样本(用于伪标签)
high_conf_mask = confidences >= CONFIDENCE_THRESHOLD
num_new = high_conf_mask.sum()
print(f"📈 第{it+1}轮平均置信度: {avg_conf:.3f}")
print(f"🟢 新增高置信样本数: {num_new}")
# 记录历史
history['iter'].append(it+1)
history['added'].append(num_new)
history['conf_avg'].append(avg_conf)
history['preds'].append(pred_labels.copy())
history['probas'].append(probas.copy())
# 如果没有新样本加入,提前终止
if num_new == 0 and EARLY_STOPPING:
print("🔚 无新增高置信样本,提前结束迭代。")
break
# 将高置信样本加入训练集
X_pseudo = Xt_aligned[high_conf_mask]
y_pseudo = predictions[high_conf_mask]
# 更新训练集
X_train = np.vstack([X_train, X_pseudo])
y_train = np.hstack([y_train, y_pseudo])
print(f"🧠 当前训练集大小: {len(X_train)} (源域{len(Xs)}, 伪标签{len(X_pseudo)})")
# ====================== 输出最终结果 ======================
final_probas = history['probas'][-1]
final_preds = history['preds'][-1]
final_conf = np.max(final_probas, axis=1)
result_df = pd.DataFrame({
'File': [f.split('.')[0] for f in filenames_target],
'Predicted_Label': final_preds,
'Confidence': final_conf.round(3)
})
for i, cls_name in enumerate(class_names):
result_df[f'Prob_{cls_name}'] = final_probas[:, i].round(3)
result_df.to_csv('predicted_labels_CORAL_RF_SELFTRAIN.csv', index=False)
print("\n📋 最终预测结果:")
print(result_df.to_string(index=False))
print("💾 已保存至: predicted_labels_CORAL_RF_SELFTRAIN.csv")
# ====================== 可视化:迭代过程 ======================
plt.figure(figsize=(12, 5))
# --- 左图:每轮新增样本与平均置信度 ---
ax1 = plt.subplot(1, 2, 1)
epochs = history['iter']
plt.plot(epochs, history['conf_avg'], 'bo-', label='Avg Confidence')
plt.bar(epochs, history['added'], alpha=0.6, color='orange', label='New Pseudo Labels')
plt.xlabel('Iteration')
plt.ylabel('Value')
plt.title('Self-Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
# --- 右图:热力图展示概率变化 ---
plt.subplot(1, 2, 2)
first_probs = history['probas'][0]
last_probs = history['probas'][-1]
delta_probs = last_probs - first_probs # 变化量
sns.heatmap(delta_probs.T, annot=True, fmt=".2f",
xticklabels=[f.split('.')[0] for f in filenames_target],
yticklabels=class_names, cmap='RdBu_r', center=0)
plt.title('Probability Change Before vs After Self-Training')
plt.xlabel('Sample')
plt.ylabel('Class')
plt.tight_layout()
plt.savefig('self_training_progress.png', dpi=150)
plt.show()
# ====================== 条形图对比置信度变化(可选)======================
if len(history['iter']) > 1:
init_conf = history['probas'][0].max(axis=1)
final_conf = history['probas'][-1].max(axis=1)
diff_conf = final_conf - init_conf
plt.figure(figsize=(10, 6))
colors = ['green' if x > 0 else 'red' for x in diff_conf]
plt.barh(result_df['File'], diff_conf, color=colors, edgecolor='black', alpha=0.8)
plt.axvline(0, color='gray', linestyle='--')
plt.xlabel('Confidence Change (After - Before)')
plt.title('Change in Prediction Confidence after Self-Training')
plt.grid(True, axis='x', alpha=0.5)
plt.tight_layout()
plt.savefig('confidence_change_selftrain.png', dpi=150)
plt.show()
print("🎉 伪标签迭代训练完成!请查看输出文件与图表。")
if __name__ == "__main__":
main_with_self_training()
将代码中的可视化图改为展示迁移前后源域数据的四种故障情况和目标域数据的T-SNE图,以及使用CORAL对齐前后的差异表格和预测的置信区间分布图