什么是Batch Normalization?

部署运行你感兴趣的模型镜像

一、概念

        Batch Normalization是在2015年提出的数据归一化方法,主要用在深度神经网络中激活层之前。它的主要作用是加快模型训练时的收敛速度,使模型训练过程更加稳定,避免梯度爆炸或消失,并起到一定的正则化作用,有时甚至可以替代Dropout。

        BN可以应用于全连接层和卷积层,在非线性映射(激活函数)之前对数据进行规范化,使得结果的输出信号的各个维度均值为0,方差为1。这有助于网络的训练,特别是在梯度消失或爆炸的情况下

二、原理

        BN的核心思想是让每一层的输入保持一个稳定的分布,这样模型在训练时可以减少对输入分布变化的依赖,从而加速收敛并提升稳定性。具体来说,BN包含以下几个步骤:

1、计算小批量数据的均值和方差

         在每一层的输入特征图上,BN会在当前batch的数据上计算其均值和方差。

2、数据归一化

        BN对每一个样本的输出进行归一化处理,通过减去均值后再除以标准差,使得归一化后的输出数据具有零均值和单位方差的标准正态分布

3、缩放和平移

        直接归一化会限制模型的学习能力,因为归一化后的输出被严格限制在均值为0和方差为1的分布中。为了恢复模型的表达能力,BN引入了两个可学习的参数:缩放参数γ和偏移参数β,将归一化后的数据进行线性变换:

y_{i} = \gamma \frac{x_{i}- \mu_{B}}{\sqrt{\sigma^{2}_{B}+ \epsilon}} + \beta

        其中,\mu_{B}是均值;\sigma ^{2}_{B}是方差;\epsilon是一个极小值,用于防止分母为0;缩放参数γ和偏移参数β是可训练参数,参与整个网络的反向传播。

4、示例

        这里我们简单调用torch中的nn.BatchNorm1d来实现Batch Normalization。在torch中,训练模型时缩放参数γ和偏移参数β是自动更新的,不需要我们额外操作。

import torch
import torch.nn as nn

# 假设我们有一个输入张量x和一个batch_size
x = torch.randn(6, 10)  # 例如,10维的特征,6是批次大小
print(x)

# 实现Batch Normalization
batch_norm = nn.BatchNorm1d(10)  # 10是特征的维度
x_bn = batch_norm(x)
print(x_bn)

三、python应用

        这里,我们简单创建一个MLP,并对比BN前后的数据变化。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 设置随机种子以确保结果可复现
torch.manual_seed(0)

# 创建一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(100, 50)  # 一个线性层
        self.bn = nn.BatchNorm1d(50)  # Batch Normalization层

    def forward(self, x):
        x = self.linear(x)
        x = self.bn(x)
        return x

# 创建模型实例
model = SimpleModel()

# 生成模拟数据:100个样本,每个样本100个特征
x = torch.randn(100, 100, requires_grad=True)

# 前向传播,计算BN前的数据
x_linear = model.linear(x)
x_linear = x_linear.detach()

# 计算BN前的数据均值和方差
mean_before = x_linear.mean(dim=0)
var_before = x_linear.var(dim=0)

# 应用BN
x_bn = model(x)
x_bn = x_bn.detach()

# 计算BN后的数据均值和方差
mean_after = x_bn.mean(dim=0)
var_after = x_bn.var(dim=0)

# 绘制BN前后数据的分布
fig, ax = plt.subplots(2, 2, figsize=(12, 8))

# 绘制BN前的数据分布
ax[0, 0].hist(x_linear.detach().numpy().flatten(), bins=30, color='blue', alpha=0.7)
ax[0, 0].set_title('Before BN: Data Distribution')

# 绘制BN后的数据分布
ax[0, 1].hist(x_bn.detach().numpy().flatten(), bins=30, color='green', alpha=0.7)
ax[0, 1].set_title('After BN: Data Distribution')

# 绘制BN前的数据均值和方差
ax[1, 0].bar(range(50), var_before, color='blue', alpha=0.7)
ax[1, 0].set_title('Before BN: Variance per Feature')
ax[1, 0].set_xticks(range(0, 50, 5))

# 绘制BN后的数据均值和方差
ax[1, 1].bar(range(50), var_after, color='green', alpha=0.7)
ax[1, 1].set_title('After BN: Variance per Feature')
ax[1, 1].set_xticks(range(0, 50, 5))

plt.tight_layout()
plt.show()

# 打印BN前后的数据均值和方差
print(f"Mean before BN: {mean_before}")
print(f"Mean after BN: {mean_after}")
print(f"Variance before BN: {var_before}")
print(f"Variance after BN: {var_after}")

四、总结

        BN在模型架构搭建的过程中很常用,但也并不是必要的,例如我们在数据预处理的过程中就对数据进行了归一化的话,BN就显得不那么重要了。当然,加上BN在某些场景中也能够使得模型适应不同批次数据的分布变化,且一定程度可以减少内部协变量偏移。具体是否增加BN可以看加入之后是否提升了模型性能。

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

你遇到的问题是: > **条形图只有 "NORMAL" 类别,其他故障类型没有显示** 这说明在你的预测结果中,**所有目标域样本都被模型预测为 `Normal`(类别 0)**,导致 `sns.countplot(x=final_pred, ...)` 只显示一个柱子。 --- ### ❌ 问题分析:为什么全都是 NORMAL? 以下是常见原因及排查方法: | 原因 | 检查方式 | 解决方案 | |------|--------|---------| | ✅ 1. 模型未学到有效特征 | 查看 t-SNE 是否源/目标严重不重叠 | 加强 MMD 对齐权重或调整网络结构 | | ✅ 2. MMD 权重 λ 太小或太大 | 当前代码用的是 `loss = cls_loss + 0.3 * mmd_loss` | 尝试调整 λ = 0.1 ~ 1.0 | | ✅ 3. 学习率过高导致训练不稳定 | Adam(lr=1e-3) 可能过大 | 改为 `lr=5e-4` 或使用更小 batch | | ✅ 4. 数据标准化错误 | 目标域用了 fit_transform 而非 transform | ✅ 已正确使用 `scaler.transform(X_target)` ✔️ | | ✅ 5. 标签分布偏移严重 | 源域正常样本过多 → 模型偏向多数类 | 使用加权损失函数或过采样 | | ✅ 6. 网络容量不足或过拟合 | Bottleneck 层太窄(32维)或 dropout 过大 | 减少 dropout、增加宽度 | | ✅ 7. 训练 epoch 不足 | 当前只训了 100 轮 | 增加到 200~300 轮观察收敛情况 | --- ### 🔍 第一步:先检查预测结果分布 插入这段代码来确认是不是真的“全是 Normal”: ```python print("📊 预测标签分布:") unique, counts = np.unique(final_pred, return_counts=True) for label, count in zip(unique, counts): print(f" {class_names[label]}: {count} 个") ``` 如果输出如下: ``` 📊 预测标签分布: Normal: 100 个 ``` → 说明模型存在严重的**类别偏见**。 --- ### ✅ 解决方案建议(按优先级排序) #### ✅ 1. **调大 MMD 损失权重 λ** 当前设置:`λ = 0.3`,可能不足以对齐域差异。 👉 修改为: ```python loss = cls_loss + 1.0 * mmd_loss # 提高迁移强度 ``` #### ✅ 2. **降低学习率防止震荡** ```python optimizer = optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4) # 原来是 1e-3 ``` #### ✅ 3. **增加训练轮数至 200** 修改 `train_ddc(..., epochs=200)` 并在循环里打印更多日志: ```python if (epoch + 1) % 50 == 0 or epoch == 0: print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, CLS: {cls_loss:.4f}, MMD: {mmd_loss:.4f}") ``` #### ✅ 4. **可视化训练过程中的分类准确率(源域)** 加入简单评估: ```python # 在每个 epoch 后计算源域准确率 pred_labels = torch.argmax(pred_s, dim=1) acc = (pred_labels == ys_batch).float().mean().item() print(f"[Epoch {epoch+1}] Source Acc: {acc:.2f}") ``` #### ✅ 5. **检查输入数据维度是否一致** 确保 `X_source_scaled.shape[1] == X_target_scaled.shape[1] == 41` 添加断言: ```python assert X_source_scaled.shape[1] == X_target_scaled.shape[1], "特征维度不匹配!" ``` #### ✅ 6. **尝试 CORAL 初步对齐后再送入 DDC** 在训练前先做一次线性对齐: ```python from scipy.linalg import sqrtm def coral_align(source, target): d = source.shape[1] C_src = np.cov(source.T) + 1e-6 * np.eye(d) C_tar = np.cov(target.T) + 1e-6 * np.eye(d) A_coral = np.dot(sqrtm(np.linalg.inv(C_tar)), sqrtm(C_src)) return np.dot(source, A_coral), target # 对齐后再标准化 X_source_aligned, X_target_aligned = coral_align(X_source, X_target) scaler = StandardScaler() X_source_scaled = scaler.fit_transform(X_source_aligned) X_target_scaled = scaler.transform(X_target_aligned) ``` --- ### ✅ 快速验证:手动构造混合标签测试绘图 临时改一下 `final_pred` 测试图表能否正常显示多类: ```python # 🧪 测试用:强制制造多样性(仅用于调试绘图) # final_pred[:20] = 1 # Outer # final_pred[20:40] = 2 # Inner # final_pred[40:60] = 3 # Ball # final_pred[60:] = 0 # Normal ``` 运行后看柱状图是否出现四个类别。如果是 → 说明问题是出在**模型预测单一化**,而非绘图代码。 --- ### ✅ 改进后的完整 `plot_confusion_or_distribution` 示例 ```python # 绘制预测分布(改进版) plt.figure(figsize=(8, 5)) counts = pd.Series(final_pred).map({0: 'Normal', 1: 'Outer Race', 2: 'Inner Race', 3: 'Ball'}).value_counts() sns.barplot(x=counts.index, y=counts.values, palette='Set2') for i, v in enumerate(counts.values): plt.text(i, v + 0.5, str(v), ha='center', fontweight='bold') plt.title('Predicted Fault Distribution on Target Domain') plt.ylabel('Number of Samples') plt.xlabel('Fault Type') plt.grid(True, axis='y', linestyle='--', alpha=0.5) plt.tight_layout() plt.savefig('ddc_predicted_distribution_fixed.png', dpi=150) plt.show() ``` --- ### ✅ 总结:你现在应该做的事 | 步骤 | 操作 | |------|------| | 1️⃣ | 打印 `np.unique(final_pred, return_counts=True)` 查看真实分布 | | 2️⃣ | 将 MMD 权重从 `0.3` 提高到 `1.0` | | 3️⃣ | 把学习率从 `1e-3` 降到 `5e-4` | | 4️⃣ | 把 epoch 数从 100 增加到 200 | | 5️⃣ | 添加源域分类准确率监控 | | 6️⃣ | (可选)先用 CORAL 对齐再训练 DDC | --- 修复后,你应该能看到四种类别的合理分布! ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值