已知y=m*m*x+3*n,推理出参数m和n

该文章演示了如何在TensorFlow中创建一个自定义层,并用它来构建一个简单的模型。通过定义输入和输出数据,模型基于这些数据进行训练,以学习到权重参数m和n。使用MeanSquaredError作为损失函数,Adam优化器进行优化,并在100个epoch内进行训练。最后,文章打印出训练得到的m和n的值。
import tensorflow as tf
import numpy as np

# 定义输入和输出数据
x = np.random.rand(100).astype(np.float32)

# 假设m=3,n=4去生成输入,去推理出m、n
y = 3 * 3 * x + 3 * 4

# 定义自定义层
class MyLayer(tf.keras.layers.Layer):
    def __init__(self, units=1):
        super(MyLayer, self).__init__()
        self.units = units

    def build(self, input_shape):
        self.m = self.add_weight(
            name='m',
            shape=(1,),
            initializer='random_normal',
            trainable=True)
        self.n = self.add_weight(
            name='n',
            shape=(1,),
            initializer='random_normal',
            trainable=True)

    def call(self, inputs):
        return self.m * self.m * inputs + 3 * self.n

# 定义模型
model = tf.keras.Sequential([
    MyLayer()
])

# 定义损失函数
loss_fn = tf.keras.losses.MeanSquaredError()

# 定义优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

# 编译模型
model.compile(optimizer=optimizer, loss=loss_fn)

# 训练模型,每100次输出一次日志
model.fit(x, y, epochs=100, steps_per_epoch=100)

# 打印结果
for layer in model.layers:
    print(f"m = {layer.m.numpy()}, n = {layer.n.numpy()}")

import numpy as np # 定义 Sigmoid 函数 def sigmoid(x): return 1 / (1 + np.exp(-x)) # 分段参数 N_SEGMENTS = 512 # 总段数 X_MIN, X_MAX = -4,4 # Sigmoid 有效定义域[-8,8] DX = (X_MAX - X_MIN) / N_SEGMENTS # 每段长度 ≈0.062 def calculate_coefficients(): # 存储系数:每行[a, b, c]对应一个分段 coefficients = np.zeros((N_SEGMENTS, 3)) seg = [] for i in range(N_SEGMENTS): # 当前分段区间 [start, end] start = X_MIN + i * DX end = start + DX seg.append([start, end]) # 在区间内均匀采样5个点(含端点) x_samples = np.linspace(start, end, 20) y_samples = sigmoid(x_samples) # 构建最小二乘矩阵:f(x)=a*x^2 - b*x + c A = np.vstack([ x_samples**2, # a 系数 -x_samples, # b 系数(注意负号) np.ones_like(x_samples) # c 系数 ]).T # 求解最小二乘问题 coef, _, _, _ = np.linalg.lstsq(A, y_samples, rcond=None) coefficients[i] = coef return coefficients, seg class SigmoidLUT: def __init__(self): self.coeffs, self.seg = calculate_coefficients() self.x_min = X_MIN self.dx = DX def __call__(self, x): # 计算所属分段索引 idx = np.clip(((x - self.x_min) / self.dx).astype(int), 0, N_SEGMENTS-1) # 获取系数 a, b, c = self.coeffs[idx] # 计算二次多项式 return a*x**2 - b*x + c def calculate_symmetric_coeffs(): # 只计算正半轴区间 (128个分段) half_segments = N_SEGMENTS // 2 coeffs_positive = np.zeros((half_segments, 3)) # 计算正半轴系数 (x>0) for i in range(half_segments): start = -DX + i * DX end = start + 2*DX x_samples = np.linspace(start, end, 1000) y_samples = sigmoid(x_samples) A = np.vstack([x_samples**2, -x_samples, np.ones_like(x_samples)]).T coeffs_positive[i] = np.linalg.lstsq(A, y_samples, rcond=None)[0] # 构建完整系数表 (257段) full_coeffs = np.zeros((N_SEGMENTS, 3)) # 负半轴通过对称性获得 for i in range(half_segments): # 正半轴系数: [a, b, c] a, b, c = coeffs_positive[i] # 对称负半轴系数: f(-x) = 1 - f(x) # => a(-x)^2 - b(-x) + c = ax^2 + bx + c # 但需满足: σ(-x) ≈ 1 - (ax^2 + bx + c) # 所以新系数: [-a, -b, 1-c] full_coeffs[i] = [-a, -b, 1 - c] # 正半轴直接使用 full_coeffs[N_SEGMENTS - 1 - i] = [a, b, c] # 处理中心分段 (x=0附近) center_idx = half_segments start, end = -DX/2, DX/2 x_samples = np.linspace(start, end, 5) y_samples = sigmoid(x_samples) A = np.vstack([x_samples**2, -x_samples, np.ones(5)]).T full_coeffs[center_idx] = np.linalg.lstsq(A, y_samples, rcond=None)[0] return full_coeffs def evaluate_accuracy(lut, n_samples=N_SEGMENTS*16): # 在[-10,10]区间测试 x_test = np.linspace(0,1, n_samples) y_true = sigmoid(x_test) y_pred = np.array([lut(x) for x in x_test]) # 计算最大绝对误差均方根误差 abs_error = np.abs(y_true - y_pred) max_error = np.max(abs_error) rmse = np.sqrt(np.mean(abs_error**2)) # 统计超门限的误差 threshold = 1e-4 over_threshold = np.sum(abs_error > threshold) / n_samples return max_error, rmse, over_threshold # 测试标准实现 standard_lut = SigmoidLUT() max_err, rmse, over_th = evaluate_accuracy(standard_lut) print(f"标准实现: MaxError={max_err:.2e}, RMSE={rmse:.2e}, >1e-4比例={over_th*100:.2f}%") # 测试对称优化实现 symmetric_lut = SigmoidLUT() symmetric_lut.coeffs = calculate_symmetric_coeffs() max_err_sym, rmse_sym, over_th_sym = evaluate_accuracy(symmetric_lut) print(f"对称优化: MaxError={max_err_sym:.2e}, RMSE={rmse_sym:.2e}, >1e-4比例={over_th_sym*100:.2f}%") 能否帮我这段代码提高精度
12-09
继续生成新的任务三代码,如题:、 迁移诊断:在任务2设计的诊断模型基础上,充分考虑源域与目标域的共性与差异特征,设计合适的迁移学习方法,构建目标域诊断模型,对目标域未知标签的数据进行分类标定,给迁移结果的可视化展示分析,并给数据对应的标签。已知任务一的代码为: # -*- coding: utf-8 -*- """ 📌 任务一:改进版特征提取(含包络谱 + 故障频率敏感特征) 🔧 目标:解决任务二准确率低的问题 💡 核心改进:加入 Hilbert 包络解调特征,极大增强对 Outer/Ball 故障的识别能力 """ import os import numpy as np import scipy.io as sio from scipy.fft import fft from scipy.stats import skew, kurtosis from scipy.signal import butter, filtfilt, hilbert import pywt from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.manifold import TSNE from scipy.stats import gaussian_kde import matplotlib.pyplot as plt import pandas as pd # ====================== 自定义 JS 散度(兼容低版本 SciPy)====================== def jenshannon(p, q): """ 手动实现 Jensen-Shannon Divergence 的平方根(即 JS距离) 输入:两个数组 p q(概率分布) 输:JS距离(非负值) """ p = np.asarray(p) q = np.asarray(q) p = p / (p.sum() + 1e-8) q = q / (q.sum() + 1e-8) m = 0.5 * (p + q) def entropy(x): return -np.sum(x * np.log(x + 1e-8)) js_divergence = entropy(m) - 0.5 * entropy(p) - 0.5 * entropy(q) return np.sqrt(max(js_divergence, 0)) # 返回 JS距离 # ====================== 路径设置 ====================== source_data_dir = r&#39;C:\Users\1\Desktop\新建文件夹\数据集\源域数据集&#39; target_data_dir = r&#39;C:\Users\1\Desktop\新建文件夹\数据集\目标域数据集&#39; output_csv = &#39;extracted_features_with_domain.csv&#39; fs_target = 32000 # 统一重采样到32kHz target_length = 320000 # 每段信号长度(约10秒) # 轴承参数(SKF6205) n_rollers = 9 diameter_ball = 0.3126 * 0.0254 pitch_diameter = 1.537 * 0.0254 contact_angle = np.radians(15) # 滤波参数 FILTER_LOW_CUT = 200.0 FILTER_HIGH_CUT = 10000.0 FILTER_ORDER = 6 # ====================== 查找所有.mat文件 ====================== def find_mat_files(root_dir): mat_files = [] for dirpath, _, filenames in os.walk(root_dir): for f in filenames: if f.endswith(&#39;.mat&#39;) and not f.startswith(&#39;._&#39;): mat_files.append(os.path.join(dirpath, f)) print(f"🔍 在 {root_dir} 中发现 {len(mat_files)} 个 .mat 文件") return mat_files # ====================== 加载.mat文件 ====================== def load_cwru_data(filepath): try: mat = sio.loadmat(filepath) except Exception as e: print(f"[错误] 无法读取 {filepath}: {e}") return None, None de_key = rpm_key = None for k in mat.keys(): if not k.startswith(&#39;__&#39;) and &#39;DE&#39; in k: # 驱动端信号 de_key = k if &#39;RPM&#39; in k.upper(): rpm_key = k if de_key is None: # 回退策略:找最长的一维数组 for k in mat.keys(): if isinstance(mat[k], np.ndarray): tmp = mat[k].flatten() if len(tmp) >= 10000: de_key = k break if de_key is None: print(f"[警告] 未在 {filepath} 中找到驱动端(DE)信号") return None, None try: signal = mat[de_key].flatten() except: return None, None rpm = 1797 if rpm_key in mat: try: rpm = float(mat[rpm_key].item()) if mat[rpm_key].size == 1 else float(mat[rpm_key][0, 0]) except: pass return signal, rpm # ====================== 带通滤波 ====================== def bandpass_filter(signal, lowcut, highcut, fs, order=5): nyquist = 0.5 * fs low = lowcut / nyquist high = highcut / nyquist if low >= 1.0 or high >= 1.0: print(f"[警告] 截止频率超范围(fs={fs}),跳过滤波") return signal b, a = butter(order, [low, high], btype=&#39;band&#39;, analog=False) return filtfilt(b, a, signal) # ====================== 预处理信号 ====================== def preprocess_signal(x, fs_original, fs_target=32000, target_len=320000): if len(x) == 0: return np.zeros(target_len) # 重采样 if fs_original != fs_target: num_new = int(len(x) * fs_target / fs_original) new_time_idx = np.linspace(0, len(x)-1, num_new) x = np.interp(new_time_idx, np.arange(len(x)), x) # 滤波 x = bandpass_filter(x, FILTER_LOW_CUT, FILTER_HIGH_CUT, fs_target, FILTER_ORDER) # 统一长度 if len(x) > target_len: x = x[:target_len] elif len(x) < target_len: pad_len = target_len - len(x) x = np.pad(x, (0, pad_len), mode=&#39;constant&#39;, constant_values=0) return x # ====================== 时域特征 ====================== def time_domain_features(x): if len(x) == 0 or np.var(x) < 1e-10: return [0.0] * 12 x_abs = np.abs(x) x_sq = x ** 2 mean_val = np.mean(x) std_val = np.std(x) peak_val = np.max(x_abs) rms_val = np.sqrt(np.mean(x_sq)) return [ mean_val, std_val, peak_val, rms_val, peak_val / (rms_val + 1e-8), # crest factor skew(x), kurtosis(x), peak_val / (np.mean(x_abs) + 1e-8), # impulse factor peak_val / ((np.mean(np.sqrt(x_abs)))**2 + 1e-8), # margin factor rms_val / (np.mean(x_abs) + 1e-8), # shape factor np.sum(x_sq), # energy peak_val / (np.mean(np.sqrt(x_abs)) + 1e-8), # clearance factor ] # ====================== 频域特征 ====================== def freq_domain_features(x, fs=32000, fr=None): N = len(x) X_fft = np.abs(fft(x))[:N // 2] freqs = np.linspace(0, fs / 2, N // 2) if fr is None: fr = 1797 / 60 # 默认转频 Hz bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle)) bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle)) bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \ (1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2) def band_energy(center, width=10): idx = (freqs >= center - width) & (freqs <= center + width) return np.sum(X_fft[idx]**2) if np.any(idx) else 0.0 total_power = np.sum(X_fft**2) + 1e-8 return [ np.max(X_fft), np.mean(X_fft), np.sum(freqs * X_fft) / (np.sum(X_fft) + 1e-8), np.sqrt(np.var(freqs[:len(X_fft)] * X_fft)), band_energy(bpfo) / total_power, band_energy(bpfi) / total_power, band_energy(bsf) / total_power, np.sum(X_fft[(freqs > 1000) & (freqs < 3000)]**2) / total_power ] # ====================== 小波包能量熵(优化版)====================== def wavelet_packet_features(x, level=4, wavelet=&#39;db4&#39;): try: wp = pywt.WaveletPacket(data=x, wavelet=wavelet, maxlevel=level) energies = [np.sum(np.square(node.data)) for node in wp.get_level(level, &#39;natural&#39;)] energies = np.array(energies) p_i = energies / (energies.sum() + 1e-8) entropy = -np.sum(p_i * np.log(p_i + 1e-8)) return [entropy] + p_i[:8].tolist() # 固定返回 9 维 except Exception as e: print(f"[警告] 小波包分解失败: {e}") return [0.0] * 9 # ====================== 包络谱特征提取(新增!核心改进)====================== def envelope_spectrum_features(x, fs=32000, fr=None): """ 提取 Hilbert 包络谱中的关键能量特征(对早期故障极其敏感) """ if len(x) == 0 or np.var(x) < 1e-10: return [0.0] * 12 # 1. Hilbert 变换获取包络 try: analytic_signal = hilbert(x) envelope = np.abs(analytic_signal) except: envelope = np.abs(hilbert(x)) # 2. 对包络做 FFT N = len(envelope) env_fft = np.abs(fft(envelope))[:N // 2] freqs = np.linspace(0, fs / 2, N // 2) total_power = np.sum(env_fft**2) + 1e-8 if fr is None: fr = 1797 / 60 # 转频 Hz # 计算理论故障频率 bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle)) bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle)) bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \ (1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2) def band_energy(center, width=50): idx = (freqs >= center - width) & (freqs <= center + width) return np.sum(env_fft[idx]**2) if np.any(idx) else 0.0 energy_bpfo = band_energy(bpfo) energy_bpfi = band_energy(bpfi) energy_bsf = band_energy(bsf) energy_fr = band_energy(fr, width=20) # 归一化能量 energy_bpfo /= total_power energy_bpfi /= total_power energy_bsf /= total_power energy_fr /= total_power # 高中低频段能量 high_freq_mask = (freqs >= 2000) & (freqs <= 8000) mid_freq_mask = (freqs >= 500) & (freqs < 2000) low_freq_mask = (freqs < 500) energy_high = np.sum(env_fft[high_freq_mask]**2) / total_power energy_mid = np.sum(env_fft[mid_freq_mask]**2) / total_power energy_low = np.sum(env_fft[low_freq_mask]**2) / total_power # 特征组合 return [ energy_bpfo, energy_bpfi, energy_bsf, energy_fr, energy_high, energy_mid, energy_low, energy_high / (energy_low + 1e-8), # 高/低频能量比 energy_bpfo / (energy_bpfi + 1e-8), # 外圈 vs 内圈能量比 kurtosis(envelope), # 包络峭度(冲击性) np.max(env_fft), # 包络谱峰值 np.mean(env_fft) # 包络谱均值 ] # ====================== 标签解析 ====================== def parse_label_from_path(filepath): name = os.path.basename(filepath).upper() dirname = os.path.dirname(filepath).upper() if any(n in name for n in [&#39;97.MAT&#39;, &#39;98.MAT&#39;, &#39;99.MAT&#39;, &#39;100.MAT&#39;]): return 0 # Normal if &#39;OR&#39; in name or &#39;OUTER&#39; in dirname or (&#39;O&#39; in name and any(s in name for s in [&#39;007&#39;, &#39;014&#39;, &#39;021&#39;, &#39;028&#39;])): return 1 # Outer Race Fault if &#39;IR&#39; in name or &#39;INNER&#39; in dirname: return 2 # Inner Race Fault if &#39;B007&#39; in name or &#39;B014&#39; in name or &#39;BALL&#39; in dirname or (&#39;B&#39; in name and &#39;IR&#39; not in name and &#39;OR&#39; not in name): return 3 # Ball Fault return 0 # ====================== 绘制滤波前后对比图 ====================== def plot_filter_comparison(original_signal, filtered_signal, fs=32000, title="Filtering Comparison"): N = len(original_signal) freqs = np.linspace(0, fs / 2, N // 2) fft_orig = np.abs(fft(original_signal))[:N // 2] fft_filt = np.abs(fft(filtered_signal))[:N // 2] plot_len = 4096 t = np.arange(plot_len) / fs plt.figure(figsize=(12, 8)) plt.subplot(2, 1, 1) plt.plot(t, original_signal[:plot_len], label=&#39;Original Signal&#39;, color=&#39;gray&#39;, alpha=0.8) plt.plot(t, filtered_signal[:plot_len], label=&#39;Filtered Signal&#39;, color=&#39;red&#39;, linewidth=1.2) plt.title(f&#39;{title} - Time Domain&#39;) plt.xlabel(&#39;Time (s)&#39;) plt.ylabel(&#39;Amplitude&#39;) plt.legend() plt.grid(True, linestyle=&#39;--&#39;, alpha=0.5) plt.subplot(2, 1, 2) plt.semilogy(freqs, fft_orig, label=&#39;Original Spectrum&#39;, color=&#39;gray&#39;, alpha=0.8) plt.semilogy(freqs, fft_filt, label=&#39;Filtered Spectrum&#39;, color=&#39;blue&#39;, linewidth=1.2) plt.axvline(FILTER_LOW_CUT, color=&#39;green&#39;, linestyle=&#39;--&#39;, linewidth=1.2, label=f&#39;{FILTER_LOW_CUT} Hz&#39;) plt.axvline(FILTER_HIGH_CUT, color=&#39;green&#39;, linestyle=&#39;--&#39;, linewidth=1.2, label=f&#39;{FILTER_HIGH_CUT} Hz&#39;) plt.title(&#39;Frequency Domain (Log Scale)&#39;) plt.xlabel(&#39;Frequency (Hz)&#39;) plt.ylabel(&#39;Magnitude&#39;) plt.xlim(0, fs / 2) plt.legend() plt.grid(True, which=&#39;both&#39;, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(&#39;filter_comparison.png&#39;, dpi=150) plt.show() # ====================== 绘制包络谱示例 ====================== def plot_envelope_spectrum_example(signal_proc, fs=32000, rpm=1797, title="Envelope Spectrum"): analytic = hilbert(signal_proc) envelope = np.abs(analytic) N = len(envelope) env_fft = np.abs(fft(envelope))[:N//2] freqs = np.linspace(0, fs/2, N//2) plt.figure(figsize=(10, 4)) plt.semilogy(freqs, env_fft, color=&#39;red&#39;, linewidth=1.2, label=&#39;Envelope Spectrum&#39;) fr = rpm / 60 bpfo = fr * n_rollers / 2 * (1 - diameter_ball / pitch_diameter * np.cos(contact_angle)) bpfi = fr * (1 + n_rollers / 2 * diameter_ball / pitch_diameter * np.cos(contact_angle)) bsf = (pitch_diameter / (2 * diameter_ball)) * fr * \ (1 - (diameter_ball / pitch_diameter)**2 * np.cos(contact_angle)**2) for f, name in [(bpfo, &#39;BPFO&#39;), (bpfi, &#39;BPFI&#39;), (bsf, &#39;BSF&#39;), (fr, &#39;FR&#39;)]: plt.axvline(f, color=&#39;blue&#39;, linestyle=&#39;--&#39;, alpha=0.7, linewidth=1) plt.text(f, np.max(env_fft)*0.8, name, rotation=90, va=&#39;top&#39;, fontsize=9) plt.xlim(0, 6000) plt.xlabel("Frequency (Hz)") plt.ylabel("Magnitude (log)") plt.title(f"Envelope Spectrum - {title}") plt.grid(True, which=&#39;both&#39;, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(f"envelope_spectrum_{title.replace(&#39; &#39;, &#39;_&#39;)}.png", dpi=150) plt.show() # ====================== 绘制源域 vs 目标域特征分布对比图 ====================== def plot_domain_comparison(X, y, d, feature_names): X_source = X[d == &#39;source&#39;] X_target = X[d == &#39;target&#39;] if len(X_source) == 0 or len(X_target) == 0: print("❌ 源域或目标域无数据,无法绘图") return print("📊 正在生成源域与目标域特征分布对比图...") # ---------------------- 图1:箱型图(前8个特征)---------------------- plt.figure(figsize=(14, 8)) n_show = min(8, X.shape[1]) data_to_plot = [X_source[:, i] for i in range(n_show)] + [X_target[:, i] for i in range(n_show)] labels = [f&#39;{name}\nSource&#39; for name in feature_names[:n_show]] + \ [f&#39;{name}\nTarget&#39; for name in feature_names[:n_show]] bp = plt.boxplot(data_to_plot, labels=labels, patch_artist=True) colors = [&#39;lightblue&#39;] * n_show + [&#39;lightcoral&#39;] * n_show for patch, color in zip(bp[&#39;boxes&#39;], colors): patch.set_facecolor(color) plt.xticks(rotation=45) plt.title(&#39;Boxplot: Feature Distribution Comparison (Source vs Target)&#39;) plt.ylabel(&#39;Standardized Value&#39;) plt.grid(True, axis=&#39;y&#39;, linestyle=&#39;--&#39;, alpha=0.6) plt.tight_layout() plt.savefig(&#39;boxplot_source_vs_target.png&#39;, dpi=150) plt.show() # ---------------------- 图2:叠加直方图(TD_Mean ENV_BPFO)---------------------- idx1 = 0 # TD_0 Mean idx2 = -3 # ENV_BPFO(倒数第3个) fig, axes = plt.subplots(1, 2, figsize=(12, 5)) axes[0].hist(X_source[:, idx1], bins=50, alpha=0.7, label=&#39;Source&#39;, color=&#39;blue&#39;, density=True) axes[0].hist(X_target[:, idx1], bins=50, alpha=0.7, label=&#39;Target&#39;, color=&#39;orange&#39;, density=True) axes[0].set_title(f&#39;Histogram: {feature_names[idx1]}&#39;) axes[0].set_xlabel(&#39;Value&#39;) axes[0].set_ylabel(&#39;Density&#39;) axes[0].legend() axes[0].grid(True, linestyle=&#39;--&#39;, alpha=0.5) axes[1].hist(X_source[:, idx2], bins=50, alpha=0.7, label=&#39;Source&#39;, color=&#39;blue&#39;, density=True) axes[1].hist(X_target[:, idx2], bins=50, alpha=0.7, label=&#39;Target&#39;, color=&#39;orange&#39;, density=True) axes[1].set_title(f&#39;Histogram: {feature_names[idx2]} (ENV_BPFO)&#39;) axes[1].set_xlabel(&#39;Value&#39;) axes[1].legend() axes[1].grid(True, linestyle=&#39;--&#39;, alpha=0.5) plt.suptitle(&#39;Overlapped Histograms of Key Features&#39;) plt.tight_layout() plt.savefig(&#39;histogram_overlap.png&#39;, dpi=150) plt.show() # ---------------------- 图3:t-SNE 可视化 ---------------------- X_concat = np.vstack((X_source, X_target)) d_concat = [&#39;Source&#39;] * len(X_source) + [&#39;Target&#39;] * len(X_target) tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42) X_tsne = tsne.fit_transform(X_concat) plt.figure(figsize=(10, 8)) plt.scatter(X_tsne[:len(X_source), 0], X_tsne[:len(X_source), 1], c=&#39;tab:blue&#39;, label=&#39;Source Domain&#39;, alpha=0.7, s=60) plt.scatter(X_tsne[len(X_source):, 0], X_tsne[len(X_source):, 1], c=&#39;tab:orange&#39;, label=&#39;Target Domain&#39;, alpha=0.7, s=60) plt.title(&#39;t-SNE: Source vs Target Feature Space&#39;) plt.xlabel(&#39;t-SNE Component 1&#39;) plt.ylabel(&#39;t-SNE Component 2&#39;) plt.legend() plt.grid(True, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(&#39;tsne_source_vs_target.png&#39;, dpi=150) plt.show() # ---------------------- 图4:KDE 核密度估计图(前6个特征)---------------------- n_kde = min(6, X.shape[1]) fig, axes = plt.subplots(2, 3, figsize=(15, 8)) axes = axes.ravel() for i in range(n_kde): kde_src = gaussian_kde(X_source[:, i]) kde_tar = gaussian_kde(X_target[:, i]) x_range = np.linspace(X[:, i].min(), X[:, i].max(), 200) axes[i].plot(x_range, kde_src(x_range), label=&#39;Source&#39;, color=&#39;blue&#39;) axes[i].fill_between(x_range, kde_src(x_range), alpha=0.3, color=&#39;blue&#39;) axes[i].plot(x_range, kde_tar(x_range), label=&#39;Target&#39;, color=&#39;orange&#39;) axes[i].fill_between(x_range, kde_tar(x_range), alpha=0.3, color=&#39;orange&#39;) axes[i].set_title(feature_names[i]) axes[i].legend() axes[i].grid(True, linestyle=&#39;--&#39;, alpha=0.5) plt.suptitle(&#39;KDE: Probability Density of Features&#39;) plt.tight_layout() plt.savefig(&#39;kde_feature_density.png&#39;, dpi=150) plt.show() # ---------------------- 图5:JS散度热力图 ---------------------- js_distances = [] for i in range(X.shape[1]): hist_src, _ = np.histogram(X_source[:, i], bins=50, density=True) hist_tar, _ = np.histogram(X_target[:, i], bins=50, density=True) js_dist = jenshannon(hist_src, hist_tar)**2 js_distances.append(js_dist) top_n = min(20, len(js_distances)) indices = np.argsort(js_distances)[-top_n:] labels = [feature_names[i] for i in indices] values = [js_distances[i] for i in indices] plt.figure(figsize=(10, 6)) bars = plt.barh(labels, values, color=&#39;purple&#39;, alpha=0.7) plt.xlabel(&#39;Jensen-Shannon Divergence (Squared)&#39;) plt.title(f&#39;Top {top_n} Features by Domain Shift (JS Divergence)&#39;) plt.grid(True, axis=&#39;x&#39;, linestyle=&#39;--&#39;, alpha=0.5) for bar, val in zip(bars, values): plt.text(bar.get_width() + np.max(values)*0.01, bar.get_y() + bar.get_height()/2, f&#39;{val:.3f}&#39;, va=&#39;center&#39;, fontsize=9) plt.tight_layout() plt.savefig(&#39;js_divergence_heatmap.png&#39;, dpi=150) plt.show() # ====================== 主函数 ====================== def main(): print("🚀 开始加载源域与目标域数据...") feature_list = [] label_list = [] file_list = [] domain_list = [] count_per_class = {0: 0, 1: 0, 2: 0, 3: 0} max_per_class = 40 show_plot = True # 控制只画一次滤波图包络谱图 # 检查路径 if not os.path.exists(source_data_dir): print(f"❌ 源域路径不存在: {source_data_dir}") return if not os.path.exists(target_data_dir): print(f"❌ 目标域路径不存在: {target_data_dir}") return source_files = find_mat_files(source_data_dir) target_files = find_mat_files(target_data_dir) all_files = [(f, &#39;source&#39;) for f in source_files] + [(f, &#39;target&#39;) for f in target_files] print(f"📌 总共将处理 {len(all_files)} 个文件(源域: {len(source_files)}, 目标域: {len(target_files)})") processed_count = 0 for filepath, domain in all_files: filename = os.path.basename(filepath) print(f"📌 正在处理 [{domain}]: {filename}") signal, rpm = load_cwru_data(filepath) if signal is None or len(signal) == 0: continue # 判断原始采样率 if &#39;48K&#39; in filepath.upper() or &#39;48000&#39; in filepath: fs_orig = 48000 else: fs_orig = 12000 # 重采样 if fs_orig != fs_target: num_new = int(len(signal) * fs_target / fs_orig) new_time_idx = np.linspace(0, len(signal)-1, num_new) signal_resampled = np.interp(new_time_idx, np.arange(len(signal)), signal) else: signal_resampled = signal.copy() # 统一长度 if len(signal_resampled) > target_length: signal_resampled = signal_resampled[:target_length] elif len(signal_resampled) < target_length: pad_len = target_length - len(signal_resampled) signal_resampled = np.pad(signal_resampled, (0, pad_len), mode=&#39;constant&#39;, constant_values=0) # 保存原始信号用于绘图 original_for_plot = signal_resampled.copy() # 滤波 signal_proc = preprocess_signal(signal_resampled, fs_orig, fs_target, target_length) # 绘制滤波对比图(仅一次) if show_plot: plot_filter_comparison(original_for_plot, signal_proc, fs_target, title=f"{domain.upper()}: {filename}") plot_envelope_spectrum_example(signal_proc, fs_target, rpm, title=filename) show_plot = False # 提取标签 label = parse_label_from_path(filepath) if count_per_class[label] >= max_per_class: continue # === 特征提取(全部升级)=== td_feat = time_domain_features(signal_proc) ff_feat = freq_domain_features(signal_proc, fs=fs_target, fr=rpm/60) wp_feat = wavelet_packet_features(signal_proc) env_feat = envelope_spectrum_features(signal_proc, fs=fs_target, fr=rpm/60) # 新增! all_features = td_feat + ff_feat + wp_feat + env_feat # 总 ~41 维 feature_list.append(all_features) label_list.append(label) file_list.append(filename) domain_list.append(domain) count_per_class[label] += 1 processed_count += 1 # ====================== 保存特征并绘图 ====================== if len(feature_list) == 0: print("❌ 未提取到任何有效特征") return print(f"✅ 共提取 {processed_count} 个样本") print("📊 各类统计:", count_per_class) X = np.array(feature_list) X = np.real(X).astype(np.float64) y = np.array(label_list) d = np.array(domain_list) scaler = StandardScaler() X_scaled = scaler.fit_transform(X) # 更新特征名称(包含新特征) feature_names = ( [f"TD_{i}" for i in range(12)] + [f"FF_{i}" for i in range(8)] + ["WP_Entropy"] + [f"WP_Band_{i}" for i in range(8)] + ["ENV_BPFO", "ENV_BPFI", "ENV_BSF", "ENV_FR", "ENV_High", "ENV_Mid", "ENV_Low", "ENV_HL_Ratio", "ENV_OI_Ratio", "ENV_Kurtosis", "ENV_Peak", "ENV_Mean"] ) df = pd.DataFrame(X_scaled, columns=feature_names) df.insert(0, &#39;filename&#39;, file_list) df.insert(1, &#39;label&#39;, y) df.insert(2, &#39;domain&#39;, d) df.to_csv(output_csv, index=False) print(f"📁 特征已保存至: {output_csv}") # === 绘制所有对比图 === plot_domain_comparison(X_scaled, y, d, feature_names) # === 最后绘制 PCA 图(按 domain 区分)=== pca = PCA(n_components=2) X_pca = pca.fit_transform(X_scaled) plt.figure(figsize=(12, 8)) domains = [&#39;source&#39;, &#39;target&#39;] colors = [&#39;tab:blue&#39;, &#39;tab:orange&#39;] for i, dom in enumerate(domains): idx = d == dom plt.scatter(X_pca[idx, 0], X_pca[idx, 1], c=colors[i], label=f&#39;{dom.capitalize()} Domain&#39;, alpha=0.7, s=60) plt.title(&#39;PCA: Source vs Target Domain Feature Distribution&#39;, fontsize=14) plt.xlabel(f&#39;PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)&#39;) plt.ylabel(f&#39;PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)&#39;) plt.legend() plt.grid(True, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(&#39;pca_source_vs_target.png&#39;, dpi=150) plt.show() print("🎉 所有任务完成!") if __name__ == &#39;__main__&#39;: main() 任务二的代码为:# -*- coding: utf-8 -*- """ 📌 任务二:基于41维特征的多分类器对比分析(使用10折交叉验证) 📊 新评价指标:Accuracy, Precision, Recall, F1-Score(宏平均) 🔧 使用源域数据 + 10-Fold CV,避免划分偏差 📦 新增功能:保存训练好的随机森林模型供任务三使用 """ import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.model_selection import StratifiedKFold from sklearn.ensemble import RandomForestClassifier from sklearn.svm import SVC from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import StandardScaler from sklearn.metrics import classification_report, confusion_matrix, accuracy_score import warnings import os import joblib # 新增导入 warnings.filterwarnings(&#39;ignore&#39;) plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] plt.rcParams[&#39;axes.unicode_minus&#39;] = False sns.set_style("whitegrid") # ====================== 主程序开始 ====================== print("🚀 开始加载并处理特征数据...") df = pd.read_csv(&#39;extracted_features_with_domain.csv&#39;) source_data = df[df[&#39;domain&#39;] == &#39;source&#39;].copy() print(f"✅ 源域样本数: {len(source_data)}") X = source_data.drop(columns=[&#39;filename&#39;, &#39;label&#39;, &#39;domain&#39;]).values y = source_data[&#39;label&#39;].values labels = sorted(np.unique(y)) class_names = [&#39;Normal&#39;, &#39;Outer Race&#39;, &#39;Inner Race&#39;, &#39;Ball&#39;] # 标准化(仅用于需要标准化的模型) scaler = StandardScaler() X_scaled = scaler.fit_transform(X) # 定义分类器 models = { "Random Forest": RandomForestClassifier(n_estimators=100, random_state=42), "SVM_RBF": SVC(kernel=&#39;rbf&#39;, C=1.0, gamma=&#39;scale&#39;, probability=True, random_state=42), "MLP": MLPClassifier(hidden_layer_sizes=(128, 64), max_iter=500, alpha=1e-4, batch_size=32, early_stopping=True, random_state=42) } # 十折分层交叉验证 cv = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) results_summary = [] print("\n" + "="*60) print("🔥 正在进行 10-Fold Cross Validation 评估各分类器...") print("="*60) for name, model in models.items(): print(f"\n📌 正在评估 {name}...") # 是否需要标准化? X_use = X_scaled if name in ["SVM_RBF", "MLP"] else X all_y_true, all_y_pred = [], [] for train_idx, test_idx in cv.split(X_use, y): X_train_fold, X_test_fold = X_use[train_idx], X_use[test_idx] y_train_fold, y_test_fold = y[train_idx], y[test_idx] model.fit(X_train_fold, y_train_fold) y_pred_fold = model.predict(X_test_fold) all_y_true.extend(y_test_fold) all_y_pred.extend(y_pred_fold) # 计算总体 Accuracy acc = accuracy_score(all_y_true, all_y_pred) # 获取 classification report(宏平均) report = classification_report(all_y_true, all_y_pred, target_names=class_names, labels=labels, output_dict=True) precision_macro = report[&#39;macro avg&#39;][&#39;precision&#39;] recall_macro = report[&#39;macro avg&#39;][&#39;recall&#39;] f1_macro = report[&#39;macro avg&#39;][&#39;f1-score&#39;] # 存储结果 results_summary.append({ &#39;Model&#39;: name, &#39;Accuracy&#39;: acc, &#39;Precision&#39;: precision_macro, &#39;Recall&#39;: recall_macro, &#39;F1-Score&#39;: f1_macro }) # 输详细报告 print(f"🎯 {name} 10-Fold CV 结果:") print(f" Accuracy = {acc:.3f}") print(f" Precision = {precision_macro:.3f}") print(f" Recall = {recall_macro:.3f}") print(f" F1-Score = {f1_macro:.3f}") # 显示分类报告 print("\n📋 分类报告 (Classification Report):") print(classification_report(all_y_true, all_y_pred, target_names=class_names)) # 绘制混淆矩阵热力图 cm = confusion_matrix(all_y_true, all_y_pred, labels=labels) plt.figure(figsize=(6, 5)) sns.heatmap(cm, annot=True, fmt=&#39;d&#39;, cmap=&#39;Blues&#39;, xticklabels=class_names, yticklabels=class_names) plt.title(f&#39;Confusion Matrix - {name} (10-Fold CV)&#39;) plt.xlabel(&#39;Predicted&#39;) plt.ylabel(&#39;True&#39;) plt.tight_layout() plt.savefig(f&#39;cm_{name.replace(" ", "_")}_10fold.png&#39;, dpi=150) plt.show() # ====================== 综合对比表格 ====================== print("\n" + "="*60) print("🏆 所有模型综合性能对比(10-Fold CV)") print("="*60) summary_df = pd.DataFrame(results_summary).round(3) print(summary_df.to_string(index=False)) # 保存为 CSV summary_df.to_csv(&#39;model_comparison_summary_10fold_updated.csv&#39;, index=False) print("💾 综合结果已保存至: model_comparison_summary_10fold_updated.csv") # ====================== 可视化四项指标柱状图 ====================== metrics_plot = [&#39;Accuracy&#39;, &#39;Precision&#39;, &#39;Recall&#39;, &#39;F1-Score&#39;] colors = [&#39;#4E79A7&#39;, &#39;#F28E2B&#39;, &#39;#E15759&#39;] x_pos = np.arange(len(models)) fig, ax = plt.subplots(figsize=(10, 6)) width = 0.2 for i, metric in enumerate(metrics_plot): values = summary_df[metric].values bars = ax.bar(x_pos + i*width, values, width, label=metric, color=colors[i % len(colors)], alpha=0.8) # 添加数值标签 for bar, val in zip(bars, values): ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f"{val:.3f}", ha=&#39;center&#39;, va=&#39;bottom&#39;, fontsize=9, fontweight=&#39;bold&#39;) ax.set_xlabel(&#39;Model&#39;) ax.set_ylabel(&#39;Score&#39;) ax.set_title(&#39;Model Comparison on Source Domain (41D Features) - 10-Fold CV\nMetrics: Accuracy, Precision, Recall, F1-Score&#39;) ax.set_xticks(x_pos + width * 1.5) ax.set_xticklabels(summary_df[&#39;Model&#39;]) ax.set_ylim(0, 1.0) ax.legend() ax.grid(True, axis=&#39;y&#39;, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(&#39;model_performance_comparison_4metrics.png&#39;, dpi=150) plt.show() # ====================== 特征重要性(仅RF)====================== if &#39;Random Forest&#39; in models: # 训练完整随机森林模型 rf_model = RandomForestClassifier(n_estimators=100, random_state=42) rf_model.fit(X, y) # 使用全部源域训练 # ✅ 新增:保存模型及元数据 model_dir = "saved_models" os.makedirs(model_dir, exist_ok=True) # 创建保存目录 # 保存核心模型文件 joblib.dump( rf_model, os.path.join(model_dir, "random_forest_model.pkl") ) # 保存元数据(特征名、类别名、标签等) metadata = { &#39;feature_names&#39;: df.columns.drop([&#39;filename&#39;, &#39;label&#39;, &#39;domain&#39;]).tolist(), &#39;class_names&#39;: class_names, &#39;labels&#39;: sorted(np.unique(y)), &#39;preprocessing&#39;: &#39;none&#39; # 表示模型训练时未使用标准化 } joblib.dump( metadata, os.path.join(model_dir, "random_forest_metadata.pkl") ) print("📁 模型及元数据已保存至 saved_models/ 目录") # 绘制特征重要性 feat_importance = rf_model.feature_importances_ feature_names = df.columns.drop([&#39;filename&#39;, &#39;label&#39;, &#39;domain&#39;]) indices = np.argsort(feat_importance)[::-1][:20] plt.figure(figsize=(10, 6)) plt.barh([feature_names[i] for i in indices[::-1]], feat_importance[indices][::-1], color=&#39;steelblue&#39;) plt.xlabel(&#39;Feature Importance (Gini)&#39;) plt.title(&#39;Top 20 Important Features - Random Forest (Full Training Set)&#39;) plt.gca().invert_yaxis() plt.grid(True, axis=&#39;x&#39;, linestyle=&#39;--&#39;, alpha=0.5) plt.tight_layout() plt.savefig(&#39;rf_feature_importance_10fold.png&#39;, dpi=150) plt.show() print("\n🎉 所有任务完成!请查看生成的图表与CSV文件。") 给一个基于任务一任务二的任务三代码,可以解决任务三的要求,并且给多个可视化对比图
09-24
评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值