
反向传播算法(Backpropagation Algorithm)
反向传播(Backpropagation,简称 BP)是一种高效的梯度下降算法,用于优化神经网络的权重和偏置。它通过计算损失函数相对于模型参数的梯度,从而调整参数以最小化损失函数。
1. 基本思想
- 核心目标是最小化损失函数
,其中:
:真实值。
:模型预测值。
- BP 算法通过链式法则逐层传播误差,从输出层到输入层,逐步计算梯度并更新参数。
2. 反向传播的步骤
(1) 前向传播
- 输入数据通过网络计算输出(预测值
)。
- 计算每层的激活值
和激活后的输出
。
- 在最后一层计算损失函数
。
(2) 计算损失
- 使用损失函数(如均方误差、交叉熵)来衡量预测值和真实值之间的差距。
(3) 反向传播
- 计算输出层误差:
- 对于输出层:
其中,
表示输出层的误差。
- 对于输出层:
- 传播到隐藏层:
- 对于每一层
:
- 对于每一层
- 计算梯度:
- 权重梯度:
- 偏置梯度:
- 权重梯度:
(4) 参数更新
- 使用梯度下降法更新参数:
其中,为学习率。
3. 公式解读
- 前向传播公式:
- 误差传播公式:
- 梯度公式:
4. Python 实现
以下是反向传播算法的基本实现:
import numpy as np
# 激活函数及其导数(以Sigmoid为例)
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def sigmoid_derivative(x):
return sigmoid(x) * (1 - sigmoid(x))
# 初始化参数
np.random.seed(42)
input_size, hidden_size, output_size = 3, 4, 1
W1 = np.random.randn(hidden_size, input_size)
b1 = np.random.randn(hidden_size, 1)
W2 = np.random.randn(output_size, hidden_size)
b2 = np.random.randn(output_size, 1)
# 数据示例
X = np.random.randn(input_size, 1) # 输入
y = np.array([[1]]) # 真实值
# 前向传播
z1 = np.dot(W1, X) + b1
a1 = sigmoid(z1)
z2 = np.dot(W2, a1) + b2
a2 = sigmoid(z2) # 预测值
# 损失函数:均方误差
loss = 0.5 * (y - a2) ** 2
# 反向传播
# 输出层误差
delta2 = (a2 - y) * sigmoid_derivative(z2)
dW2 = np.dot(delta2, a1.T)
db2 = delta2
# 隐藏层误差
delta1 = np.dot(W2.T, delta2) * sigmoid_derivative(z1)
dW1 = np.dot(delta1, X.T)
db1 = delta1
# 参数更新
learning_rate = 0.1
W2 -= learning_rate * dW2
b2 -= learning_rate * db2
W1 -= learning_rate * dW1
b1 -= learning_rate * db1
print("Updated Weights and Biases:")
print("W1:", W1)
print("b1:", b1)
print("W2:", W2)
print("b2:", b2)
输出结果
Updated Weights and Biases:
W1: [[ 0.49721568 -0.13841431 0.65085341]
[ 1.52294983 -0.23412944 -0.23464195]
[ 1.57955459 0.76733251 -0.46731765]
[ 0.54337753 -0.4636622 -0.46057106]]
b1: [[ 0.23974092]
[-1.9129258 ]
[-1.7264316 ]
[-0.5659083 ]]
W2: [[-1.00965471 0.31553744 -0.90592943 -1.40730743]]
b2: [[1.47591079]]
5. 总结
- 反向传播的作用:通过计算损失函数对参数的梯度,指导权重和偏置更新,从而最小化误差。
- 优点:高效、易于实现,适合深度神经网络。
- 局限性:可能出现梯度消失或梯度爆炸问题,特别是在深层网络中。
- 改进:结合优化算法(如 Adam、RMSProp)和正则化技术,增强 BP 算法的性能。
&spm=1001.2101.3001.5002&articleId=144451158&d=1&t=3&u=15e010a3afd243e4b1ec19e97088ee10)
1011

被折叠的 条评论
为什么被折叠?



