tf.segment_sum 分段计算函数

本文主要介绍了tf.segment_sum和tf.unsorted_segment_sum函数。tf.segment_sum对输入数据分割并按对应下标求和,segment_ids值为对data第一维度索引。tf.unsorted_segment_sum计算张量片段的和,segment_ids无需排序,若分段ID和为空,output[i]=0,负值会被删除。

tf.segment_sum

定义

tf.segment_sum(
    data,
    segment_ids,
    name=None
)

Computes the sum along segments of a tensor.
对输入数据data进行分割,并按对应下标进行求和。

输入数据

data为待分割数据
segment_ids为分割下标。segment_ids的大小应与data的第一个维度长度(k)相同。且segment_ids的值应小于k。
name是该操作命名。

详解

Computes a tensor such that (output_i = \sum_j data_j) where sum is over j such that segment_ids[j] == i.
If the sum is empty for a given segment ID i, output[i] = 0.
segment_ids中的值相当于对data数据第一维度的索引。输出数据的第一个维度(output[0])为segment_ids==0的下标所对应的data的数据的求和。

示例

data = [5 1 7 2 3 4 1 3]
segment_ids = [0 2 4 1 5 2 4 1]
output = tf.segment_sum(
    data,
    segment_ids,
    name=None
)
  • segment_ids的大小与data第一维度相同。(=8)且segment_ids的每个值<8。
  • output[0] = data[segment_ids==0] = 5
  • output[2] = data[segment_ids==2] = 1+4 = 5
  • 对于segment_ids没有指定的下标(output[3])直接等于0,output[3]=0

 

tf.unsorted_segment_sum函数

tf.unsorted_segment_sum(
    data,
    segment_ids,
    num_segments,
    name=None
)

计算张量片段的和. 

计算一个张量,使得 (output[i] = sum_{j...} data[j...] 总和超过元组 j...,例如,segment_ids[j...] == i.与 SegmentSum 不同,segment_ids 不需要排序,不需要覆盖整个有效值范围内的所有值.

如果给定段 ID i 的和为空,则 output[i] = 0.如果给定的分段 ID i 为负值,则该值将被删除并且不会被添加到该段的总和中.

num_segments 应等于不同的段 ID 的数量.

TensorFlow函数

函数参数:

  • data:一个 Tensor,必须是下列类型之一:float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,half,uint32,uint64.
  • segment_ids:一个 Tensor,必须是以下类型之一:int32,int64,张量的形状是一个 data.shape 的前缀.
  • num_segments:一个 Tensor,必须是以下类型之一:int32,int64.
  • name:操作的名称(可选).

函数返回值:

tf.unsorted_segment_sum函数返回一个 Tensor,它与 data 的类型相同.

import os import scipy.io as sio from scipy.signal import resample, butter, lfilter import numpy as np import pandas as pd import re from collections import Counter from imblearn.over_sampling import RandomOverSampler import matplotlib.pyplot as plt import tensorflow as tf from tensorflow.keras import layers, models, regularizers from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler, LabelEncoder from sklearn.metrics import accuracy_score # --- 设置基础文件路径 --- # 请将此路径替换为您的“源域数据集”所在路径 BASE_PATH = r'C:\Users\bzxfsj\Desktop\E题\数据集\源域数据集' PROCESSED_DATA_FILE = 'processed_source_data.npz' BALANCED_DATA_FILE = 'balanced_source_data.npz' FINAL_FEATURES_FILE = 'final_features_enhanced.npz' ATTENTION_FEATURES_FILE = 'attention_selected_features.npz' # 定义故障类型到标签的映射 FAULT_MAP = { 'OR': 'OuterRaceFault', 'IR': 'InnerRaceFault', 'B': 'BallFault', 'N': 'Normal' } # 目标采样频率 (Hz) TARGET_FS = 32000 # 轴承参数 (根据附件1的表1,SKF6205和SKF6203) BEARING_PARAMS = { 'SKF6205': {'Nd': 9, 'd': 0.3126, 'D': 1.537}, 'SKF6203': {'Nd': 9, 'd': 0.2656, 'D': 1.122} } # --- 1. 数据预处理与样本平衡 --- def butter_lowpass_filter(data, cutoff, fs, order=5): """ 对信号应用巴特沃斯低通滤波器。 """ nyquist = 0.5 * fs normal_cutoff = cutoff / nyquist if normal_cutoff >= 1: return data b, a = butter(order, normal_cutoff, btype='low', analog=False) y = lfilter(b, a, data) return y def preprocess_and_balance_data(): """ 遍历文件夹,对所有信号进行去噪、重采样、分段、归一化和样本平衡。 """ resampled_data_list = labels = rpm_values = signal_parts = print("开始遍历文件夹并处理所有信号(DE, FE, BA)...") files_processed_count = 0 for root, dirs, files in os.walk(BASE_PATH): for file in files: if file.endswith('.mat'): file_path = os.path.join(root, file) relative_path = os.path.relpath(file_path, BASE_PATH) parts = relative_path.split(os.sep) sampling_rate = None if 'kHz' in relative_path: fs_str = ''.join(filter(str.isdigit, relative_path)) if fs_str: sampling_rate = int(fs_str) * 1000 fault_type_code = parts[1] if 'Normal' in relative_path: fault_type = 'Normal' else: fault_type = FAULT_MAP.get(fault_type_code, 'Unknown') if sampling_rate: try: mat_data = sio.loadmat(file_path) rpm = None match = re.search(r'\((\d+)rpm\)', file) if match: rpm = int(match.group(1)) else: rpm_keys = if rpm_keys and np.size(mat_data[rpm_keys]) > 0: rpm = mat_data[rpm_keys].flatten() for part in: signal_keys = [key for key in mat_data.keys() if f'{part}_time' in key or f'{part}_time_' in key] if not signal_keys: continue signal = mat_data[signal_keys].flatten() if len(signal) > 0: cutoff_freq = sampling_rate * 0.45 denoised_signal = butter_lowpass_filter(signal, cutoff_freq, sampling_rate) num_samples_resampled = int(len(denoised_signal) * TARGET_FS / sampling_rate) resampled_signal = resample(denoised_signal, num_samples_resampled) resampled_data_list.append(resampled_signal) labels.append(fault_type) rpm_values.append(rpm) signal_parts.append(part) files_processed_count += 1 except Exception as e: print(f"处理文件 {file_path} 时出错: {e}. 跳过...") print(f"\n成功处理并重采样 {files_processed_count} 个样本。") # 信号分段与归一化 segment_length = 2048 all_segments = all_labels = all_rpms = all_signal_parts = print("\n开始进行信号分段与归一化...") for i, signal in enumerate(resampled_data_list): num_segments = len(signal) // segment_length for j in range(num_segments): segment = signal[j*segment_length:(j+1)*segment_length] min_val = np.min(segment) max_val = np.max(segment) if (max_val - min_val) > 1e-6: normalized_segment = (segment - min_val) / (max_val - min_val) else: normalized_segment = segment all_segments.append(normalized_segment) all_labels.append(labels[i]) all_rpms.append(rpm_values[i]) all_signal_parts.append(signal_parts[i]) all_segments = np.array(all_segments) all_labels = np.array(all_labels) all_rpms = np.array(all_rpms) all_signal_parts = np.array(all_signal_parts) print(f"分段后总样本数:{len(all_segments)}") # 样本平衡:使用随机过采样 print("\n开始进行样本平衡...") ros = RandomOverSampler(random_state=42) X_resampled, y_resampled = ros.fit_resample(all_segments.reshape(-1, segment_length), all_labels) original_indices = ros.sample_indices_ rpms_balanced = all_rpms[original_indices] signal_parts_balanced = all_signal_parts[original_indices] print("\n样本平衡后样本分布:", Counter(y_resampled)) print(f"平衡后总样本数:{len(y_resampled)}") np.savez_compressed( BALANCED_DATA_FILE, segments=X_resampled, labels=y_resampled, rpms=rpms_balanced, signal_parts=signal_parts_balanced ) print(f"\n平衡后的数据已保存至 '{BALANCED_DATA_FILE}' 文件中。") return X_resampled, y_resampled, rpms_balanced, signal_parts_balanced # --- 2. 综合故障特征工程 --- def calculate_fault_frequencies(rpm, bearing_type): """根据RPM和轴承型号计算故障特征频率 (Hz)。""" if rpm is None: return {'BPFO': 0, 'BPFI': 0, 'BSF': 0} params = BEARING_PARAMS.get(bearing_type) if not params: return {'BPFO': 0, 'BPFI': 0, 'BSF': 0} fr = rpm / 60.0 Nd = params['Nd'] d = params['d'] D = params bpfo = fr * (Nd / 2) * (1 - d / D) bpfi = fr * (Nd / 2) * (1 + d / D) bsf = fr * (D / d) * (1 - (d / D)**2) return {'BPFO': bpfo, 'BPFI': bpfi, 'BSF': bsf} def extract_features(segment, rpm, sensor_part): """提取一个信号段的综合特征。""" # 时域特征 abs_segment = np.abs(segment) mean_abs = np.mean(abs_segment) rms_val = np.sqrt(np.mean(segment**2)) std_val = np.std(segment) peak_val = np.max(abs_segment) crest_factor = peak_val / rms_val if rms_val!= 0 else 0 form_factor = rms_val / mean_abs if mean_abs!= 0 else 0 kurtosis_val = np.mean((segment - np.mean(segment))**4) / std_val**4 if std_val!= 0 else 0 impulse_factor = peak_val / mean_abs if mean_abs!= 0 else 0 margin_factor = peak_val / np.mean(np.sqrt(abs_segment))**2 if np.mean(np.sqrt(abs_segment))**2!= 0 else 0 # 频域特征 N = len(segment) yf = np.fft.fft(segment) xf = np.fft.fftfreq(N, 1 / TARGET_FS)[:N//2] power_spectrum = 2.0 / N * np.abs(yf[0:N//2]) if np.sum(power_spectrum) > 1e-6: freq_centroid = np.sum(xf * power_spectrum) / np.sum(power_spectrum) freq_variance = np.sqrt(np.sum(((xf - freq_centroid)**2) * power_spectrum) / np.sum(power_spectrum)) else: freq_centroid = 0 freq_variance = 0 # 故障特征频率相关特征 bearing_type = 'SKF6205' if sensor_part == 'DE' else 'SKF6203' fault_freqs = calculate_fault_frequencies(rpm, bearing_type) def get_freq_feature(freq, spectrum, freqs, tolerance=5): if freq == 0: return 0 freq_indices = np.where((freqs >= freq - tolerance) & (freqs <= freq + tolerance)) return np.sum(spectrum[freq_indices]) bpfo_energy = get_freq_feature(fault_freqs, power_spectrum, xf) bpfi_energy = get_freq_feature(fault_freqs, power_spectrum, xf) bsf_energy = get_freq_feature(fault_freqs, power_spectrum, xf) features = [ mean_abs, rms_val, std_val, peak_val, crest_factor, form_factor, kurtosis_val, impulse_factor, margin_factor, freq_centroid, freq_variance, bpfo_energy, bpfi_energy, bsf_energy ] return features def process_and_extract_features(): """ 加载平衡后的数据并提取综合特征。 """ try: balanced_data = np.load(BALANCED_DATA_FILE, allow_pickle=True) segments_balanced = balanced_data['segments'] labels_balanced = balanced_data['labels'] signal_parts = balanced_data['signal_parts'] rpms = balanced_data['rpms'] print(f"\n成功加载 '{BALANCED_DATA_FILE}' 文件。") except FileNotFoundError: print(f"错误: 找不到 '{BALANCED_DATA_FILE}' 文件,请确保第一步代码已成功运行。") return None, None print("\n开始提取增强特征...") extracted_features = for i in range(len(segments_balanced)): features_vector = extract_features( segments_balanced[i], rpms[i], signal_parts[i] ) extracted_features.append(features_vector) extracted_features = np.array(extracted_features) feature_names = [ 'mean_abs', 'rms', 'std', 'peak', 'crest_factor', 'form_factor', 'kurtosis', 'impulse_factor', 'margin_factor', 'freq_centroid', 'freq_variance', 'bpfo_energy', 'bpfi_energy', 'bsf_energy' ] print("\n增强特征提取完成。") print(f"提取的特征矩阵形状: {extracted_features.shape}") np.savez_compressed( FINAL_FEATURES_FILE, features=extracted_features, labels=labels_balanced, feature_names=np.array(feature_names, dtype=object) ) print(f"增强后的特征和标签已保存至 '{FINAL_FEATURES_FILE}' 文件中。") return extracted_features, labels_balanced # --- 3. 构建与训练带注意力机制的深度神经网络 --- def build_and_train_model(): """ 加载特征数据,构建注意力模型,进行训练和评估,并进行特征选择。 """ try: final_data = np.load(FINAL_FEATURES_FILE, allow_pickle=True) features = final_data['features'] labels = final_data['labels'] feature_names = final_data['feature_names'] print(f"\n成功加载 '{FINAL_FEATURES_FILE}' 文件。") except FileNotFoundError: print(f"错误: 找不到 '{FINAL_FEATURES_FILE}' 文件,请确保第二步代码已成功运行。") return # 数据标准化与标签编码 scaler = StandardScaler() features_scaled = scaler.fit_transform(features) print("\n增强特征数据已标准化。") le = LabelEncoder() labels_encoded = le.fit_transform(labels) num_classes = len(np.unique(labels_encoded)) # 划分数据集 X_train, X_test, y_train, y_test = train_test_split( features_scaled, labels_encoded, test_size=0.3, random_state=42, stratify=labels_encoded ) # 构建注意力机制模型 input_dim = X_train.shape[1] inputs = layers.Input(shape=(input_dim,)) attention_weights = layers.Dense( input_dim, activation='sigmoid', name='attention_weights', kernel_regularizer=regularizers.l1(0.01) )(inputs) weighted_inputs = layers.Multiply()([inputs, attention_weights]) x = layers.Dense(64, activation='relu')(weighted_inputs) outputs = layers.Dense(num_classes, activation='softmax')(x) model = models.Model(inputs=inputs, outputs=outputs) # 编译模型 model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) model.summary() # 训练注意力模型 print("\n开始训练注意力模型...") history = model.fit(X_train, y_train, epochs=10, batch_size=64, verbose=1) print("注意力模型训练完成。") # 评估模型性能 loss, accuracy = model.evaluate(X_test, y_test, verbose=0) print(f"\n模型在测试集上的准确率: {accuracy:.4f}") # 提取注意力权重并进行稀疏特征选择 attention_model = models.Model(inputs=model.input, outputs=model.get_layer('attention_weights').output) attention_scores = np.mean(attention_model.predict(features_scaled), axis=0) print("\n各特征的注意力分数(重要性):") for name, score in sorted(zip(feature_names, attention_scores), key=lambda x: x[1], reverse=True): print(f"{name}: {score:.4f}") selected_indices = np.where(attention_scores > 0.6) selected_feature_names = feature_names[selected_indices] print(f"\n已选择 {len(selected_indices)} 个重要特征: {selected_feature_names}") np.savez_compressed( ATTENTION_FEATURES_FILE, features=features[:, selected_indices], labels=labels ) print(f"稀疏特征已保存至 '{ATTENTION_FEATURES_FILE}' 文件中。") # --- 主程序执行 --- if __name__ == "__main__": # 执行第一步:数据预处理与样本平衡 segments, labels, rpms, signal_parts = preprocess_and_balance_data() # 执行第二步:综合故障特征工程 features, labels = process_and_extract_features() # 执行第三步:构建与训练注意力模型 if features is not None and labels is not None: build_and_train_model()前面的错误我都给你改正了import os import numpy as np import scipy.io as sio import re import random from scipy.signal import welch from scipy.stats import kurtosis # ---------------- 参数设置 ---------------- RAW_DATA_DIR = "source_data" PROCESSED_SOURCE_FILE = "processed_source_data.npz" BALANCED_SOURCE_FILE = "balanced_source_data.npz" FINAL_FEATURES_FILE = "enhanced_features_data.npz" # 每段信号的长度 SEGMENT_LENGTH = 1024 # 每类最大样本数 MAX_SAMPLES_PER_CLASS = 1000 # 随机种子 RANDOM_SEED = 42 random.seed(RANDOM_SEED) np.random.seed(RANDOM_SEED) # ---------------- 工具函数 ---------------- def segment_signal(signal, segment_length=SEGMENT_LENGTH):     """把长信号切成多段"""     num_segments = len(signal) // segment_length     segments = []     for i in range(num_segments):         start = i * segment_length         end = start + segment_length         segments.append(signal[start:end])     return segments def extract_features(signal, rpm=None, part="DE"):     """提取时域 + 频域 + 包含轴承频率能量特征"""     # ---- 时域特征 ----     mean_abs = np.mean(np.abs(signal))     rms = np.sqrt(np.mean(signal ** 2))     std = np.std(signal)     peak = np.max(np.abs(signal))     crest_factor = peak / rms if rms != 0 else 0     form_factor = rms / mean_abs if mean_abs != 0 else 0     kurt = kurtosis(signal)     impulse_factor = peak / mean_abs if mean_abs != 0 else 0     margin_factor = peak / (np.mean(np.sqrt(np.abs(signal))) ** 2 + 1e-8)     # ---- 频域特征 ----     f, Pxx = welch(signal, fs=12000, nperseg=256)     freq_centroid = np.sum(f * Pxx) / np.sum(Pxx)     freq_variance = np.sum(((f - freq_centroid) ** 2) * Pxx) / np.sum(Pxx)     # ---- 包含轴承频率能量特征 ----     bpfo_energy = 0     bpfi_energy = 0     bsf_energy = 0     if rpm is not None:         bpfo = rpm / 60 * 0.4         bpfi = rpm / 60 * 0.6         bsf = rpm / 60 * 0.2         bpfo_energy = np.sum(Pxx[(f > bpfo - 5) & (f < bpfo + 5)])         bpfi_energy = np.sum(Pxx[(f > bpfi - 5) & (f < bpfi + 5)])         bsf_energy = np.sum(Pxx[(f > bsf - 5) & (f < bsf + 5)])     return [         mean_abs, rms, std, peak, crest_factor, form_factor,         kurt, impulse_factor, margin_factor,         freq_centroid, freq_variance,         bpfo_energy, bpfi_energy, bsf_energy     ] # ---------------- 主流程函数 ---------------- def preprocess_and_balance_data(segment_length=SEGMENT_LENGTH):     """1. 读取原始数据,切片并保存"""     if os.path.exists(PROCESSED_SOURCE_FILE):         print(f"已存在 {PROCESSED_SOURCE_FILE},直接加载。")         data = np.load(PROCESSED_SOURCE_FILE, allow_pickle=True)         return data["segments"], data["labels"], data["rpms"], data["parts"]     print("开始处理原始数据...")     all_segments, all_labels, all_rpms, all_parts = [], [], [], []     for root, dirs, files in os.walk(RAW_DATA_DIR):         for file in files:             if file.endswith(".mat"):                 file_path = os.path.join(root, file)                 try:                     mat_data = sio.loadmat(file_path)                 except Exception as e:                     print(f"加载失败 {file_path}: {e}")                     continue                 # ---- 提取 rpm ----                 rpm = None                 match = re.search(r'\((\d+)rpm\)', file)                 if match:                     rpm = int(match.group(1))                 else:                     rpm_keys = [key for key in mat_data.keys() if 'rpm' in key.lower()]                     if rpm_keys and np.size(mat_data[rpm_keys[0]]) > 0:                         rpm = mat_data[rpm_keys[0]].flatten()[0]                     else:                         rpm = None                 # ---- 提取信号 ----                 for part in ['DE', 'FE', 'BA']:                     signal_keys = [key for key in mat_data.keys() if f'{part}_time' in key or f'{part}_time_' in key]                     if not signal_keys:                         continue                     signal = mat_data[signal_keys[0]].flatten()                     # 切片                     segments = segment_signal(signal, segment_length)                     label = os.path.basename(root)  # 文件夹名作为标签                     all_segments.extend(segments)                     all_labels.extend([label] * len(segments))                     all_rpms.extend([rpm] * len(segments))                     all_parts.extend([part] * len(segments))     # 保存     np.savez_compressed(         PROCESSED_SOURCE_FILE,         segments=np.array(all_segments, dtype=object),         labels=np.array(all_labels, dtype=object),         rpms=np.array(all_rpms, dtype=object),         parts=np.array(all_parts, dtype=object)     )     print(f"处理后的数据保存至 {PROCESSED_SOURCE_FILE}")     return all_segments, all_labels, all_rpms, all_parts def balance_data():     """2. 平衡样本数量"""     if os.path.exists(BALANCED_SOURCE_FILE):         print(f"已存在 {BALANCED_SOURCE_FILE},直接加载。")         data = np.load(BALANCED_SOURCE_FILE, allow_pickle=True)         return data["segments"], data["labels"], data["rpms"], data["parts"]     print("开始平衡数据...")     segments, labels, rpms, parts = preprocess_and_balance_data()     label_to_indices = {}     for i, label in enumerate(labels):         label_to_indices.setdefault(label, []).append(i)     balanced_segments, balanced_labels, balanced_rpms, balanced_parts = [], [], [], []     for label, indices in label_to_indices.items():         if len(indices) > MAX_SAMPLES_PER_CLASS:             indices = random.sample(indices, MAX_SAMPLES_PER_CLASS)         for idx in indices:             balanced_segments.append(segments[idx])             balanced_labels.append(labels[idx])             balanced_rpms.append(rpms[idx])             balanced_parts.append(parts[idx])     np.savez_compressed(         BALANCED_SOURCE_FILE,         segments=np.array(balanced_segments, dtype=object),         labels=np.array(balanced_labels, dtype=object),         rpms=np.array(balanced_rpms, dtype=object),         parts=np.array(balanced_parts, dtype=object)     )     print(f"平衡数据保存至 {BALANCED_SOURCE_FILE}")     return balanced_segments, balanced_labels, balanced_rpms, balanced_parts def process_and_extract_features():     """3. 提取特征"""     if os.path.exists(FINAL_FEATURES_FILE):         print(f"已存在 {FINAL_FEATURES_FILE},直接加载。")         data = np.load(FINAL_FEATURES_FILE, allow_pickle=True)         return data["features"], data["labels"], data["feature_names"]     segments_balanced, labels_balanced, rpms, signal_parts = balance_data()     print("\n开始提取增强特征...")     extracted_features = []     for i in range(len(segments_balanced)):         features_vector = extract_features(             segments_balanced[i],             rpms[i],             signal_parts[i]         )         extracted_features.append(features_vector)     extracted_features = np.array(extracted_features)     feature_names = [         'mean_abs', 'rms', 'std', 'peak', 'crest_factor', 'form_factor',         'kurtosis', 'impulse_factor', 'margin_factor',         'freq_centroid', 'freq_variance',         'bpfo_energy', 'bpfi_energy', 'bsf_energy'     ]     print("\n增强特征提取完成。")     print(f"特征矩阵形状: {extracted_features.shape}")     np.savez_compressed(         FINAL_FEATURES_FILE,         features=extracted_features,         labels=labels_balanced,         feature_names=np.array(feature_names, dtype=object)     )     print(f"增强后的特征和标签已保存至 '{FINAL_FEATURES_FILE}' 文件中。")     return extracted_features, labels_balanced, feature_names # ------------------- 主函数 ------------------- if __name__ == "__main__":     features, labels, feature_names = process_and_extract_features()     print("流程完成。") 把这个融合到你给的代码,重新给出完整的可运行的代码
最新发布
09-25
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值