**ResNet-SE + MFCC** 训练框架,包括 **数据加载、训练流程**,以及 **混淆矩阵** 可视化示例


1. 依赖库安装

如果你还没安装相关库,请先执行:

pip install torch torchaudio torchvision scikit-learn matplotlib tqdm

2. 数据加载

这里假设你有一个 音频分类数据集,其文件结构如下:

dataset/
│── train/
│   ├── class_0/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│   ├── class_1/
│   │   ├── audio_0.wav
│   │   ├── audio_1.wav
│── val/
│   ├── class_0/
│   ├── class_1/

实现数据加载器:

import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 音频数据集类
class AudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=16000, n_mfcc=40):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.classes = sorted(os.listdir(root_dir))  # 目录名作为类别
        self.file_paths = []
        self.labels = []

        for label, class_name in enumerate(self.classes):
            class_dir = os.path.join(root_dir, class_name)
            for file_name in os.listdir(class_dir):
                self.file_paths.append(os.path.join(class_dir, file_name))
                self.labels.append(label)

        self.mfcc_transform = torchaudio.transforms.MFCC(
            sample_rate=self.sample_rate,
            n_mfcc=self.n_mfcc,
            melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64}
        )

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]

        waveform, sr = torchaudio.load(file_path)
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = resampler(waveform)

        mfcc = self.mfcc_transform(waveform).squeeze(0)  # (n_mfcc, time)
        mfcc = mfcc.unsqueeze(0).repeat(3, 1, 1)  # (3, n_mfcc, time) 适配 ResNet

        return mfcc, label


# 创建数据加载器
def get_dataloaders(train_dir, val_dir, batch_size=32, num_workers=2):
    train_dataset = AudioDataset(train_dir)
    val_dataset = AudioDataset(val_dir)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader

3. 训练和验证

import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device="cuda"):
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}]")

        # 训练阶段
        model.train()
        total_loss, correct, total = 0, 0, 0
        for inputs, labels in tqdm(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        train_acc = correct / total
        print(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")

        # 验证阶段
        model.eval()
        total_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                total += labels.size(0)
                correct += (outputs.argmax(dim=1) == labels).sum().item()

        val_acc = correct / total
        print(f"Val Loss: {total_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")

    return model

4. 混淆矩阵可视化

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def evaluate_model(model, val_loader, device="cuda"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    return np.array(all_labels), np.array(all_preds)

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues, values_format="d")
    plt.xticks(rotation=45)
    plt.show()

5. 运行完整训练流程

if __name__ == "__main__":
    train_dir = "dataset/train"
    val_dir = "dataset/val"
    batch_size = 32
    num_epochs = 10
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # 加载数据
    train_loader, val_loader = get_dataloaders(train_dir, val_dir, batch_size)

    # 初始化模型
    model = ResNetSE(num_classes=len(os.listdir(train_dir)))

    # 训练模型
    trained_model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, device=device)

    # 计算混淆矩阵
    y_true, y_pred = evaluate_model(trained_model, val_loader, device=device)

    # 绘制混淆矩阵
    class_names = sorted(os.listdir(train_dir))
    plot_confusion_matrix(y_true, y_pred, class_names)

6. 总结

数据加载

  • 通过 torchaudio 提取 MFCC 特征,并适配 ResNet 输入格式。

ResNet-SE 训练

  • 训练过程包含 Adam 优化器 + 交叉熵损失,支持 GPU 训练。

混淆矩阵可视化

  • 通过 sklearn 计算混淆矩阵,并绘制 分类效果图

改进方向

🚀 模型优化

  • 使用 ResNet-34/50 替代 ResNet-18 提升表达能力。
  • 结合 SpecAugment 增强数据,提高鲁棒性。

推理加速

  • 采用 TorchScript / ONNX 进行模型导出,提高部署效率。

💡 数据增强

  • 额外使用 时域和频域增强(如 torchaudio.transforms.TimeMasking)。

这样,你就能完整训练 ResNet-SE + MFCC 进行音频分类,并分析模型性能了!💪🚀

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大霸王龙

+V来点难题

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值