1. 依赖库安装
如果你还没安装相关库,请先执行:
pip install torch torchaudio torchvision scikit-learn matplotlib tqdm
2. 数据加载
这里假设你有一个 音频分类数据集,其文件结构如下:
dataset/
│── train/
│ ├── class_0/
│ │ ├── audio_0.wav
│ │ ├── audio_1.wav
│ ├── class_1/
│ │ ├── audio_0.wav
│ │ ├── audio_1.wav
│── val/
│ ├── class_0/
│ ├── class_1/
实现数据加载器:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
# 音频数据集类
class AudioDataset(Dataset):
def __init__(self, root_dir, sample_rate=16000, n_mfcc=40):
self.root_dir = root_dir
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.classes = sorted(os.listdir(root_dir)) # 目录名作为类别
self.file_paths = []
self.labels = []
for label, class_name in enumerate(self.classes):
class_dir = os.path.join(root_dir, class_name)
for file_name in os.listdir(class_dir):
self.file_paths.append(os.path.join(class_dir, file_name))
self.labels.append(label)
self.mfcc_transform = torchaudio.transforms.MFCC(
sample_rate=self.sample_rate,
n_mfcc=self.n_mfcc,
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 64}
)
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
file_path = self.file_paths[idx]
label = self.labels[idx]
waveform, sr = torchaudio.load(file_path)
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
waveform = resampler(waveform)
mfcc = self.mfcc_transform(waveform).squeeze(0) # (n_mfcc, time)
mfcc = mfcc.unsqueeze(0).repeat(3, 1, 1) # (3, n_mfcc, time) 适配 ResNet
return mfcc, label
# 创建数据加载器
def get_dataloaders(train_dir, val_dir, batch_size=32, num_workers=2):
train_dataset = AudioDataset(train_dir)
val_dataset = AudioDataset(val_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader
3. 训练和验证
import torch.optim as optim
from tqdm import tqdm
def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001, device="cuda"):
model = model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
# 训练阶段
model.train()
total_loss, correct, total = 0, 0, 0
for inputs, labels in tqdm(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
total += labels.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
train_acc = correct / total
print(f"Train Loss: {total_loss/len(train_loader):.4f}, Train Acc: {train_acc:.4f}")
# 验证阶段
model.eval()
total_loss, correct, total = 0, 0, 0
with torch.no_grad():
for inputs, labels in tqdm(val_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
total_loss += loss.item()
total += labels.size(0)
correct += (outputs.argmax(dim=1) == labels).sum().item()
val_acc = correct / total
print(f"Val Loss: {total_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}")
return model
4. 混淆矩阵可视化
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
def evaluate_model(model, val_loader, device="cuda"):
model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for inputs, labels in tqdm(val_loader):
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
preds = outputs.argmax(dim=1).cpu().numpy()
labels = labels.cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels)
return np.array(all_labels), np.array(all_preds)
def plot_confusion_matrix(y_true, y_pred, class_names):
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap=plt.cm.Blues, values_format="d")
plt.xticks(rotation=45)
plt.show()
5. 运行完整训练流程
if __name__ == "__main__":
train_dir = "dataset/train"
val_dir = "dataset/val"
batch_size = 32
num_epochs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
# 加载数据
train_loader, val_loader = get_dataloaders(train_dir, val_dir, batch_size)
# 初始化模型
model = ResNetSE(num_classes=len(os.listdir(train_dir)))
# 训练模型
trained_model = train_model(model, train_loader, val_loader, num_epochs=num_epochs, device=device)
# 计算混淆矩阵
y_true, y_pred = evaluate_model(trained_model, val_loader, device=device)
# 绘制混淆矩阵
class_names = sorted(os.listdir(train_dir))
plot_confusion_matrix(y_true, y_pred, class_names)
6. 总结
✅ 数据加载
- 通过
torchaudio
提取 MFCC 特征,并适配 ResNet 输入格式。
✅ ResNet-SE 训练
- 训练过程包含 Adam 优化器 + 交叉熵损失,支持 GPU 训练。
✅ 混淆矩阵可视化
- 通过
sklearn
计算混淆矩阵,并绘制 分类效果图。
改进方向
🚀 模型优化
- 使用 ResNet-34/50 替代 ResNet-18 提升表达能力。
- 结合 SpecAugment 增强数据,提高鲁棒性。
⚡ 推理加速
- 采用 TorchScript / ONNX 进行模型导出,提高部署效率。
💡 数据增强
- 额外使用 时域和频域增强(如
torchaudio.transforms.TimeMasking
)。
这样,你就能完整训练 ResNet-SE + MFCC 进行音频分类,并分析模型性能了!💪🚀