1. Matplotlib 简介
1.1 为什么在 AI/ML/DL 中使用 Matplotlib?
Matplotlib 是 Python 中最基础且广泛使用的数据可视化库,在 AI/ML/DL 领域有以下关键用途:
- 模型评估:可视化混淆矩阵、ROC 曲线、损失曲线等,帮助分析模型性能。
- 数据探索:绘制数据分布、特征相关性、降维结果(如 t-SNE),辅助数据预处理。
- 结果展示:生成图像网格、特征重要性图,用于论文、报告或演示。
- 灵活性:支持高度自定义,适合复杂的 ML/DL 实验需求。
- 生态兼容:与 NumPy、Pandas、Scikit-learn、PyTorch、TensorFlow 等无缝集成。
相比其他工具(如 Seaborn 或 Plotly),Matplotlib 提供了底层控制,适合需要精细调整的场景,尤其在学术研究和实验分析中。
1.2 安装与基本设置
安装 Matplotlib:
pip install matplotlib
验证安装:
import matplotlib
print(matplotlib.__version__) # 示例输出:3.9.2
导入核心模块:
import matplotlib.pyplot as plt
import numpy as np
约定:
plt是标准别名,简化代码。
2. Matplotlib 基础
2.1 核心概念
- Figure:整个图形窗口,类似画布。
- Axes:图表中的绘图区域,一个 Figure 可包含多个 Axes(子图)。
- Axis:坐标轴(x 轴或 y 轴),包含刻度和标签。
关系:Figure → 包含多个 Axes → 每个 Axes 包含 Axis。
2.2 绘制简单的折线图(损失曲线)
在 ML/DL 中,损失曲线是监控模型训练过程的标配。以下是一个示例:
# 模拟训练和验证损失
epochs = np.arange(1, 11)
train_loss = np.array([0.9, 0.7, 0.5, 0.4, 0.3, 0.25, 0.2, 0.18, 0.15, 0.12])
val_loss = np.array([1.0, 0.8, 0.6, 0.55, 0.5, 0.48, 0.45, 0.43, 0.42, 0.41])
plt.plot(epochs, train_loss, 'b-', label='Training Loss') # 蓝色实线
plt.plot(epochs, val_loss, 'r--', label='Validation Loss') # 红色虚线
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
2.3 自定义样式
为增强可读性,可以添加标题、标签、网格等:
plt.plot(epochs, train_loss, 'b-', label='Training Loss')
plt.plot(epochs, val_loss, 'r--', label='Validation Loss')
plt.title('Loss Curve', fontsize=14)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.legend(fontsize=10)
plt.show()
3. AI/ML/DL 常用可视化类型
以下是 ML/DL 领域中最常用的 Matplotlib 可视化类型,包含代码示例。
3.1 损失曲线(Loss Curve)
如上所述,损失曲线用于监控模型训练。以下是一个更真实的例子,结合 PyTorch 的训练过程:
import torch
import torch.nn as nn
import torch.optim as optim
# 模拟简单模型和数据
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
X = torch.randn(100, 10)
y = torch.randn(100, 1)
losses = []
# 训练
for epoch in range(50):
optimizer.zero_grad()
outputs = model(X)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
losses.append(loss.item())
# 绘制损失曲线
plt.plot(range(1, 51), losses, 'b-', label='Training Loss')
plt.title('Training Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
3.2 混淆矩阵(Confusion Matrix)
混淆矩阵是分类任务中评估模型性能的核心工具。以下示例使用 Scikit-learn 和 Matplotlib 绘制:
from sklearn.metrics import confusion_matrix
import seaborn as sns # 用于美化热力图

最低0.47元/天 解锁文章
527

被折叠的 条评论
为什么被折叠?



