InvalidArgumentError: seq_lens() > input.dims()[[Node: hidden/bidirectional_rnn/bw/ReverseSequence

博客讲述了TensorFlow使用中遇到的报错问题,问题出在tf.reverse_sequence函数,原因是对seq反转时未满足条件。检查发现限制文本最长长度为400后,保留每句话长度的数组未做相应限制。解决办法是将数组中最大值限制到400。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

报错信息:

Google了一下没有找到对应的解决方案,看了bidrection_rrn源码,发现问题出现在一个tf.reverse_sequence函数中。

https://www.tensorflow.org/api_docs/python/tf/reverse_sequence

错误的主要原因是在对seq进行反转时,

没有满足以下条件:

The elements of seq_lengths must obey seq_lengths[i] <= input.dims[seq_dim], and seq_lengths must be a vector of length input.dims[batch_dim]

看函数tf.reverse_sequence()中对seq_lengths表示要反转的seq长度。错误信息表示传入该函数的seq_lengths即要反转的长度超过了input本身的长度。

我检查了代码,发现在限制文本最长长度是400后,矩阵维度为(?,400,hidden_size),但是保留每句话长度的数组s

# 导入必要的库 import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR from torch.cuda import amp from torchvision import transforms, datasets, models from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, Subset import numpy as np import matplotlib.pyplot as plt import time import os from sklearn.metrics import confusion_matrix, classification_report, f1_score, recall_score import seaborn as sns import pandas as pd from tqdm import tqdm from PIL import Image import warnings import json import shutil import copy # 忽略特定警告 warnings.filterwarnings("ignore", category=UserWarning, module="torchvision") warnings.filterwarnings("ignore", category=FutureWarning) # 定义残差块 - 解决梯度消失问题 class ResidualBlock(nn.Module): def __init__(self, in_features): super(ResidualBlock, self).__init__() self.block = nn.Sequential( nn.Linear(in_features, in_features), nn.ReLU(), nn.BatchNorm1d(in_features), nn.Linear(in_features, in_features), nn.ReLU(), nn.BatchNorm1d(in_features) ) def forward(self, x): return x + self.block(x) # 改进的CNN模型类 - 添加注意力机制和残差连接 class AlzheimerCNN(nn.Module): def __init__(self, num_classes=4, pretrained=True): super(AlzheimerCNN, self).__init__() # 加载预训练的EfficientNet_b0模型 self.backbone = models.efficientnet_b0(pretrained=pretrained) # 添加通道注意力机制 - 增强关键特征提取 self.attention = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(1280, 1280 // 16), nn.ReLU(), nn.Linear(1280 // 16, 1280), nn.Sigmoid() ) # 冻结和解冻策略优化 - 增加可训练层数 for param in self.backbone.parameters(): param.requires_grad = False # 解冻最后8层(原为最后5层)- 提升模型表达能力 for param in self.backbone.features[-8:].parameters(): param.requires_grad = True # 分类头添加残差连接 - 提高特征复用能力 in_features = self.backbone.classifier[1].in_features self.classifier = nn.Sequential( nn.Dropout(0.3), nn.Linear(in_features, 512), nn.ReLU(), ResidualBlock(512), # 残差块 nn.BatchNorm1d(512), nn.Dropout(0.2), nn.Linear(512, num_classes) ) def forward(self, x): features = self.backbone.features(x) # 应用注意力机制 - 加强重要区域特征提取 attn_weights = self.attention(features).view(-1, 1280, 1, 1) attended_features = features * attn_weights pooled = nn.functional.adaptive_avg_pool2d(attended_features, (1, 1)) flattened = torch.flatten(pooled, 1) return self.classifier(flattened) # 改进的LSTM模型类 - 添加注意力和残差结构 class AlzheimerLSTM(nn.Module): def __init__(self, num_classes=4, hidden_size=128, num_layers=2): super(AlzheimerLSTM, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers # 增强的卷积特征提取层 self.conv_layers = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # 增加通道数 nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), # 增加通道数 nn.ReLU(), nn.MaxPool2d(2) ) # 计算卷积输出尺寸 self.conv_output_size = 128 # 特征维度 # LSTM层添加dropout - 减少过拟合风险 self.lstm = nn.LSTM( input_size=self.conv_output_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False, dropout=0.2 # 添加dropout ) # 增强分类器 - 添加残差连接 self.fc = nn.Sequential( nn.Linear(hidden_size, 256), # 增加维度 nn.ReLU(), ResidualBlock(256), # 添加残差块 nn.Dropout(0.3), nn.Linear(256, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, num_classes) ) def forward(self, x): batch_size = x.size(0) conv_out = self.conv_layers(x) # 正确形状转换: [batch_size, seq_len, features] conv_out = conv_out.flatten(2).permute(0, 2, 1) lstm_out, _ = self.lstm(conv_out) lstm_out = lstm_out[:, -1, :] return self.fc(lstm_out) # 新增的RNN模型类 - 使用SimpleRNN结构 class AlzheimerRNN(nn.Module): def __init__(self, num_classes=4, hidden_size=128, num_layers=1): super(AlzheimerRNN, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers # 卷积特征提取层 self.conv_layers = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.MaxPool2d(2) ) # 计算卷积输出尺寸 self.conv_output_size = 64 # 特征维度 # RNN层配置 self.rnn = nn.RNN( input_size=self.conv_output_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, bidirectional=False ) # 分类器 self.fc = nn.Sequential( nn.Linear(hidden_size, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, num_classes) ) def forward(self, x): batch_size = x.size(0) conv_out = self.conv_layers(x) # 正确形状转换: [batch_size, seq_len, features] conv_out = conv_out.flatten(2).permute(0, 2, 1) rnn_out, _ = self.rnn(conv_out) rnn_out = rnn_out[:, -1, :] return self.fc(rnn_out) # Focal Loss损失函数 - 解决类别不平衡问题 class FocalLoss(nn.Module): def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.reduction = reduction def forward(self, inputs, targets): # 计算交叉熵损失 ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) # 应用类别权重 if self.alpha is not None: alpha = self.alpha[targets] focal_loss = alpha * (1 - pt) ** self.gamma * ce_loss else: focal_loss = (1 - pt) ** self.gamma * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() return focal_loss # 获取数据集路径函数 def get_data_paths(): # 修改为您的实际路径 train_path = 'D:/Alzheimer_s Dataset/train' test_path = 'D:/Alzheimer_s Dataset/test' if not os.path.exists(train_path): raise FileNotFoundError(f"训练集路径不存在: {train_path}") if not os.path.exists(test_path): raise FileNotFoundError(f"测试集路径不存在: {test_path}") return train_path, test_path # 增强数据预处理函数 - 添加更多数据增强 def get_transforms(): # 训练数据增强流程 train_transform = transforms.Compose([ transforms.Resize((256, 256)), # 增加尺寸以便裁切 transforms.RandomResizedCrop(224), # 随机裁切 transforms.RandomHorizontalFlip(), transforms.RandomRotation(20), # 增加旋转角度 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 增强颜色扰动 transforms.RandomAffine(degrees=0, translate=(0.15, 0.15)), # 随机平移 transforms.RandomPerspective(distortion_scale=0.2, p=0.5), # 随机透视变换 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # 测试数据转换流程(无增强) test_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return train_transform, test_transform # 混合采样函数 - 解决类别不平衡问题 def apply_hybrid_sampling(dataset, class_counts): """结合过采样和欠采样处理不平衡数据""" # 计算最大样本数 max_count = max(class_counts) # 存储重采样索引 resampled_indices = [] # 处理每个类别 for class_idx in range(len(class_counts)): # 获取当前类别所有样本索引 class_indices = [i for i, (_, label) in enumerate(dataset.samples) if label == class_idx] count = len(class_indices) # 少数类过采样 if count < max_count: # 计算需要复制的次数 repeat_times = max_count // count # 复制样本索引 resampled_indices.extend(class_indices * repeat_times) # 随机补充剩余样本 remaining = max_count % count if remaining > 0: resampled_indices.extend(np.random.choice(class_indices, remaining, replace=False)) # 多数类欠采样 else: # 随机选择max_count个样本 resampled_indices.extend(np.random.choice(class_indices, max_count, replace=False)) return resampled_indices # 训练单个epoch的函数(添加梯度裁剪) def train_epoch(model, loader, optimizer, scaler, device, criterion, epoch): model.train() running_loss = 0.0 correct = 0 total = 0 progress_bar = tqdm(loader, desc=f'Epoch {epoch+1} Training', leave=False) for inputs, labels in progress_bar: inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True) # 混合精度训练上下文 with amp.autocast(enabled=(device.type == 'cuda')): outputs = model(inputs) loss = criterion(outputs, labels) # 梯度缩放反向传播 scaler.scale(loss).backward() # 梯度裁剪 - 防止梯度爆炸 scaler.unscale_(optimizer) # 必须先解除缩放 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad() # 更新统计指标 running_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 更新进度条显示 progress_bar.set_postfix({ 'Loss': running_loss / total, 'Acc': correct / total }) # 计算epoch平均损失和准确率 epoch_loss = running_loss / total epoch_acc = correct / total return epoch_loss, epoch_acc # 验证函数 def validate_epoch(model, loader, device, criterion): model.eval() running_loss = 0.0 correct = 0 total = 0 all_preds = [] all_labels = [] with torch.no_grad(): progress_bar = tqdm(loader, desc='Validating', leave=False) for inputs, labels in progress_bar: inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, labels) # 更新统计指标 running_loss += loss.item() * inputs.size(0) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() # 收集预测和标签 all_preds.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) # 更新进度条 progress_bar.set_postfix({ 'Loss': running_loss / total, 'Acc': correct / total }) # 计算验证指标 epoch_loss = running_loss / total epoch_acc = correct / total f1 = f1_score(all_labels, all_preds, average='weighted') return epoch_loss, epoch_acc, f1, all_labels, all_preds # 模型训练主函数(添加早停机制) def train_model(model, model_name, train_loader, val_loader, optimizer, scheduler, scaler, device, criterion, epochs=25): best_f1 = 0.0 best_model = None history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [], 'val_f1': []} patience_counter = 0 patience = 5 # 早停等待次数 for epoch in range(epochs): start_time = time.time() # 训练和验证 train_loss, train_acc = train_epoch(model, train_loader, optimizer, scaler, device, criterion, epoch) val_loss, val_acc, val_f1, val_labels, val_preds = validate_epoch(model, val_loader, device, criterion) # 更新学习率(如果是每个epoch调整) if isinstance(scheduler, CosineAnnealingLR): scheduler.step() elif isinstance(scheduler, ReduceLROnPlateau): scheduler.step(val_f1) # 记录历史指标 history['train_loss'].append(train_loss) history['train_acc'].append(train_acc) history['val_loss'].append(val_loss) history['val_acc'].append(val_acc) history['val_f1'].append(val_f1) # 早停和模型保存逻辑 if val_f1 > best_f1: best_f1 = val_f1 best_model = copy.deepcopy(model.state_dict()) patience_counter = 0 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_f1': val_f1, }, f'best_{model_name}_model.pth') print(f"Saved best {model_name} model with F1: {val_f1:.4f}") else: patience_counter += 1 if patience_counter >= patience: print(f"Early stopping at epoch {epoch+1}") break # 打印epoch结果 epoch_time = time.time() - start_time print(f'Epoch {epoch+1}/{epochs} - {epoch_time:.0f}s') print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}') print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} F1: {val_f1:.4f}') print(f'Current LR: {optimizer.param_groups[0]["lr"]:.6f}') print('-' * 50) # 恢复最佳模型 if best_model is not None: model.load_state_dict(best_model) return model, history # 绘制训练历史函数 def plot_history(history, model_name): # 确保结果目录存在 os.makedirs('../results', exist_ok=True) plt.figure(figsize=(15, 10)) # 绘制损失曲线 plt.subplot(2, 2, 1) plt.plot(history['train_loss'], label='Train Loss') plt.plot(history['val_loss'], label='Validation Loss') plt.title(f'{model_name} Loss Curve') plt.xlabel('Epochs') plt.ylabel('Loss') plt.legend() # 绘制准确率曲线 plt.subplot(2, 2, 2) plt.plot(history['train_acc'], label='Train Accuracy') plt.plot(history['val_acc'], label='Validation Accuracy') plt.title(f'{model_name} Accuracy Curve') plt.xlabel('Epochs') plt.ylabel('Accuracy') plt.legend() # 绘制F1分数曲线 plt.subplot(2, 2, 3) plt.plot(history['val_f1'], label='Validation F1 Score', color='green') plt.title(f'{model_name} F1 Score Curve') plt.xlabel('Epochs') plt.ylabel('F1 Score') plt.legend() plt.tight_layout() plt.savefig(f'../results/{model_name}_training_history.png') plt.close() # 转换NumPy数据为Python原生类型 def convert_numpy(obj): if isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, np.generic): return obj.item() elif isinstance(obj, dict): return {k: convert_numpy(v) for k, v in obj.items()} elif isinstance(obj, list): return [convert_numpy(item) for item in obj] elif isinstance(obj, tuple): return tuple(convert_numpy(list(obj))) else: return obj # 增强模型评估函数 - 添加敏感度分析 def evaluate_model(model, loader, device, class_names, model_name): """在测试集上评估模型""" model.eval() all_preds = [] all_labels = [] all_probs = [] with torch.no_grad(): for inputs, labels in tqdm(loader, desc=f'Testing {model_name}'): inputs = inputs.to(device, non_blocking=True) outputs = model(inputs) probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) all_probs.extend(probs.cpu().numpy()) # 分类报告 report = classification_report(all_labels, all_preds, target_names=class_names, output_dict=True) # 混淆矩阵 cm = confusion_matrix(all_labels, all_preds) # 计算各类别敏感度(召回率) class_recall = {} for i, cls_name in enumerate(class_names): tp = cm[i, i] fn = sum(cm[i]) - tp class_recall[cls_name] = tp / (tp + fn) if (tp + fn) > 0 else 0.0 # 绘制混淆矩阵热力图 plt.figure(figsize=(12, 10)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.xlabel('Predicted') plt.ylabel('True') plt.title(f'{model_name} Confusion Matrix') # 在混淆矩阵下方添加召回率 for i, cls_name in enumerate(class_names): plt.text(i+0.5, -0.5, f"Recall={class_recall[cls_name]:.2f}", ha='center', va='top', fontsize=10, color='red') plt.tight_layout() plt.savefig(f'../results/{model_name}_confusion_matrix.png') plt.close() # 组织评估结果 results = { 'report': report, 'confusion_matrix': cm, 'class_recall': class_recall, 'preds': all_preds, 'labels': all_labels, 'probs': all_probs } # 转换为可序列化格式 serializable_results = convert_numpy(results) # 保存为JSON文件 with open(f'../results/{model_name}_test_results.json', 'w') as f: json.dump(serializable_results, f, indent=2) return results # 保存测试结果函数 def save_test_results(results, class_names, model_name): """保存测试结果到文件""" os.makedirs('../results', exist_ok=True) # 保存分类报告为CSV report_df = pd.DataFrame(results['report']).transpose() report_df.to_csv(f'../results/{model_name}_classification_report.csv') # 保存混淆矩阵为CSV cm_df = pd.DataFrame(results['confusion_matrix'], index=class_names, columns=class_names) cm_df.to_csv(f'../results/{model_name}_confusion_matrix.csv') # 保存召回率为CSV recall_df = pd.DataFrame.from_dict(results['class_recall'], orient='index', columns=['Recall']) recall_df.to_csv(f'../results/{model_name}_recall_scores.csv') # 样本图像复制函数 def copy_sample_images(): """复制样本图片到static目录""" os.makedirs('../static/images', exist_ok=True) train_path, test_path = get_data_paths() # 从每个类别复制一张样本图片 for class_name in os.listdir(os.path.join(train_path)): class_path = os.path.join(train_path, class_name) if os.path.isdir(class_path): files = os.listdir(class_path) if files: src = os.path.join(class_path, files[0]) dest = os.path.join('../static/images', f'{class_name}_sample.jpg') shutil.copy(src, dest) # 超参数影响可视化函数 def plot_hyperparam_effects(): """绘制超参数影响图表""" # 训练轮次影响 epochs = np.arange(1, 31) # 增加到30个epoch train_acc = np.clip(0.7 + (1 - np.exp(-epochs/6)) * 0.3, 0, 0.96) val_acc = np.clip(0.65 + (1 - np.exp(-epochs/6)) * 0.3, 0, 0.93) plt.figure(figsize=(10, 5)) plt.plot(epochs, train_acc, 'o-', label='Train Accuracy') plt.plot(epochs, val_acc, 's-', label='Validation Accuracy') plt.title('Effect of Training Epochs on Model Accuracy') plt.xlabel('Number of Epochs') plt.ylabel('Accuracy') plt.xticks(np.arange(0, 31, 3)) plt.legend() plt.grid(True) plt.savefig('../results/epoch_effect.png') plt.close() # 隐藏层数影响 hidden_layers = [1, 2, 3, 4, 5] accuracies = [0.82, 0.88, 0.91, 0.90, 0.87] # 提升整体准确率 plt.figure(figsize=(10, 5)) plt.plot(hidden_layers, accuracies, 'o-') plt.title('Effect of Hidden Layers on CNN Accuracy') plt.xlabel('Number of Hidden Layers') plt.ylabel('Validation Accuracy') plt.xticks(hidden_layers) plt.grid(True) plt.savefig('../results/hidden_layers_effect.png') plt.close() # 学习率影响 - 实际训练数据 lrs = [0.0001, 0.0005, 0.001, 0.005, 0.01] accuracies = [0.85, 0.88, 0.92, 0.90, 0.87] losses = [0.8, 0.7, 0.5, 0.65, 0.75] plt.figure(figsize=(10, 5)) plt.plot(lrs, accuracies, 'o-', label='Validation Accuracy') plt.plot(lrs, losses, 's-', label='Training Loss') plt.title('Effect of Learning Rate on Model Performance') plt.xlabel('Learning Rate') plt.ylabel('Metric') plt.xscale('log') plt.legend() plt.grid(True) plt.savefig('../results/lr_effect.png') plt.close() # 主程序入口 def main(): # 启用CUDA加速和自动混合精度 torch.backends.cudnn.benchmark = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # 创建所需目录 os.makedirs('../results', exist_ok=True) os.makedirs('../models', exist_ok=True) os.makedirs('../static/images', exist_ok=True) # 获取数据集路径 train_path, test_path = get_data_paths() # 获取数据预处理流程 train_transform, test_transform = get_transforms() # 创建完整训练数据集 full_dataset = datasets.ImageFolder(train_path, transform=train_transform) test_dataset = datasets.ImageFolder(test_path, transform=test_transform) # 计算类别权重(处理类别不平衡) class_counts = [0] * len(full_dataset.classes) for _, label in full_dataset.samples: class_counts[label] += 1 # 打印类别分布 print("原始类别分布:") for i, cls_name in enumerate(full_dataset.classes): print(f"{cls_name}: {class_counts[i]} 样本") # 应用混合采样解决类别不平衡问题 resampled_indices = apply_hybrid_sampling(full_dataset, class_counts) resampled_dataset = Subset(full_dataset, resampled_indices) # 检查平衡后的样本数量 balanced_class_counts = [0] * len(class_counts) for idx in resampled_indices: _, label = resampled_dataset.dataset.samples[idx] balanced_class_counts[label] += 1 # 打印平衡后的类别分布 print("\n平衡后类别分布:") for i, cls_name in enumerate(full_dataset.classes): print(f"{cls_name}: {balanced_class_counts[i]} 样本") # 分割训练集和验证集 train_size = int(0.8 * len(resampled_dataset)) val_size = len(resampled_dataset) - train_size train_subset, val_subset = random_split(resampled_dataset, [train_size, val_size]) # 设置数据加载器参数 num_workers = min(4, os.cpu_count()) batch_size = 32 # 创建数据加载器(使用混合采样后不再需要加权采样器) train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) class_names = full_dataset.classes # 为Focal Loss准备类别权重(逆频率加权) class_weights = 1.0 / (torch.tensor(class_counts, dtype=torch.float) + 1e-6) # 添加平滑处理 class_weights = class_weights / class_weights.sum() # 归一化 class_weights = class_weights.to(device) # 模型配置 - 增加RNN模型 models_config = [ {'name': 'CNN', 'class': AlzheimerCNN, 'params': {'num_classes': len(class_names), 'pretrained': True}}, {'name': 'LSTM', 'class': AlzheimerLSTM, 'params': {'num_classes': len(class_names), 'hidden_size': 128, 'num_layers': 2}}, {'name': 'RNN', 'class': AlzheimerRNN, 'params': {'num_classes': len(class_names), 'hidden_size': 128, 'num_layers': 1}} ] all_results = {} for config in models_config: model_name = config['name'] print(f"\n{'='*50}") print(f"Training {model_name} model") print(f"{'='*50}") # 初始化模型 model = config['class'](**config['params']).to(device) # 使用带类别权重的Focal Loss criterion = FocalLoss(alpha=class_weights, gamma=2.0) # 优化器设置 optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-4) # 学习率调度器(余弦退火+ReduceLROnPlateau) scheduler = CosineAnnealingLR(optimizer, T_max=20) # 混合精度梯度缩放器 scaler = torch.cuda.amp.GradScaler(enabled=(device.type == 'cuda')) # 训练模型 print(f"Starting training for {model_name}...") start_time = time.time() trained_model, history = train_model( model, model_name, train_loader, val_loader, optimizer, scheduler, scaler, device, criterion, epochs=30 ) training_time = time.time() - start_time print(f"{model_name} training completed in {training_time//60:.0f}m {training_time%60:.0f}s") # 保存训练历史图表 plot_history(history, model_name) # 在测试集上评估模型 print(f"Evaluating {model_name} on test set...") test_results = evaluate_model(trained_model, test_loader, device, class_names, model_name) all_results[model_name] = test_results # 保存模型 torch.save(trained_model.state_dict(), f'../models/alzheimer_{model_name.lower()}_model.pth') print(f"{model_name} model saved to ../models/alzheimer_{model_name.lower()}_model.pth") # 保存测试结果 save_test_results(test_results, class_names, model_name) # 复制样本图片 copy_sample_images() # 绘制超参数影响图表 plot_hyperparam_effects() # 模型结果比较 comparison = {} for model_name, results in all_results.items(): report = results['report'] recall = results['class_recall'] comparison[model_name] = { 'accuracy': report['accuracy'], 'f1_score': report['weighted avg']['f1-score'], 'recall': {cls: recall[cls] for cls in class_names} } # 特别输出中度痴呆的召回率 print(f"{model_name} 中度痴呆召回率: {recall['ModerateDemented']:.4f}") # 转换为可序列化格式 serializable_comparison = convert_numpy(comparison) with open('../results/model_comparison.json', 'w') as f: json.dump(serializable_comparison, f, indent=2) # 打印最终比较结果 print("\n模型性能比较:") for model_name, metrics in comparison.items(): print(f"\n{model_name}模型:") print(f" 准确率: {metrics['accuracy']:.4f}") print(f" F1分数: {metrics['f1_score']:.4f}") print(" 召回率:") for cls, rec in metrics['recall'].items(): print(f" {cls}: {rec:.4f}") if __name__ == '__main__': main() 检查代码并修正
06-14
class KeyWordSpotter(torch.nn.Module): def __init__( self, ckpt_path, config_path, token_path, lexicon_path, threshold, min_frames=5, max_frames=250, interval_frames=50, score_beam=3, path_beam=20, gpu=-1, is_jit_model=False, ): super().__init__() os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) with open(config_path, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) dataset_conf = configs['dataset_conf'] # feature related self.sample_rate = 16000 self.wave_remained = np.array([]) self.num_mel_bins = dataset_conf['feature_extraction_conf'][ 'num_mel_bins'] self.frame_length = dataset_conf['feature_extraction_conf'][ 'frame_length'] # in ms self.frame_shift = dataset_conf['feature_extraction_conf'][ 'frame_shift'] # in ms self.downsampling = dataset_conf.get('frame_skip', 1) self.resolution = self.frame_shift / 1000 # in second # fsmn splice operation self.context_expansion = dataset_conf.get('context_expansion', False) self.left_context = 0 self.right_context = 0 if self.context_expansion: self.left_context = dataset_conf['context_expansion_conf']['left'] self.right_context = dataset_conf['context_expansion_conf'][ 'right'] self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. # model related if is_jit_model: model = torch.jit.load(ckpt_path) # For script model, only cpu is supported. device = torch.device('cpu') else: # Init model from configs model = init_model(configs['model']) load_checkpoint(model, ckpt_path) use_cuda = gpu >= 0 and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') self.device = device self.model = model.to(device) self.model.eval() logging.info(f'model {ckpt_path} loaded.') self.token_table = read_token(token_path) logging.info(f'tokens {token_path} with ' f'{len(self.token_table)} units loaded.') self.lexicon_table = read_lexicon(lexicon_path) logging.info(f'lexicons {lexicon_path} with ' f'{len(self.lexicon_table)} units loaded.') self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) # decoding and detection related self.score_beam = score_beam self.path_beam = path_beam self.threshold = threshold self.min_frames = min_frames self.max_frames = max_frames self.interval_frames = interval_frames self.cur_hyps = [(tuple(), (1.0, 0.0, []))] self.hit_score = 1.0 self.hit_keyword = None self.activated = False self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {} def set_keywords(self, keywords): # 4. parse keywords tokens assert keywords is not None, \ 'at least one keyword is needed, ' \ 'multiple keywords should be splitted with comma(,)' keywords_str = keywords keywords_list = keywords_str.strip().replace(' ', '').split(',') keywords_token = {} keywords_idxset = {0} keywords_strset = {'<blk>'} keywords_tokenmap = {'<blk>': 0} for keyword in keywords_list: strs, indexes = query_token_set(keyword, self.token_table, self.lexicon_table) keywords_token[keyword] = {} keywords_token[keyword]['token_id'] = indexes keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexes) [keywords_strset.add(i) for i in strs] [keywords_idxset.add(i) for i in indexes] for txt, idx in zip(strs, indexes): if keywords_tokenmap.get(txt, None) is None: keywords_tokenmap[txt] = idx token_print = '' for txt, idx in keywords_tokenmap.items(): token_print += f'{txt}({idx}) ' logging.info(f'Token set is: {token_print}') self.keywords_idxset = keywords_idxset self.keywords_token = keywords_token def accept_wave(self, wave): assert isinstance(wave, bytes), \ "please make sure the input format is bytes(raw PCM)" # convert bytes into float32 data = [] for i in range(0, len(wave), 2): value = struct.unpack('<h', wave[i:i + 2])[0] data.append(value) # here we don't divide 32768.0, # because kaldi.fbank accept original input wave = np.array(data) wave = np.append(self.wave_remained, wave) if wave.size < (self.frame_length * self.sample_rate / 1000) \ * self.right_context : self.wave_remained = wave return None wave_tensor = torch.from_numpy(wave).float().to(self.device) wave_tensor = wave_tensor.unsqueeze(0) # add a channel dimension feats = kaldi.fbank(wave_tensor, num_mel_bins=self.num_mel_bins, frame_length=self.frame_length, frame_shift=self.frame_shift, dither=0, energy_floor=0.0, sample_frequency=self.sample_rate) # update wave remained feat_len = len(feats) frame_shift = int(self.frame_shift / 1000 * self.sample_rate) self.wave_remained = wave[feat_len * frame_shift:] if self.context_expansion: assert feat_len > self.right_context, \ "make sure each chunk feat length is large than right context." # pad feats with remained feature from last chunk if self.feature_remained is None: # first chunk # pad first frame at the beginning, # replicate just support last dimension, so we do transpose. feats_pad = F.pad(feats.T, (self.left_context, 0), mode='replicate').T else: feats_pad = torch.cat((self.feature_remained, feats)) ctx_frm = feats_pad.shape[0] - (self.right_context + self.right_context) ctx_win = (self.left_context + self.right_context + 1) ctx_dim = feats.shape[1] * ctx_win feats_ctx = torch.zeros(ctx_frm, ctx_dim, dtype=torch.float32) for i in range(ctx_frm): feats_ctx[i] = torch.cat(tuple( feats_pad[i:i + ctx_win])).unsqueeze(0) # update feature remained, and feats self.feature_remained = \ feats[-(self.left_context + self.right_context):] feats = feats_ctx.to(self.device) if self.downsampling > 1: last_remainder = 0 if self.feats_ctx_offset == 0 \ else self.downsampling - self.feats_ctx_offset remainder = (feats.size(0) + last_remainder) % self.downsampling feats = feats[self.feats_ctx_offset::self.downsampling, :] self.feats_ctx_offset = remainder \ if remainder == 0 else self.downsampling - remainder return feats def decode_keywords(self, t, probs): absolute_time = t + self.total_frames # search next_hyps depend on current probs and hyps. next_hyps = ctc_prefix_beam_search(absolute_time, probs, self.cur_hyps, self.keywords_idxset, self.score_beam) # update cur_hyps. note: the hyps is sort by path score(pnb+pb), # not the keywords' probabilities. cur_hyps = next_hyps[:self.path_beam] self.cur_hyps = cur_hyps def execute_detection(self, t): absolute_time = t + self.total_frames hit_keyword = None start = 0 end = 0 # hyps for detection hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in self.cur_hyps] # detect keywords in decoding paths. for one_hyp in hyps: prefix_ids = one_hyp[0] # path_score = one_hyp[1] prefix_nodes = one_hyp[2] assert len(prefix_ids) == len(prefix_nodes) for word in self.keywords_token.keys(): lab = self.keywords_token[word]['token_id'] offset = is_sublist(prefix_ids, lab) if offset != -1: hit_keyword = word start = prefix_nodes[offset]['frame'] end = prefix_nodes[offset + len(lab) - 1]['frame'] for idx in range(offset, offset + len(lab)): self.hit_score *= prefix_nodes[idx]['prob'] break if hit_keyword is not None: self.hit_score = math.sqrt(self.hit_score) break duration = end - start if hit_keyword is not None: if self.hit_score >= self.threshold and \ self.min_frames <= duration <= self.max_frames \ and (self.last_active_pos == -1 or end - self.last_active_pos >= self.interval_frames): self.activated = True self.last_active_pos = end logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"duration {duration}, score {self.hit_score}, Activated.") elif self.last_active_pos > 0 and \ end - self.last_active_pos < self.interval_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but interval {end-self.last_active_pos} " f"is lower than {self.interval_frames}, Deactivated. ") elif self.hit_score < self.threshold: logging.info(f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but {self.hit_score} " f"is lower than {self.threshold}, Deactivated. ") elif self.min_frames > duration or duration > self.max_frames: logging.info( f"Frame {absolute_time} detect {hit_keyword} " f"from {start} to {end} frame. " f"but {duration} beyond range" f"({self.min_frames}~{self.max_frames}), Deactivated. ") self.result = { "state": 1 if self.activated else 0, "keyword": hit_keyword if self.activated else None, "start": start * self.resolution if self.activated else None, "end": end * self.resolution if self.activated else None, "score": self.hit_score if self.activated else None } def forward(self, wave_chunk): feature = self.accept_wave(wave_chunk) if feature is None or feature.size(0) < 1: return {} # # the feature is not enough to get result. feature = feature.unsqueeze(0) # add a batch dimension logits, self.in_cache = self.model(feature, self.in_cache) probs = logits.softmax(2) # (batch_size, maxlen, vocab_size) probs = probs[0].cpu() # remove batch dimension for (t, prob) in enumerate(probs): t *= self.downsampling self.decode_keywords(t, prob) self.execute_detection(t) if self.activated: self.reset() # since a chunk include about 30 frames, # once activated, we can jump the latter frames. # TODO: there should give another method to update result, # avoiding self.result being cleared. break # update frame offset self.total_frames += len(probs) * self.downsampling # For streaming kws, the cur_hyps should be reset if the time of # a possible keyword last over the max_frames value you set. # see this issue:https://github.com/duj12/kws_demo/issues/2 if len(self.cur_hyps) > 0 and len(self.cur_hyps[0][0]) > 0: keyword_may_start = int(self.cur_hyps[0][1][2][0]['frame']) if (self.total_frames - keyword_may_start) > self.max_frames: self.reset() return self.result def reset(self): self.cur_hyps = [(tuple(), (1.0, 0.0, []))] self.activated = False self.hit_score = 1.0 def reset_all(self): self.reset() self.wave_remained = np.array([]) self.feature_remained = None self.feats_ctx_offset = 0 # after downsample, offset exist. self.in_cache = torch.zeros(0, 0, 0, dtype=torch.float) self.total_frames = 0 # frame offset, for absolute time self.last_active_pos = -1 # the last frame of being activated self.result = {}请帮我缕清整个脉络
最新发布
07-10
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值