继续生成新的任务三代码,如题:、 迁移诊断:在任务2设计的诊断模型基础上,充分考虑源域与目标域的共性与差异特征,设计合适的迁移学习方法,构建目标域诊断模型,对目标域未知标签的数据进行分类和标定,给出迁移结果的可视化展示和分析,并给出数据对应的标签。已知任务一的代码为:
# -*- coding: utf-8 -*-
"""
📌 任务一:改进版特征提取(含包络谱 + 故障频率敏感特征)
🔧 目标:解决任务二准确率低的问题
💡 核心改进:加入 Hilbert 包络解调特征,极大增强对 Outer/Ball 故障的识别能力
"""
import os
import numpy as np
import scipy.io as sio
from scipy.fft import fft
from scipy.stats import skew, kurtosis
from scipy.signal import butter, filtfilt, hilbert
import pywt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from scipy.stats import gaussian_kde
import matplotlib.pyplot as plt
import pandas as pd
# ====================== 自定义 JS 散度(兼容低版本 SciPy)======================
def jenshannon(p, q):
"""
手动实现 Jensen-Shannon Divergence 的平方根(即 JS距离)
输入:两个数组 p 和 q(概率分布)
输出:JS距离(非负值)
"""
p = np.asarray(p)
q = np.asarray(q)
p = p / (p.sum() + 1e-8)
q = q / (q.sum() + 1e-8)
m = 0.5 * (p + q)
def entropy(x):
return -np.sum(x * np.log(x + 1e-8))
js_divergence = entropy(m) - 0.5 * entropy(p) - 0.5 * entropy(q)
return np.sqrt(max(js_divergence, 0)) # 返回 JS距离
# ====================== 路径设置 ======================
source_data_dir = r'C:\Users\1\Desktop\新建文件夹\数据集\源域数据集'
target_data_dir = r'C:\Users\1\Desktop\新建文件夹\数据集\目标域数据集'
output_csv = 'extracted_features_with_domain.csv'
fs_target = 32000 # 统一重采样到32kHz
target_length = 320000 # 每段信号长度(约10秒)
# 轴承参数(SKF6205)
n_rollers = 9
diameter_ball = 0.3126 * 0.0254
pitch_diameter = 1.537 * 0.0254
contact_angle = np.radians(15)
# 滤波参数
FILTER_LOW_CUT = 200.0
FILTER_HIGH_CUT = 10000.0
FILTER_ORDER = 6
# ====================== 查找所有.mat文件 ======================
def find_mat_files(root_dir):
mat_files = []
for dirpath, _, filenames in os.walk(root_dir):
for f in filenames:
if f.endswith('.mat') and not f.startswith('._'):
mat_files.append(os.path.join(dirpath, f))
print(f"🔍 在 {root_dir} 中发现 {len(mat_files)} 个 .mat 文件")
return mat_files
# ====================== 加载.mat文件 ======================
def load_cwru_data(filepath):
try:
mat = sio.loadmat(filepath)
except Exception as e:
print(f"[错误] 无法读取 {filepath}: {e}")
return None, None
de_key = rpm_key = None
for k in mat.keys():
if not k.startswith('__') and 'DE' in k: # 驱动端信号
de_key = k
if 'RPM' in k.upper():
rpm_key = k
if de_key is None:
# 回退策略:找最长的一维数组
for k in mat.keys():
if isinstance(mat[k], np.ndarray):
tmp = mat[k].flatten()
if len(tmp) >= 10000:
de_key = k
break
if de_key is None:
print(f"[警告] 未在 {filepath} 中找到驱动端(DE)信号")
return None, None
try:
signal = mat[de_key].flatten()
except:
return None, None
rpm = 1797
if rpm_key in mat:
try:
rpm = float(mat[rpm_key].item()) if mat[rpm_key].size == 1 else float(mat[rpm_key][0, 0])
except:
pass
return signal, rpm
# ====================== 带通滤波 ======================
def bandpass_filter(signal, lowcut, highcut, fs, order=5):
nyquist = 0.5 * fs
low = lowcut / nyquist
high = highcut / nyquist
if low >= 1.0 or high >= 1.0:
print(f"[警告] 截止频率超出范围(fs={fs}),跳过滤波")
return signal
b, a = butter(order, [low, high], btype='band', analog=False)
return filtfilt(b, a, signal)
# ====================== 预处理信号 ======================
def preprocess_signal(x, fs_original, fs_target=32000, target_len=320000):
if len(x) == 0:
return np.zeros(target_len)
# 重采样
if fs_original != fs_target:
num_new = int(len(x) * fs_target / fs_original)
new_time_idx = np.linspace(0, len(x)-1, num_new)
x = np.interp(new_time_idx, np.arange(len(x)), x)
# 滤波
x = bandpass_filter(x, FILTER_LOW_CUT, FILTER_HIGH_CUT, fs_target, FILTER_ORDER)
# 统一长度
if len(x) > target_len:
x = x[:target_len]
elif len(x) < target_len:
pad_len = target_len - len(x)
x = np.pad(x, (0, pad_len), mode='constant', constant_values=0)
return x
# ====================== 时域特征 ======================
def time_domain_features(x):
if len(x) == 0 or np.var(x) < 1e-10:
return [0.0] * 12
x_abs = np.abs(x)
x_sq = x ** 2
mean_val = np.mean(x)
std_val = np.std(x)
peak_val = np.max(x_abs)
rms_val = np.sqrt(np.mean(x_sq))
return [
mean_val,
std_val,
peak_val,
rms_val,
peak_val / (rms_val + 1e-8), # crest factor
skew(x),
kurtosis(x),
peak_val / (np.mean(x_abs) + 1e-8), # impulse factor
peak_val / ((np.mean(np.sqrt(x_abs)))**2 + 1e-8), # margin factor
rms_val / (np.mean(x_abs) + 1e-8), # shape factor
np.sum(x_sq), # energy
peak_val / (np.mean(np.sqrt(x_abs)) + 1e-8), # clearance factor
]
# ====================== 频域特征 ======================
def freq_domain_features(x, fs=32000, fr=None):
N = len(x)
X_fft = np.abs(fft(x))[:N // 2]
freqs = np.linspace(0, fs / 2, N // 2)
if fr is None:
fr = 1797 / 60 # 默认转频 Hz
bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle))
bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle))
bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \
(1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2)
def band_energy(center, width=10):
idx = (freqs >= center - width) & (freqs <= center + width)
return np.sum(X_fft[idx]**2) if np.any(idx) else 0.0
total_power = np.sum(X_fft**2) + 1e-8
return [
np.max(X_fft),
np.mean(X_fft),
np.sum(freqs * X_fft) / (np.sum(X_fft) + 1e-8),
np.sqrt(np.var(freqs[:len(X_fft)] * X_fft)),
band_energy(bpfo) / total_power,
band_energy(bpfi) / total_power,
band_energy(bsf) / total_power,
np.sum(X_fft[(freqs > 1000) & (freqs < 3000)]**2) / total_power
]
# ====================== 小波包能量熵(优化版)======================
def wavelet_packet_features(x, level=4, wavelet='db4'):
try:
wp = pywt.WaveletPacket(data=x, wavelet=wavelet, maxlevel=level)
energies = [np.sum(np.square(node.data)) for node in wp.get_level(level, 'natural')]
energies = np.array(energies)
p_i = energies / (energies.sum() + 1e-8)
entropy = -np.sum(p_i * np.log(p_i + 1e-8))
return [entropy] + p_i[:8].tolist() # 固定返回 9 维
except Exception as e:
print(f"[警告] 小波包分解失败: {e}")
return [0.0] * 9
# ====================== 包络谱特征提取(新增!核心改进)======================
def envelope_spectrum_features(x, fs=32000, fr=None):
"""
提取 Hilbert 包络谱中的关键能量特征(对早期故障极其敏感)
"""
if len(x) == 0 or np.var(x) < 1e-10:
return [0.0] * 12
# 1. Hilbert 变换获取包络
try:
analytic_signal = hilbert(x)
envelope = np.abs(analytic_signal)
except:
envelope = np.abs(hilbert(x))
# 2. 对包络做 FFT
N = len(envelope)
env_fft = np.abs(fft(envelope))[:N // 2]
freqs = np.linspace(0, fs / 2, N // 2)
total_power = np.sum(env_fft**2) + 1e-8
if fr is None:
fr = 1797 / 60 # 转频 Hz
# 计算理论故障频率
bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle))
bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle))
bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \
(1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2)
def band_energy(center, width=50):
idx = (freqs >= center - width) & (freqs <= center + width)
return np.sum(env_fft[idx]**2) if np.any(idx) else 0.0
energy_bpfo = band_energy(bpfo)
energy_bpfi = band_energy(bpfi)
energy_bsf = band_energy(bsf)
energy_fr = band_energy(fr, width=20)
# 归一化能量
energy_bpfo /= total_power
energy_bpfi /= total_power
energy_bsf /= total_power
energy_fr /= total_power
# 高中低频段能量
high_freq_mask = (freqs >= 2000) & (freqs <= 8000)
mid_freq_mask = (freqs >= 500) & (freqs < 2000)
low_freq_mask = (freqs < 500)
energy_high = np.sum(env_fft[high_freq_mask]**2) / total_power
energy_mid = np.sum(env_fft[mid_freq_mask]**2) / total_power
energy_low = np.sum(env_fft[low_freq_mask]**2) / total_power
# 特征组合
return [
energy_bpfo, energy_bpfi, energy_bsf, energy_fr,
energy_high, energy_mid, energy_low,
energy_high / (energy_low + 1e-8), # 高/低频能量比
energy_bpfo / (energy_bpfi + 1e-8), # 外圈 vs 内圈能量比
kurtosis(envelope), # 包络峭度(冲击性)
np.max(env_fft), # 包络谱峰值
np.mean(env_fft) # 包络谱均值
]
# ====================== 标签解析 ======================
def parse_label_from_path(filepath):
name = os.path.basename(filepath).upper()
dirname = os.path.dirname(filepath).upper()
if any(n in name for n in ['97.MAT', '98.MAT', '99.MAT', '100.MAT']):
return 0 # Normal
if 'OR' in name or 'OUTER' in dirname or ('O' in name and any(s in name for s in ['007', '014', '021', '028'])):
return 1 # Outer Race Fault
if 'IR' in name or 'INNER' in dirname:
return 2 # Inner Race Fault
if 'B007' in name or 'B014' in name or 'BALL' in dirname or ('B' in name and 'IR' not in name and 'OR' not in name):
return 3 # Ball Fault
return 0
# ====================== 绘制滤波前后对比图 ======================
def plot_filter_comparison(original_signal, filtered_signal, fs=32000, title="Filtering Comparison"):
N = len(original_signal)
freqs = np.linspace(0, fs / 2, N // 2)
fft_orig = np.abs(fft(original_signal))[:N // 2]
fft_filt = np.abs(fft(filtered_signal))[:N // 2]
plot_len = 4096
t = np.arange(plot_len) / fs
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
plt.plot(t, original_signal[:plot_len], label='Original Signal', color='gray', alpha=0.8)
plt.plot(t, filtered_signal[:plot_len], label='Filtered Signal', color='red', linewidth=1.2)
plt.title(f'{title} - Time Domain')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.subplot(2, 1, 2)
plt.semilogy(freqs, fft_orig, label='Original Spectrum', color='gray', alpha=0.8)
plt.semilogy(freqs, fft_filt, label='Filtered Spectrum', color='blue', linewidth=1.2)
plt.axvline(FILTER_LOW_CUT, color='green', linestyle='--', linewidth=1.2, label=f'{FILTER_LOW_CUT} Hz')
plt.axvline(FILTER_HIGH_CUT, color='green', linestyle='--', linewidth=1.2, label=f'{FILTER_HIGH_CUT} Hz')
plt.title('Frequency Domain (Log Scale)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')
plt.xlim(0, fs / 2)
plt.legend()
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('filter_comparison.png', dpi=150)
plt.show()
# ====================== 绘制包络谱示例 ======================
def plot_envelope_spectrum_example(signal_proc, fs=32000, rpm=1797, title="Envelope Spectrum"):
analytic = hilbert(signal_proc)
envelope = np.abs(analytic)
N = len(envelope)
env_fft = np.abs(fft(envelope))[:N//2]
freqs = np.linspace(0, fs/2, N//2)
plt.figure(figsize=(10, 4))
plt.semilogy(freqs, env_fft, color='red', linewidth=1.2, label='Envelope Spectrum')
fr = rpm / 60
bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle))
bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle))
bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \
(1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2)
for f, name in [(bpfo, 'BPFO'), (bpfi, 'BPFI'), (bsf, 'BSF'), (fr, 'FR')]:
plt.axvline(f, color='blue', linestyle='--', alpha=0.7, linewidth=1)
plt.text(f, np.max(env_fft)*0.8, name, rotation=90, va='top', fontsize=9)
plt.xlim(0, 6000)
plt.xlabel("Frequency (Hz)")
plt.ylabel("Magnitude (log)")
plt.title(f"Envelope Spectrum - {title}")
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(f"envelope_spectrum_{title.replace(' ', '_')}.png", dpi=150)
plt.show()
# ====================== 绘制源域 vs 目标域特征分布对比图 ======================
def plot_domain_comparison(X, y, d, feature_names):
X_source = X[d == 'source']
X_target = X[d == 'target']
if len(X_source) == 0 or len(X_target) == 0:
print("❌ 源域或目标域无数据,无法绘图")
return
print("📊 正在生成源域与目标域特征分布对比图...")
# ---------------------- 图1:箱型图(前8个特征)----------------------
plt.figure(figsize=(14, 8))
n_show = min(8, X.shape[1])
data_to_plot = [X_source[:, i] for i in range(n_show)] + [X_target[:, i] for i in range(n_show)]
labels = [f'{name}\nSource' for name in feature_names[:n_show]] + \
[f'{name}\nTarget' for name in feature_names[:n_show]]
bp = plt.boxplot(data_to_plot, labels=labels, patch_artist=True)
colors = ['lightblue'] * n_show + ['lightcoral'] * n_show
for patch, color in zip(bp['boxes'], colors):
patch.set_facecolor(color)
plt.xticks(rotation=45)
plt.title('Boxplot: Feature Distribution Comparison (Source vs Target)')
plt.ylabel('Standardized Value')
plt.grid(True, axis='y', linestyle='--', alpha=0.6)
plt.tight_layout()
plt.savefig('boxplot_source_vs_target.png', dpi=150)
plt.show()
# ---------------------- 图2:叠加直方图(TD_Mean 和 ENV_BPFO)----------------------
idx1 = 0 # TD_0 Mean
idx2 = -3 # ENV_BPFO(倒数第3个)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].hist(X_source[:, idx1], bins=50, alpha=0.7, label='Source', color='blue', density=True)
axes[0].hist(X_target[:, idx1], bins=50, alpha=0.7, label='Target', color='orange', density=True)
axes[0].set_title(f'Histogram: {feature_names[idx1]}')
axes[0].set_xlabel('Value')
axes[0].set_ylabel('Density')
axes[0].legend()
axes[0].grid(True, linestyle='--', alpha=0.5)
axes[1].hist(X_source[:, idx2], bins=50, alpha=0.7, label='Source', color='blue', density=True)
axes[1].hist(X_target[:, idx2], bins=50, alpha=0.7, label='Target', color='orange', density=True)
axes[1].set_title(f'Histogram: {feature_names[idx2]} (ENV_BPFO)')
axes[1].set_xlabel('Value')
axes[1].legend()
axes[1].grid(True, linestyle='--', alpha=0.5)
plt.suptitle('Overlapped Histograms of Key Features')
plt.tight_layout()
plt.savefig('histogram_overlap.png', dpi=150)
plt.show()
# ---------------------- 图3:t-SNE 可视化 ----------------------
X_concat = np.vstack((X_source, X_target))
d_concat = ['Source'] * len(X_source) + ['Target'] * len(X_target)
tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
X_tsne = tsne.fit_transform(X_concat)
plt.figure(figsize=(10, 8))
plt.scatter(X_tsne[:len(X_source), 0], X_tsne[:len(X_source), 1],
c='tab:blue', label='Source Domain', alpha=0.7, s=60)
plt.scatter(X_tsne[len(X_source):, 0], X_tsne[len(X_source):, 1],
c='tab:orange', label='Target Domain', alpha=0.7, s=60)
plt.title('t-SNE: Source vs Target Feature Space')
plt.xlabel('t-SNE Component 1')
plt.ylabel('t-SNE Component 2')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('tsne_source_vs_target.png', dpi=150)
plt.show()
# ---------------------- 图4:KDE 核密度估计图(前6个特征)----------------------
n_kde = min(6, X.shape[1])
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.ravel()
for i in range(n_kde):
kde_src = gaussian_kde(X_source[:, i])
kde_tar = gaussian_kde(X_target[:, i])
x_range = np.linspace(X[:, i].min(), X[:, i].max(), 200)
axes[i].plot(x_range, kde_src(x_range), label='Source', color='blue')
axes[i].fill_between(x_range, kde_src(x_range), alpha=0.3, color='blue')
axes[i].plot(x_range, kde_tar(x_range), label='Target', color='orange')
axes[i].fill_between(x_range, kde_tar(x_range), alpha=0.3, color='orange')
axes[i].set_title(feature_names[i])
axes[i].legend()
axes[i].grid(True, linestyle='--', alpha=0.5)
plt.suptitle('KDE: Probability Density of Features')
plt.tight_layout()
plt.savefig('kde_feature_density.png', dpi=150)
plt.show()
# ---------------------- 图5:JS散度热力图 ----------------------
js_distances = []
for i in range(X.shape[1]):
hist_src, _ = np.histogram(X_source[:, i], bins=50, density=True)
hist_tar, _ = np.histogram(X_target[:, i], bins=50, density=True)
js_dist = jenshannon(hist_src, hist_tar)**2
js_distances.append(js_dist)
top_n = min(20, len(js_distances))
indices = np.argsort(js_distances)[-top_n:]
labels = [feature_names[i] for i in indices]
values = [js_distances[i] for i in indices]
plt.figure(figsize=(10, 6))
bars = plt.barh(labels, values, color='purple', alpha=0.7)
plt.xlabel('Jensen-Shannon Divergence (Squared)')
plt.title(f'Top {top_n} Features by Domain Shift (JS Divergence)')
plt.grid(True, axis='x', linestyle='--', alpha=0.5)
for bar, val in zip(bars, values):
plt.text(bar.get_width() + np.max(values)*0.01, bar.get_y() + bar.get_height()/2,
f'{val:.3f}', va='center', fontsize=9)
plt.tight_layout()
plt.savefig('js_divergence_heatmap.png', dpi=150)
plt.show()
# ====================== 主函数 ======================
def main():
print("🚀 开始加载源域与目标域数据...")
feature_list = []
label_list = []
file_list = []
domain_list = []
count_per_class = {0: 0, 1: 0, 2: 0, 3: 0}
max_per_class = 40
show_plot = True # 控制只画一次滤波图和包络谱图
# 检查路径
if not os.path.exists(source_data_dir):
print(f"❌ 源域路径不存在: {source_data_dir}")
return
if not os.path.exists(target_data_dir):
print(f"❌ 目标域路径不存在: {target_data_dir}")
return
source_files = find_mat_files(source_data_dir)
target_files = find_mat_files(target_data_dir)
all_files = [(f, 'source') for f in source_files] + [(f, 'target') for f in target_files]
print(f"📌 总共将处理 {len(all_files)} 个文件(源域: {len(source_files)}, 目标域: {len(target_files)})")
processed_count = 0
for filepath, domain in all_files:
filename = os.path.basename(filepath)
print(f"📌 正在处理 [{domain}]: {filename}")
signal, rpm = load_cwru_data(filepath)
if signal is None or len(signal) == 0:
continue
# 判断原始采样率
if '48K' in filepath.upper() or '48000' in filepath:
fs_orig = 48000
else:
fs_orig = 12000
# 重采样
if fs_orig != fs_target:
num_new = int(len(signal) * fs_target / fs_orig)
new_time_idx = np.linspace(0, len(signal)-1, num_new)
signal_resampled = np.interp(new_time_idx, np.arange(len(signal)), signal)
else:
signal_resampled = signal.copy()
# 统一长度
if len(signal_resampled) > target_length:
signal_resampled = signal_resampled[:target_length]
elif len(signal_resampled) < target_length:
pad_len = target_length - len(signal_resampled)
signal_resampled = np.pad(signal_resampled, (0, pad_len), mode='constant', constant_values=0)
# 保存原始信号用于绘图
original_for_plot = signal_resampled.copy()
# 滤波
signal_proc = preprocess_signal(signal_resampled, fs_orig, fs_target, target_length)
# 绘制滤波对比图(仅一次)
if show_plot:
plot_filter_comparison(original_for_plot, signal_proc, fs_target, title=f"{domain.upper()}: {filename}")
plot_envelope_spectrum_example(signal_proc, fs_target, rpm, title=filename)
show_plot = False
# 提取标签
label = parse_label_from_path(filepath)
if count_per_class[label] >= max_per_class:
continue
# === 特征提取(全部升级)===
td_feat = time_domain_features(signal_proc)
ff_feat = freq_domain_features(signal_proc, fs=fs_target, fr=rpm/60)
wp_feat = wavelet_packet_features(signal_proc)
env_feat = envelope_spectrum_features(signal_proc, fs=fs_target, fr=rpm/60) # 新增!
all_features = td_feat + ff_feat + wp_feat + env_feat # 总 ~41 维
feature_list.append(all_features)
label_list.append(label)
file_list.append(filename)
domain_list.append(domain)
count_per_class[label] += 1
processed_count += 1
# ====================== 保存特征并绘图 ======================
if len(feature_list) == 0:
print("❌ 未提取到任何有效特征")
return
print(f"✅ 共提取 {processed_count} 个样本")
print("📊 各类统计:", count_per_class)
X = np.array(feature_list)
X = np.real(X).astype(np.float64)
y = np.array(label_list)
d = np.array(domain_list)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 更新特征名称(包含新特征)
feature_names = (
[f"TD_{i}" for i in range(12)] +
[f"FF_{i}" for i in range(8)] +
["WP_Entropy"] + [f"WP_Band_{i}" for i in range(8)] +
["ENV_BPFO", "ENV_BPFI", "ENV_BSF", "ENV_FR",
"ENV_High", "ENV_Mid", "ENV_Low", "ENV_HL_Ratio", "ENV_OI_Ratio",
"ENV_Kurtosis", "ENV_Peak", "ENV_Mean"]
)
df = pd.DataFrame(X_scaled, columns=feature_names)
df.insert(0, 'filename', file_list)
df.insert(1, 'label', y)
df.insert(2, 'domain', d)
df.to_csv(output_csv, index=False)
print(f"📁 特征已保存至: {output_csv}")
# === 绘制所有对比图 ===
plot_domain_comparison(X_scaled, y, d, feature_names)
# === 最后绘制 PCA 图(按 domain 区分)===
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
plt.figure(figsize=(12, 8))
domains = ['source', 'target']
colors = ['tab:blue', 'tab:orange']
for i, dom in enumerate(domains):
idx = d == dom
plt.scatter(X_pca[idx, 0], X_pca[idx, 1], c=colors[i], label=f'{dom.capitalize()} Domain', alpha=0.7, s=60)
plt.title('PCA: Source vs Target Domain Feature Distribution', fontsize=14)
plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('pca_source_vs_target.png', dpi=150)
plt.show()
print("🎉 所有任务完成!")
if __name__ == '__main__':
main()
任务二的代码为:# -*- coding: utf-8 -*-
"""
📌 任务二:基于41维特征的多分类器对比分析(使用10折交叉验证)
📊 新评价指标:Accuracy, Precision, Recall, F1-Score(宏平均)
🔧 使用源域数据 + 10-Fold CV,避免划分偏差
📦 新增功能:保存训练好的随机森林模型供任务三使用
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import warnings
import os
import joblib # 新增导入
warnings.filterwarnings('ignore')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
sns.set_style("whitegrid")
# ====================== 主程序开始 ======================
print("🚀 开始加载并处理特征数据...")
df = pd.read_csv('extracted_features_with_domain.csv')
source_data = df[df['domain'] == 'source'].copy()
print(f"✅ 源域样本数: {len(source_data)}")
X = source_data.drop(columns=['filename', 'label', 'domain']).values
y = source_data['label'].values
labels = sorted(np.unique(y))
class_names = ['Normal', 'Outer Race', 'Inner Race', 'Ball']
# 标准化(仅用于需要标准化的模型)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 定义分类器
models = {
"Random Forest": RandomForestClassifier(n_estimators=100, random_state=42),
"SVM_RBF": SVC(kernel='rbf', C=1.0, gamma='scale', probability=True, random_state=42),
"MLP": MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, alpha=1e-4,
batch_size=32, early_stopping=True, random_state=42)
}
# 十折分层交叉验证
cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
results_summary = []
print("\n" + "="*60)
print("🔥 正在进行 10-Fold Cross Validation 评估各分类器...")
print("="*60)
for name, model in models.items():
print(f"\n📌 正在评估 {name}...")
# 是否需要标准化?
X_use = X_scaled if name in ["SVM_RBF", "MLP"] else X
all_y_true, all_y_pred = [], []
for train_idx, test_idx in cv.split(X_use, y):
X_train_fold, X_test_fold = X_use[train_idx], X_use[test_idx]
y_train_fold, y_test_fold = y[train_idx], y[test_idx]
model.fit(X_train_fold, y_train_fold)
y_pred_fold = model.predict(X_test_fold)
all_y_true.extend(y_test_fold)
all_y_pred.extend(y_pred_fold)
# 计算总体 Accuracy
acc = accuracy_score(all_y_true, all_y_pred)
# 获取 classification report(宏平均)
report = classification_report(all_y_true, all_y_pred, target_names=class_names,
labels=labels, output_dict=True)
precision_macro = report['macro avg']['precision']
recall_macro = report['macro avg']['recall']
f1_macro = report['macro avg']['f1-score']
# 存储结果
results_summary.append({
'Model': name,
'Accuracy': acc,
'Precision': precision_macro,
'Recall': recall_macro,
'F1-Score': f1_macro
})
# 输出详细报告
print(f"🎯 {name} 10-Fold CV 结果:")
print(f" Accuracy = {acc:.3f}")
print(f" Precision = {precision_macro:.3f}")
print(f" Recall = {recall_macro:.3f}")
print(f" F1-Score = {f1_macro:.3f}")
# 显示分类报告
print("\n📋 分类报告 (Classification Report):")
print(classification_report(all_y_true, all_y_pred, target_names=class_names))
# 绘制混淆矩阵热力图
cm = confusion_matrix(all_y_true, all_y_pred, labels=labels)
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title(f'Confusion Matrix - {name} (10-Fold CV)')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.tight_layout()
plt.savefig(f'cm_{name.replace(" ", "_")}_10fold.png', dpi=150)
plt.show()
# ====================== 综合对比表格 ======================
print("\n" + "="*60)
print("🏆 所有模型综合性能对比(10-Fold CV)")
print("="*60)
summary_df = pd.DataFrame(results_summary).round(3)
print(summary_df.to_string(index=False))
# 保存为 CSV
summary_df.to_csv('model_comparison_summary_10fold_updated.csv', index=False)
print("💾 综合结果已保存至: model_comparison_summary_10fold_updated.csv")
# ====================== 可视化四项指标柱状图 ======================
metrics_plot = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
colors = ['#4E79A7', '#F28E2B', '#E15759']
x_pos = np.arange(len(models))
fig, ax = plt.subplots(figsize=(10, 6))
width = 0.2
for i, metric in enumerate(metrics_plot):
values = summary_df[metric].values
bars = ax.bar(x_pos + i*width, values, width, label=metric, color=colors[i % len(colors)], alpha=0.8)
# 添加数值标签
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f"{val:.3f}",
ha='center', va='bottom', fontsize=9, fontweight='bold')
ax.set_xlabel('Model')
ax.set_ylabel('Score')
ax.set_title('Model Comparison on Source Domain (41D Features) - 10-Fold CV\nMetrics: Accuracy, Precision, Recall, F1-Score')
ax.set_xticks(x_pos + width * 1.5)
ax.set_xticklabels(summary_df['Model'])
ax.set_ylim(0, 1.0)
ax.legend()
ax.grid(True, axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('model_performance_comparison_4metrics.png', dpi=150)
plt.show()
# ====================== 特征重要性(仅RF)======================
if 'Random Forest' in models:
# 训练完整随机森林模型
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
rf_model.fit(X, y) # 使用全部源域训练
# ✅ 新增:保存模型及元数据
model_dir = "saved_models"
os.makedirs(model_dir, exist_ok=True) # 创建保存目录
# 保存核心模型文件
joblib.dump(
rf_model,
os.path.join(model_dir, "random_forest_model.pkl")
)
# 保存元数据(特征名、类别名、标签等)
metadata = {
'feature_names': df.columns.drop(['filename', 'label', 'domain']).tolist(),
'class_names': class_names,
'labels': sorted(np.unique(y)),
'preprocessing': 'none' # 表示模型训练时未使用标准化
}
joblib.dump(
metadata,
os.path.join(model_dir, "random_forest_metadata.pkl")
)
print("📁 模型及元数据已保存至 saved_models/ 目录")
# 绘制特征重要性
feat_importance = rf_model.feature_importances_
feature_names = df.columns.drop(['filename', 'label', 'domain'])
indices = np.argsort(feat_importance)[::-1][:20]
plt.figure(figsize=(10, 6))
plt.barh([feature_names[i] for i in indices[::-1]], feat_importance[indices][::-1], color='steelblue')
plt.xlabel('Feature Importance (Gini)')
plt.title('Top 20 Important Features - Random Forest (Full Training Set)')
plt.gca().invert_yaxis()
plt.grid(True, axis='x', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig('rf_feature_importance_10fold.png', dpi=150)
plt.show()
print("\n🎉 所有任务完成!请查看生成的图表与CSV文件。")
给出一个基于任务一和任务二的任务三代码,可以解决任务三的要求,并且给出多个可视化对比图