【机器学习】均方误差(MSE:Mean Squared Error)

均方误差(Mean Squared Error, MSE)是衡量预测值与真实值之间差异的一种方法。在统计学和机器学习中,MSE 是一种常见的损失函数,用于评估模型的预测准确性。

均方误差的定义

假设有一组真实值 y_1, y_2, \ldots, y_n​ 和模型预测的对应值 \hat{y}_1, \hat{y}_2, \ldots, \hat{y}_n​。均方误差的定义如下:

\text{MSE} = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2

其中:

  • y_i 是第 i 个真实值。
  • \hat{y}_i 是第 i 个预测值。
  • n 是数据点的总数。

公式解析

  • 误差:每个预测值与真实值的差异称为误差,记为 y_i - \hat{y}_i
  • 平方:每个误差的平方 (y_i - \hat{y}_i)^2 消除了正负误差的抵消作用,保证误差总量为正。
  • 均值:将所有平方误差求和并取平均,以得到整体误差的平均值,这样可以反映出模型的整体预测误差。

特点

  • 非负性:均方误差总是非负的,因为平方项总是非负。
  • 敏感度:MSE 对于离群值(极大或极小误差)非常敏感,因为平方会放大较大误差的影响。

均方误差的应用

  • 回归分析:在回归问题中,MSE 被用来衡量模型预测值与实际观测值之间的差异,常用于模型的训练和验证。
  • 机器学习模型评估:MSE 是评估回归模型的一种常用指标,比如线性回归、决策树回归、神经网络等。

示例

假设有 3 个真实值 y = [3, -0.5, 2] 和模型的预测值 \hat{y} = [2.5, 0.0, 2],则均方误差为:

\text{MSE} = \frac{1}{3} \left( (3 - 2.5)^2 + (-0.5 - 0.0)^2 + (2 - 2)^2 \right) = \frac{1}{3} (0.25 + 0.25 + 0) = 0.1667

均方误差的优缺点

  • 优点:简单且广泛使用,适合衡量模型误差。
  • 缺点:对异常值非常敏感,可能不适合含有离群值的数据集。

Python 实现代码

以下代码用于计算 MSE 值:

import numpy as np

def mse(y_true, y_pred):
    return np.mean((y_pred - y_true) ** 2)

# 示例
y_true = np.array([3, -0.5, 2, 7])
y_pred = np.array([2.5, 0.0, 2, 8])

result = mse(y_true, y_pred)
print("MSE:", result)

MSE 图解示例

下面的图展示了真实值和预测值之间的差异及其平方误差,用于更直观地理解 MSE。

# Re-import necessary libraries due to session context reset
import numpy as np
import matplotlib.pyplot as plt

# Generate some sample data points for MSE visualization
np.random.seed(0)
x = np.linspace(0, 10, 10)
y_true = 2 * x + 1                      # True relationship (e.g., ground truth values)
y_pred = y_true + np.random.normal(0, 3, 10) # Predicted values with random noise

# Calculate MSE
mse_value = np.mean((y_true - y_pred) ** 2)

# Plotting true vs predicted values with error lines
plt.figure(figsize=(10, 6))
plt.plot(x, y_true, label="True Values", color="blue", marker='o')
plt.plot(x, y_pred, label="Predicted Values", color="red", marker='x')

# Add residual lines for MSE (error lines)
for i in range(len(x)):
    plt.plot([x[i], x[i]], [y_true[i], y_pred[i]], color='gray', linestyle='dotted')

# Adding text and labels
plt.xlabel("x")
plt.ylabel("y")
plt.title(f"Illustration of Mean Squared Error (MSE)\nMSE = {mse_value:.2f}")
plt.legend()
plt.grid(True)
plt.show()

上图展示了均方误差(MSE)的计算过程:

  • 蓝色圆点连线 表示真实值 y
  • 红色叉点连线 表示预测值 \hat{y}
  • 灰色虚线 表示每个预测值和真实值之间的误差,即残差(差异)。

这些残差的平方的平均值即为 MSE。本图示帮助理解预测值与真实值之间的差异以及如何计算它们的平方误差。


为什么要使用误差的平方而不直接使用误差的绝对值

使用误差的平方而不直接使用误差的绝对值主要有以下几个原因:

1. 数学性质

  • 可导性:均方误差(MSE)是一个连续且可导的函数,这使得我们在优化算法(如梯度下降法)中能够轻松计算导数和进行更新。而绝对误差(Mean Absolute Error, MAE)在误差为零时不可导,这在某些优化算法中可能会造成困难。

2. 对离群值的敏感性

  • 放大离群值影响:平方误差对较大的误差(离群值)非常敏感,因为它们的平方会显著增加总误差的值。这使得模型能够更好地识别并调整较大的预测错误。在某些应用中,尤其是对大误差特别关注的场景,使用平方误差可以帮助改善模型性能。

3. 简化计算

  • 解析解和算法效率:使用平方误差可以使许多计算过程变得更简单。例如,在最小二乘法中,通过对平方误差进行最小化可以得到解析解,这在处理线性回归等问题时非常有用。

4. 标准正态分布假设

  • 假设分布:在许多统计建模和机器学习的背景下,假设误差是正态分布的是常见的。使用平方误差的损失函数与这种正态分布假设一致,适合于基于最大似然估计的参数估计。

5. 平滑性

  • 函数平滑:平方函数是平滑的,优化过程中的小变化不会导致函数值发生剧烈变化,这使得收敛过程更加稳定和可靠。

6. 对称性

  • 误差符号的处理:平方误差可以消除正负误差的影响,而绝对误差只能给出误差的大小,不能处理多种情况的平衡。

示例对比

假设有三个真实值 y = [2, 3, 5] 和对应的预测值 \hat{y} = [2.5, 2, 4.5],我们来计算每个误差的绝对值和平方。

  1. 计算误差

    • y_1 - \hat{y}_1 = 2 - 2.5 = -0.5
    • y_2 - \hat{y}_2 = 3 - 2 = 1
    • y_3 - \hat{y}_3 = 5 - 4.5 = 0.5
  2. 绝对误差

    • |y_1 - \hat{y}_1| = | -0.5 | = 0.5
    • |y_2 - \hat{y}_2| = | 1 | = 1
    • |y_3 - \hat{y}_3| = | 0.5 | = 0.5

    总绝对误差

    \text{MAE} = \frac{1}{3} \left( 0.5 + 1 + 0.5 \right) = \frac{2}{3} \approx 0.67
  3. 平方误差

    • (y_1 - \hat{y}_1)^2 = (-0.5)^2 = 0.25
    • (y_2 - \hat{y}_2)^2 = (1)^2 = 1
    • (y_3 - \hat{y}_3)^2 = (0.5)^2 = 0.25

    总平方误差

    \text{MSE} = \frac{1}{3} \left( 0.25 + 1 + 0.25 \right) = \frac{1.5}{3} = 0.5

关键点总结

  • 绝对误差是对误差的绝对值的简单累加,不考虑误差的方向(正负)。
  • 平方误差则是对每个误差进行平方处理,从而放大了较大误差的影响,有助于强调模型在较大误差上的表现。

结论

虽然绝对误差在某些情况下也非常有用,尤其是在关注中位数和稳健性时,但均方误差在优化、模型训练和统计推断中有其独特的优势。因此,选择使用平方误差还是绝对误差通常取决于具体问题的需求和模型的特性。

以下是对提供的三个函数逐行代码含义的解释: ### `visual` 函数 ```python def visual(true, preds=None, name='./pic/test.pdf'): """ Results visualization """ plt.figure() plt.plot(true, label='GroundTruth', linewidth=2) if preds is not None: plt.plot(preds, label='Prediction', linewidth=2) plt.legend() plt.savefig(name, bbox_inches='tight') ``` - `def visual(true, preds=None, name='./pic/test.pdf'):`:定义一个名为 `visual` 的函数,该函数接受三个参数。`true` 是必需的,通常代表真实值;`preds` 是可选参数,默认为 `None`,通常代表预测值;`name` 是可选参数,默认为 `'./pic/test.pdf'`,表示保存可视化图片的文件名。 - `""" Results visualization """`:这是函数的文档字符串,用于描述函数的功能,这里表明该函数用于结果可视化。 - `plt.figure()`:使用 `matplotlib` 库创建一个新的图形窗口。 - `plt.plot(true, label='GroundTruth', linewidth=2)`:在图形窗口中绘制 `true` 数据的折线图,标签为 `'GroundTruth'`,线条宽度为 2。 - `if preds is not None:`:检查 `preds` 是否不为 `None`。 - `plt.plot(preds, label='Prediction', linewidth=2)`:如果 `preds` 不为 `None`,则在图形窗口中绘制 `preds` 数据的折线图,标签为 `'Prediction'`,线条宽度为 2。 - `plt.legend()`:在图形窗口中显示图例,用于区分不同的线条。 - `plt.savefig(name, bbox_inches='tight')`:将当前图形保存为指定文件名 `name` 的文件,`bbox_inches='tight'` 表示去除图形周围的空白区域。 ### `adjustment` 函数 ```python def adjustment(gt, pred): anomaly_state = False for i in range(len(gt)): if gt[i] == 1 and pred[i] == 1 and not anomaly_state: anomaly_state = True for j in range(i, 0, -1): if gt[j] == 0: break else: if pred[j] == 0: pred[j] = 1 for j in range(i, len(gt)): if gt[j] == 0: break else: if pred[j] == 0: pred[j] = 1 elif gt[i] == 0: anomaly_state = False if anomaly_state: pred[i] = 1 return gt, pred ``` - `def adjustment(gt, pred):`:定义一个名为 `adjustment` 的函数,接受两个参数 `gt` 和 `pred`,通常分别代表真实标签和预测标签。 - `anomaly_state = False`:初始化一个布尔变量 `anomaly_state` 为 `False`,用于标记是否处于异常状态。 - `for i in range(len(gt)):`:遍历 `gt` 列表的每个元素。 - `if gt[i] == 1 and pred[i] == 1 and not anomaly_state:`:如果当前位置的真实标签和预测标签都为 1,且当前不处于异常状态。 - `anomaly_state = True`:将 `anomaly_state` 标记为 `True`,表示进入异常状态。 - `for j in range(i, 0, -1):`:从当前位置 `i` 向前遍历。 - `if gt[j] == 0:`:如果遇到真实标签为 0 的位置。 - `break`:跳出当前循环。 - `else:`:如果真实标签不为 0。 - `if pred[j] == 0:`:如果预测标签为 0。 - `pred[j] = 1`:将预测标签修改为 1。 - `for j in range(i, len(gt)):`:从当前位置 `i` 向后遍历。 - 后续的 `if` 和 `else` 语句逻辑与向前遍历类似,用于处理向后的情况。 - `elif gt[i] == 0:`:如果当前位置的真实标签为 0。 - `anomaly_state = False`:将 `anomaly_state` 标记为 `False`,表示退出异常状态。 - `if anomaly_state:`:如果当前处于异常状态。 - `pred[i] = 1`:将当前位置的预测标签修改为 1。 - `return gt, pred`:返回修改后的真实标签和预测标签。 ### `cal_accuracy` 函数 ```python def cal_accuracy(y_pred, y_true): return np.mean(y_pred == y_true) ``` - `def cal_accuracy(y_pred, y_true):`:定义一个名为 `cal_accuracy` 的函数,接受两个参数 `y_pred` 和 `y_true`,分别代表预测标签和真实标签。 - `return np.mean(y_pred == y_true)`:计算 `y_pred` 和 `y_true` 对应元素相等的比例,即准确率。`y_pred == y_true` 会返回一个布尔数组,其中元素为 `True` 表示对应位置的预测标签和真实标签相等,为 `False` 则表示不相等。`np.mean()` 函数会将布尔数组中的 `True` 视为 1,`False` 视为 0,然后计算平均值,得到准确率[^3]。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值