# core/training_manager.py
import os
import gc
import time
import json
import numpy as np
import tensorflow as tf
import h5py
from datetime import datetime
from sklearn.model_selection import train_test_split
# 修复导入 - 直接从 utils 导入
try:
from ..utils.callbacks import AdvancedMemoryMonitor, TrainingProgressLogger
except (ImportError, ValueError):
import sys
import os
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, current_dir)
from utils.callbacks import AdvancedMemoryMonitor, TrainingProgressLogger
class TrainingManager:
"""训练流程管理器"""
def __init__(self, hdf5_path, url, data_version, model_config_name, logger):
self.hdf5_path = hdf5_path
self.url = url
self.data_version = data_version # 改为 data_version
self.model_config_name = model_config_name # 新增模型配置名称
self.logger = logger
self.results = {}
self.start_time = None
self.evaluator = ModelEvaluator(logger)
# 尝试导入增强评估器
try:
from core.enhanced_evaluator import EnhancedEvaluator
self.enhanced_evaluator = EnhancedEvaluator(self.evaluator)
self.enhanced_available = True
self.logger.info("增强评估器可用")
except ImportError:
self.enhanced_evaluator = None
self.enhanced_available = False
self.logger.warning("增强评估器不可用,使用基础评估")
# 初始化报告生成器
try:
from core.report_generator import WordReportGenerator
self.report_generator = WordReportGenerator(logger, None)
self.report_available = True
self.logger.info("Word报告功能可用")
except ImportError:
self.report_available = False
self.logger.warning("Word报告功能不可用,请安装python-docx")
def validate_data_integrity(self):
"""验证数据完整性"""
from config.training_config import TrainingConfig
try:
with h5py.File(self.hdf5_path, 'r') as f:
num_samples = len(f['trdata'])
input_shape = f['trdata'].shape[1:]
# 检查所有需要的标签
required_keys = []
for hl in TrainingConfig.LHSS:
for i in TrainingConfig.ISS:
required_keys.extend([f'y_trdata_{hl}_c_{i}', f'y_trdata_{hl}_r_{i}'])
missing_keys = [key for key in required_keys if key not in f]
if missing_keys:
raise ValueError(f"缺失标签: {missing_keys}")
return num_samples, input_shape
except Exception as e:
self.logger.error(f"数据完整性检查失败: {e}")
raise
def get_model_structure_info(self, model):
"""获取模型结构信息 - 增强版"""
if not model:
return None
# 基础层信息
structure_info = {
'layers': [],
'total_params': model.count_params() if hasattr(model, 'count_params') else 0,
'trainable_params': 0,
'non_trainable_params': 0
}
# 提取每层信息
for i, layer in enumerate(model.layers):
layer_info = LayerInfoExtractor.get_layer_info(layer, i)
structure_info['layers'].append(layer_info)
# 累加参数
if layer.trainable:
structure_info['trainable_params'] += layer_info.get('params', 0)
else:
structure_info['non_trainable_params'] += layer_info.get('params', 0)
# 添加模型复杂度分析
try:
complexity = ModelAnalyzer.analyze_model_complexity(model)
structure_info['complexity'] = complexity
except Exception as e:
self.logger.warning(f"模型复杂度分析失败: {e}")
# 添加内存使用估算
try:
memory_usage = ModelAnalyzer.analyze_memory_usage(model)
structure_info['memory_usage'] = memory_usage
except Exception as e:
self.logger.warning(f"内存使用分析失败: {e}")
# 添加计算复杂度分析
try:
computational_complexity = ModelAnalyzer.analyze_computational_complexity(model)
structure_info['computational_complexity'] = computational_complexity
except Exception as e:
self.logger.warning(f"计算复杂度分析失败: {e}")
# 添加模型结构总结
try:
structure_summary = self._generate_structure_summary(structure_info)
structure_info['summary'] = structure_summary
except Exception as e:
self.logger.warning(f"模型结构总结生成失败: {e}")
return structure_info
def _generate_structure_summary(self, structure_info):
"""生成模型结构总结"""
total_layers = len(structure_info.get('layers', []))
total_params = structure_info.get('total_params', 0)
trainable_ratio = structure_info.get('trainable_params', 0) / total_params if total_params > 0 else 0
# 统计不同层类型的数量
layer_types = {}
for layer in structure_info.get('layers', []):
layer_type = layer.get('type', 'Unknown')
layer_types[layer_type] = layer_types.get(layer_type, 0) + 1
# 计算模型深度(排除特定层)
excluded_types = ['InputLayer', 'Dropout', 'BatchNormalization', 'Activation']
model_depth = len([layer for layer in structure_info.get('layers', [])
if layer.get('type') not in excluded_types])
return {
'total_layers': total_layers,
'total_params_formatted': f"{total_params:,}",
'trainable_ratio': trainable_ratio,
'trainable_ratio_percent': f"{trainable_ratio:.1%}",
'layer_type_distribution': layer_types,
'model_depth': model_depth,
'has_rnn': any(layer.get('type') in ['LSTM', 'GRU', 'SimpleRNN'] for layer in structure_info.get('layers', [])),
'has_cnn': any(layer.get('type') in ['Conv1D', 'Conv2D'] for layer in structure_info.get('layers', [])),
'has_attention': any('attention' in layer.get('type', '').lower() for layer in structure_info.get('layers', []))
}
def create_callbacks(self, model_name, model_type, hl, i, monitor='val_loss',
checkpoint_freq=10, initial_epoch=0):
"""创建回调函数 - 增强版,支持检查点保存"""
base_dir = os.path.join(self.url, 'data', 'modle')
os.makedirs(base_dir, exist_ok=True)
# 1. 最佳模型保存(清晰命名)
best_model_path = os.path.join(
base_dir,
f'{self.data_version}_{self.model_config_name}_{hl}_{i}d_{model_type}_best.keras'
)
# 2. 检查点保存路径
checkpoint_dir = os.path.join(base_dir, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(
checkpoint_dir,
f'{self.data_version}_{self.model_config_name}_{hl}_{i}d_{model_type}_epoch_{{epoch:03d}}.keras'
)
callbacks = [
# 最佳模型保存(保持原有逻辑)
tf.keras.callbacks.ModelCheckpoint(
filepath=best_model_path,
monitor=monitor,
save_best_only=True,
save_weights_only=False,
verbose=1
),
# 定期保存检查点(新增)
tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
monitor=monitor,
save_best_only=False,
save_weights_only=False,
save_freq='epoch', # 每个epoch都保存
verbose=0
),
tf.keras.callbacks.EarlyStopping(
monitor=monitor,
patience=50,
restore_best_weights=True,
verbose=1
),
tf.keras.callbacks.ReduceLROnPlateau(
monitor=monitor,
factor=0.5,
patience=20,
min_lr=1e-7,
verbose=1
),
AdvancedMemoryMonitor(),
TrainingProgressLogger(self.logger)
]
# 3. 训练进度保存回调(新增)
progress_file = os.path.join(
base_dir, 'training_progress.json'
)
# 自定义回调用于保存训练进度
class TrainingProgressCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
try:
progress = {
'model': {
'name': model_name,
'type': model_type,
'hl': hl,
'i': i
},
'current_epoch': epoch + 1, # epoch是从0开始的
'last_epoch_time': time.strftime('%Y-%m-%d %H:%M:%S'),
'metrics': logs
}
# 读取现有进度
all_progress = {}
if os.path.exists(progress_file):
with open(progress_file, 'r') as f:
all_progress = json.load(f)
# 更新当前模型进度
model_key = f"{model_name}_{hl}_{i}_{model_type}"
all_progress[model_key] = progress
# 保存
with open(progress_file, 'w') as f:
json.dump(all_progress, f, indent=2)
except Exception as e:
self.logger.warning(f"保存训练进度失败: {e}")
progress_callback = TrainingProgressCallback()
progress_callback.logger = self.logger
callbacks.append(progress_callback)
return callbacks, best_model_path
def check_resume_training(self, model_name, model_type, hl, i):
"""检查是否可以继续训练"""
try:
# 1. 检查最佳模型文件(使用新命名规则)
best_model_path = os.path.join(
self.url, 'data', 'modle',
f'{self.data_version}_{self.model_config_name}_{hl}_{i}d_{model_type}_best.keras'
)
# 2. 检查检查点目录
checkpoint_dir = os.path.join(self.url, 'data', 'modle', 'checkpoints')
# 3. 检查训练进度文件
progress_file = os.path.join(self.url, 'data', 'modle', 'training_progress.json')
results = {
'can_resume': False,
'best_model_exists': os.path.exists(best_model_path),
'best_model_path': best_model_path,
'checkpoints_exist': os.path.exists(checkpoint_dir),
'latest_checkpoint': None,
'latest_epoch': 0,
'progress_exists': os.path.exists(progress_file)
}
# 如果存在最佳模型,检查是否可以加载
if results['best_model_exists']:
try:
# 尝试加载模型来验证完整性
test_model = tf.keras.models.load_model(best_model_path)
test_model.summary()
del test_model
tf.keras.backend.clear_session()
# 获取最新的检查点
if results['checkpoints_exist']:
checkpoint_pattern = f'{self.data_version}_{model_name}_{hl}_{i}d_{model_type}_epoch_*.keras'
checkpoint_files = []
for f in os.listdir(checkpoint_dir):
if f.startswith(f'{self.data_version}_{model_name}_{hl}_{i}d_{model_type}_epoch_'):
try:
# 从文件名提取epoch数
epoch_num = int(f.split('_epoch_')[1].split('.')[0])
checkpoint_files.append((epoch_num, os.path.join(checkpoint_dir, f)))
except:
continue
if checkpoint_files:
# 按epoch排序,获取最新的
checkpoint_files.sort(key=lambda x: x[0], reverse=True)
latest_epoch, latest_checkpoint = checkpoint_files[0]
results['latest_checkpoint'] = latest_checkpoint
results['latest_epoch'] = latest_epoch
results['can_resume'] = True
except Exception as e:
self.logger.warning(f"模型文件可能损坏: {e}")
results['best_model_exists'] = False
# 如果有训练进度文件,读取进度
if results['progress_exists']:
try:
with open(progress_file, 'r') as f:
progress_data = json.load(f)
model_key = f"{model_name}_{hl}_{i}_{model_type}"
if model_key in progress_data:
progress = progress_data[model_key]
results['progress_epoch'] = progress.get('current_epoch', 0)
results['progress_time'] = progress.get('last_epoch_time', '未知')
results['progress_metrics'] = progress.get('metrics', {})
except Exception as e:
self.logger.warning(f"读取训练进度失败: {e}")
return results
except Exception as e:
self.logger.error(f"检查继续训练状态失败: {e}")
return {'can_resume': False, 'error': str(e)}
def load_model_for_resume(self, checkpoint_path, initial_epoch=None):
"""加载模型继续训练"""
try:
self.logger.info(f"加载模型继续训练: {os.path.basename(checkpoint_path)}")
# 加载模型
model = tf.keras.models.load_model(checkpoint_path)
# 如果提供了initial_epoch,使用提供的值,否则从文件名推断
if initial_epoch is None:
# 尝试从文件名解析epoch
filename = os.path.basename(checkpoint_path)
try:
# 格式: ..._epoch_001.keras
epoch_part = filename.split('_epoch_')[1].split('.')[0]
initial_epoch = int(epoch_part)
except:
initial_epoch = 0
self.logger.info(f"模型已加载,将从第 {initial_epoch} 个 epoch 继续训练")
return model, initial_epoch
except Exception as e:
self.logger.error(f"加载模型失败: {e}")
return None, 0
def get_data_splits(self, num_samples, train_size=0.8, random_state=42):
"""获取数据划分"""
indices = np.arange(num_samples)
train_indices, test_indices = train_test_split(
indices,
train_size=train_size,
random_state=random_state,
shuffle=True
)
return train_indices, test_indices
def get_test_returns(self, test_indices, hl, i):
"""获取测试集的收益率数据"""
try:
# 使用收益率标签 y_trdata_yield_r_{i}
returns_key = f'y_trdata_yield_r_{i}'
with h5py.File(self.hdf5_path, 'r') as f:
if returns_key in f:
all_returns = f[returns_key][:]
test_returns = all_returns[test_indices]
self.logger.info(f"获取收益率数据成功: {returns_key}, 形状: {test_returns.shape}")
return test_returns
else:
# 如果找不到收益率标签,尝试其他可能的名字
possible_keys = [key for key in f.keys() if 'yield' in key or 'return' in key]
if possible_keys:
self.logger.warning(f"未找到{returns_key},使用替代键: {possible_keys[0]}")
all_returns = f[possible_keys[0]][:]
test_returns = all_returns[test_indices]
return test_returns
else:
# 如果都没有,生成默认数据
self.logger.warning("未找到任何收益率数据,使用默认数据")
return np.random.randn(len(test_indices)) * 0.02
except Exception as e:
self.logger.error(f"获取收益率数据失败: {e}")
# 返回默认数据以避免崩溃
return np.random.randn(len(test_indices)) * 0.02
def evaluate_and_report(self, model, test_dataset, test_steps, model_type, hl, i,
history, input_shape, num_samples, test_indices=None):
"""评估模型并生成报告 - 增强版"""
from config.training_config import TrainingConfig
try:
self.logger.info(f"开始评估模型 - {hl} {model_type} {i}日")
# 1. 进行预测
y_pred = model.predict(test_dataset, steps=test_steps, verbose=0)
# 2. 获取真实标签
if model_type == 'classification':
label_key = f'y_trdata_{hl}_c_{i}'
else: # 回归
label_key = f'y_trdata_{hl}_r_{i}'
with h5py.File(self.hdf5_path, 'r') as f:
all_labels = f[label_key][:]
if test_indices is not None:
y_true = all_labels[test_indices]
else:
# 如果没有提供test_indices,使用随机数据
if model_type == 'classification':
y_true = np.random.randint(0, 9, size=len(y_pred))
else:
y_true = np.random.randn(len(y_pred))
# 3. 获取收益率数据(分类和回归都需要)
returns = None
if test_indices is not None:
returns = self.get_test_returns(test_indices, hl, i)
else:
# 如果没有收益率数据,使用随机数据
returns = np.random.randn(len(y_pred)) * 0.02 # 2%波动
# 4. 根据模型类型进行评估
if model_type == 'classification':
# 处理分类预测
if y_pred.shape[1] > 1: # 多分类
y_pred_classes = np.argmax(y_pred, axis=1)
else: # 二分类
y_pred_classes = (y_pred > 0.5).astype(int).flatten()
# 确保y_true是整数类型
y_true_classes = y_true.astype(int)
# 截取相同数量的样本
n_samples = min(len(y_pred_classes), len(y_true_classes))
y_pred_classes = y_pred_classes[:n_samples]
y_true_classes = y_true_classes[:n_samples]
returns_subset = returns[:n_samples] if returns is not None else None
# 使用增强评估器(如果可用)
if self.enhanced_available and returns_subset is not None:
evaluation_results = self.enhanced_evaluator.evaluate_classification(
y_true_classes, y_pred_classes, returns_subset, f"{hl}_{i}_classification"
)
else:
# 回退到基础评估
evaluation_results = self.evaluator.evaluate_classification(
y_true_classes, y_pred_classes, model_type, f"{hl}_{i}_classification"
)
else: # 回归任务
y_pred_flat = y_pred.flatten()
y_true_flat = y_true.flatten()
# 截取相同数量
n_samples = min(len(y_pred_flat), len(y_true_flat), len(returns))
y_pred_flat = y_pred_flat[:n_samples]
y_true_flat = y_true_flat[:n_samples]
returns_subset = returns[:n_samples] if returns is not None else None
evaluation_results = self.evaluator.evaluate_regression(
y_true_flat, y_pred_flat, model_type, f"{hl}_{i}_regression"
)
# 使用增强评估器(如果可用)
if returns_subset is not None:
try:
# 导入增强回归分析器
from core.regression_analyzer import EnhancedRegressionAnalyzer
score_analyzer = EnhancedRegressionAnalyzer()
# 分析评价分数与真实收益的关系
score_analysis = score_analyzer.analyze_score_return_relationship(
y_pred_flat, returns_subset
)
# 将分析结果添加到评估结果中
if 'enhanced_analysis' not in evaluation_results:
evaluation_results['enhanced_analysis'] = {}
evaluation_results['enhanced_analysis'].update(score_analysis)
except ImportError as e:
self.logger.warning(f"无法导入增强回归分析器: {e}")
except Exception as e:
self.logger.warning(f"分析评价分数与收益关系失败: {e}")
# 5. 准备模型信息
# 从模型对象获取优化器和损失函数信息
optimizer_info = "Unknown"
loss_info = "Unknown"
if model and hasattr(model, 'optimizer'):
optimizer = model.optimizer
if hasattr(optimizer, '__class__'):
optimizer_info = optimizer.__class__.__name__
# 尝试获取学习率
if hasattr(optimizer, 'learning_rate'):
try:
lr = optimizer.learning_rate
if hasattr(lr, 'numpy'):
lr_value = lr.numpy()
else:
lr_value = lr
optimizer_info = f"{optimizer_info} (lr={lr_value:.2e})"
except:
pass
if model and hasattr(model, 'loss'):
loss = model.loss
if isinstance(loss, str):
loss_info = loss
elif hasattr(loss, '__name__'):
loss_info = loss.__name__
elif hasattr(loss, '__class__'):
loss_info = loss.__class__.__name__
model_info = {
'data_version': self.data_version, # 数据版本
'model_config_name': self.model_config_name, # 模型配置名称
'model_type': model_type,
'target': f"{hl}未来{i}日",
'input_shape': input_shape,
'num_samples': num_samples,
'train_ratio': TrainingConfig.TRAIN_SIZE,
'test_ratio': 1 - TrainingConfig.TRAIN_SIZE,
'project_name': '股票价格预测系统',
'hl': hl,
'i': i,
'training_params': {
'batch_size': TrainingConfig.BATCH_SIZE,
'epochs': TrainingConfig.EPOCHS,
'learning_rate': TrainingConfig.LEARNING_RATE,
'optimizer': optimizer_info,
'loss_function': loss_info
},
'model_structure': self.get_model_structure_info(model) if model else None
}
# 6. 生成Word报告
if self.report_available:
report_dir = os.path.join(self.url, 'data', 'reports')
os.makedirs(report_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 清晰的报告命名
report_filename = f"report_{self.data_version}_{self.model_config_name}_{hl}_{i}d_{model_type}_{timestamp}.docx"
report_path = os.path.join(report_dir, report_filename)
# 准备训练历史
training_history = {
'loss': history.history.get('loss', []),
'val_loss': history.history.get('val_loss', []),
'accuracy': history.history.get('accuracy', history.history.get('acc', [])),
'val_accuracy': history.history.get('val_accuracy', history.history.get('val_acc', [])),
'training_time': time.time() - self.start_time
}
success = self.report_generator.create_report(
model_info=model_info,
evaluation_results=evaluation_results,
training_history=training_history,
output_path=report_path
)
if success:
self.logger.info(f"Word报告已生成: {report_filename}")
else:
self.logger.warning("Word报告生成失败")
# 7. 保存评估结果到JSON
eval_dir = os.path.join(self.url, 'data', 'evaluations')
os.makedirs(eval_dir, exist_ok=True)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
# 清晰的评估结果命名
eval_filename = f"eval_{self.data_version}_{self.model_config_name}_{hl}_{i}d_{model_type}_{timestamp}.json"
eval_path = os.path.join(eval_dir, eval_filename)
# 保存完整的评估结果
with open(eval_path, 'w', encoding='utf-8') as f:
json.dump(evaluation_results, f, indent=2, ensure_ascii=False)
self.logger.info(f"评估结果已保存: {eval_path}")
return evaluation_results
except Exception as e:
self.logger.error(f"评估和报告生成失败: {e}")
import traceback
traceback.print_exc()
return None
def log_training_results(self, history, test_results, model_type, hl, i,
num_samples, input_shape, model=None,
test_dataset=None, test_steps=None, test_indices=None):
"""记录训练结果并生成评估报告"""
from config.training_config import TrainingConfig
separator = "=" * 60
if model_type == 'classification':
test_loss, test_acc = test_results
# 自适应获取准确率键名
acc_key = 'accuracy' if 'accuracy' in history.history else 'acc'
val_acc_key = f'val_{acc_key}'
train_acc = history.history[acc_key][-1] if acc_key in history.history else 0.0
val_acc = history.history[val_acc_key][-1] if val_acc_key in history.history else 0.0
best_val_acc = max(history.history[val_acc_key]) if val_acc_key in history.history else 0.0
self.logger.info(f'''
{separator}
数据{self.data_version}-【总量】: {num_samples}; 【形状】: {input_shape}
对{hl}未来{i}日【分类】模型训练完成
测试集 - 损失: {test_loss:.4f}, 准确率: {test_acc:.4f}
{separator}
训练历史:
- 最终训练准确率: {train_acc:.4f}
- 最终验证准确率: {val_acc:.4f}
- 最佳验证准确率: {best_val_acc:.4f}
{separator}
''')
else:
test_loss, test_mae, test_mse = test_results
self.logger.info(f'''
{separator}
数据{self.data_version}-【总量】: {num_samples}; 【形状】: {input_shape}
对{hl}未来{i}日【回归】模型训练完成
测试集 - 损失: {test_loss:.4f}, MAE: {test_mae:.4f}, MSE: {test_mse:.4f}
{separator}
训练历史:
- 最终训练损失: {history.history['loss'][-1]:.4f}
- 最终验证损失: {history.history['val_loss'][-1]:.4f}
- 最佳验证损失: {min(history.history['val_loss']):.4f}
{separator}
''')
# 如果提供了模型和测试数据,进行详细评估和生成报告
if model is not None and test_dataset is not None and test_steps is not None:
evaluation_results = self.evaluate_and_report(
model, test_dataset, test_steps, model_type, hl, i,
history, input_shape, num_samples, test_indices
)
if evaluation_results:
self.logger.info(f"模型评估完成,详细结果已保存")
def start_training_session(self):
"""开始训练会话"""
self.start_time = time.time()
self.logger.info(f"开始训练会话 - 版本: {self.data_version}, 时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
def end_training_session(self):
"""结束训练会话"""
if self.start_time:
total_time = time.time() - self.start_time
self.logger.info(f"训练会话结束 - 总用时: {total_time/3600:.2f} 小时")
def cleanup(self):
"""清理资源"""
tf.keras.backend.clear_session()
gc.collect()
import pandas as pd
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
confusion_matrix, classification_report,
mean_absolute_error, mean_squared_error, r2_score,
mean_absolute_percentage_error
)
class ModelEvaluator:
"""模型评估器"""
def __init__(self, logger):
self.logger = logger
self.results = {}
def evaluate_classification(self, y_true, y_pred, model_type, model_name):
"""评估分类模型"""
# 计算基础指标
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='weighted')
recall = recall_score(y_true, y_pred, average='weighted')
f1 = f1_score(y_true, y_pred, average='weighted')
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 计算分类报告
report = classification_report(y_true, y_pred, output_dict=True)
results = {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'confusion_matrix': cm.tolist(),
'classification_report': report,
'num_classes': len(np.unique(y_true)),
'num_samples': len(y_true)
}
self.results[model_name] = results
return results
def evaluate_regression(self, y_true, y_pred, model_type, model_name):
"""评估回归模型"""
# 计算基础指标
mae = mean_absolute_error(y_true, y_pred)
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
mape = mean_absolute_percentage_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
# 计算预测偏差统计
residuals = y_pred - y_true
residual_stats = {
'mean_residual': float(np.mean(residuals)),
'std_residual': float(np.std(residuals)),
'max_residual': float(np.max(np.abs(residuals))),
'median_residual': float(np.median(residuals))
}
results = {
'mae': mae,
'mse': mse,
'rmse': rmse,
'mape': mape,
'r2_score': r2,
'residual_stats': residual_stats,
'num_samples': len(y_true)
}
self.results[model_name] = results
return results
def save_results(self, filepath):
"""保存评估结果到JSON文件"""
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(self.results, f, indent=2, ensure_ascii=False)
self.logger.info(f"评估结果已保存: {filepath}")
class LayerInfoExtractor:
"""层信息提取器 - 安全地提取各种层的属性"""
@staticmethod
def get_layer_info(layer, index):
"""获取单层信息"""
layer_info = {
'index': index,
'name': layer.name,
'type': layer.__class__.__name__,
'trainable': layer.trainable
}
# 获取输出形状
layer_info['output_shape'] = LayerInfoExtractor._get_layer_shape(layer)
# 获取参数数量
layer_info['params'] = LayerInfoExtractor._get_layer_params(layer)
# 获取其他属性
specific_attrs = LayerInfoExtractor._get_layer_specific_attrs(layer)
layer_info.update(specific_attrs)
return layer_info
@staticmethod
def _get_layer_shape(layer):
"""安全地获取层形状"""
# 对于InputLayer,使用input_shape
if layer.__class__.__name__ == 'InputLayer':
shape_attrs = ['input_shape', 'batch_input_shape', '_batch_input_shape']
for attr in shape_attrs:
if hasattr(layer, attr):
shape = getattr(layer, attr)
if shape is not None:
return str(shape)
return "InputLayer (形状未知)"
# 对于其他层,尝试多种方式
try:
# 首选方式:output_shape属性
if hasattr(layer, 'output_shape'):
shape = layer.output_shape
if shape is not None:
return str(shape)
except Exception as e:
pass
try:
# 备选方式:output属性
if hasattr(layer, 'output') and hasattr(layer.output, 'shape'):
shape = layer.output.shape
if shape is not None:
return str(shape)
except Exception as e:
pass
try:
# 次选方式:get_output_shape_at方法
if hasattr(layer, 'get_output_shape_at'):
shape = layer.get_output_shape_at(0)
if shape is not None:
return str(shape)
except Exception as e:
pass
return "无法获取形状"
@staticmethod
def _get_layer_params(layer):
"""安全地获取层参数数量"""
try:
return layer.count_params()
except Exception:
return 0
@staticmethod
def _get_layer_specific_attrs(layer):
"""获取层特定属性"""
attrs = {}
# 常用属性列表
common_attrs = [
'units', 'filters', 'kernel_size', 'strides', 'padding',
'activation', 'dropout', 'rate', 'pool_size', 'dilation_rate',
'recurrent_activation', 'return_sequences', 'num_heads', 'key_dim'
]
for attr in common_attrs:
try:
if hasattr(layer, attr):
value = getattr(layer, attr)
# 跳过None值
if value is None:
continue
# 特殊处理激活函数
if attr == 'activation':
attrs[attr] = LayerInfoExtractor._format_activation(value)
else:
attrs[attr] = value
except Exception:
continue
# 特殊处理某些层类型
layer_type = layer.__class__.__name__
if layer_type == 'BatchNormalization':
try:
attrs['momentum'] = layer.momentum
attrs['epsilon'] = layer.epsilon
except Exception:
pass
elif layer_type == 'Dropout':
try:
attrs['rate'] = layer.rate
except Exception:
pass
elif layer_type in ['LSTM', 'GRU', 'SimpleRNN']:
try:
attrs['dropout'] = layer.dropout
attrs['recurrent_dropout'] = layer.recurrent_dropout
except Exception:
pass
elif layer_type == 'MultiHeadAttention':
try:
attrs['num_heads'] = layer.num_heads
attrs['key_dim'] = layer.key_dim
except Exception:
pass
return attrs
@staticmethod
def _format_activation(activation):
"""格式化激活函数信息"""
if activation is None:
return 'linear'
if hasattr(activation, '__name__'):
return activation.__name__
elif hasattr(activation, '__class__'):
return activation.__class__.__name__
elif isinstance(activation, str):
return activation
else:
return str(activation)
class ModelAnalyzer:
"""模型分析器 - 提供更详细的模型分析功能"""
@staticmethod
def analyze_model_complexity(model):
"""分析模型复杂度"""
total_params = model.count_params()
# 按层类型统计参数
param_by_type = {}
layer_types = []
trainable_params = 0
non_trainable_params = 0
for layer in model.layers:
layer_type = layer.__class__.__name__
layer_types.append(layer_type)
try:
params = layer.count_params()
if layer_type not in param_by_type:
param_by_type[layer_type] = 0
param_by_type[layer_type] += params
if layer.trainable:
trainable_params += params
else:
non_trainable_params += params
except Exception:
pass
# 计算模型深度(排除输入输出层)
depth = len([layer for layer in model.layers
if layer.__class__.__name__ not in ['InputLayer', 'Dropout', 'BatchNormalization']])
# 找出参数最多的前3种层类型
top_layer_types = sorted(param_by_type.items(), key=lambda x: x[1], reverse=True)[:3]
return {
'total_params': total_params,
'trainable_params': trainable_params,
'non_trainable_params': non_trainable_params,
'param_by_type': param_by_type,
'top_layer_types': dict(top_layer_types),
'layer_types': list(set(layer_types)),
'num_layers': len(model.layers),
'model_depth': depth,
'trainable_ratio': trainable_params / total_params if total_params > 0 else 0,
'param_distribution': {
'param_by_type': param_by_type,
'top_3_layers': dict(top_layer_types)
}
}
@staticmethod
def analyze_memory_usage(model, batch_size=32):
"""估算模型内存使用"""
# 计算模型参数占用的内存
total_params = model.count_params()
param_memory = total_params * 4 # 假设使用float32
# 计算激活内存(估算)
activation_memory = 0
activation_by_layer = {}
for layer in model.layers:
try:
if hasattr(layer, 'output_shape'):
shape = layer.output_shape
if shape:
# 计算该层输出的元素数量
num_elements = 1
for dim in shape[1:]: # 跳过batch维度
if dim is not None:
num_elements *= dim
layer_activation_memory = num_elements * 4 * batch_size # float32
activation_memory += layer_activation_memory
activation_by_layer[layer.name] = {
'shape': shape,
'memory_mb': layer_activation_memory / (1024 * 1024)
}
except:
pass
# 计算梯度内存(通常与参数内存相当)
gradient_memory = param_memory
total_memory = param_memory + activation_memory + gradient_memory
return {
'parameter_memory_mb': param_memory / (1024 * 1024),
'activation_memory_mb': activation_memory / (1024 * 1024),
'gradient_memory_mb': gradient_memory / (1024 * 1024),
'total_memory_mb': total_memory / (1024 * 1024),
'batch_size': batch_size,
'activation_by_layer': activation_by_layer
}
@staticmethod
def analyze_computational_complexity(model):
"""分析计算复杂度"""
# 估算FLOPs(浮点运算次数)
total_flops = 0
flops_by_layer = {}
for layer in model.layers:
layer_type = layer.__class__.__name__
flops = 0
try:
if layer_type == 'Dense':
# Dense层:输入维度 * 输出维度 * 2 (乘加各一次)
if hasattr(layer, 'input_shape') and hasattr(layer, 'output_shape'):
input_dim = layer.input_shape[-1] if layer.input_shape else 0
output_dim = layer.output_shape[-1] if layer.output_shape else 0
flops = input_dim * output_dim * 2
elif layer_type == 'Conv1D':
# Conv1D层:输入长度 * 卷积核长度 * 输入通道 * 输出通道 * 2
if hasattr(layer, 'input_shape') and hasattr(layer, 'output_shape'):
input_shape = layer.input_shape
output_shape = layer.output_shape
if len(input_shape) >= 3 and len(output_shape) >= 2:
input_length = input_shape[1]
input_channels = input_shape[2]
output_channels = output_shape[-1]
if hasattr(layer, 'kernel_size'):
kernel_size = layer.kernel_size[0] if isinstance(layer.kernel_size, (list, tuple)) else layer.kernel_size
flops = input_length * kernel_size * input_channels * output_channels * 2
elif layer_type in ['LSTM', 'GRU']:
# RNN层:更复杂的计算,这里简化估算
if hasattr(layer, 'units'):
units = layer.units
if hasattr(layer, 'input_shape'):
input_dim = layer.input_shape[-1] if layer.input_shape else 0
# 简化的FLOPs估算:4 * units * (input_dim + units) 对于LSTM
flops = 4 * units * (input_dim + units)
if flops > 0:
total_flops += flops
flops_by_layer[layer.name] = {
'type': layer_type,
'flops': flops,
'flops_m': flops / 1e6 # 百万FLOPs
}
except Exception:
continue
return {
'total_flops': total_flops,
'total_flops_g': total_flops / 1e9, # 十亿FLOPs
'flops_by_layer': flops_by_layer,
'inference_time_estimate_ms': total_flops / 1e9 * 100 # 简化的时间估算
}
帮我检查代码
最新发布