使用 Kolmogorov-Arnold 网络进行时间序列异常检测

导 读

🚨 当数据开始「说谎」:这些场景正在发生

  • 💸 金融市场:高频交易中0.1秒的异常波动,可能吞噬千万级资金

  • 🏭 工业产线:温度传感器0.5℃的微妙偏移,或是设备爆炸的前奏

  • 🏥 医疗监测:心电图上一个看似无害的波形畸变,可能预示致命性心律失常

但致命问题在于:这些异常往往不是简单的「离群点」,而是深藏在复杂时间模式中的**「变色龙」**——它们只在特定上下文(设备启动阶段、节假日流量高峰)现形,让传统检测方法难以检测!

🔥 传统方法为何沦为「工业瞎子」?三大死穴全解剖

检测维度

ARIMA

孤立森林

法则

非线性突变

完全失效

随机误报

连续误判

上下文感知

零感知能力

局部失效

全局盲区

多变量关联

单变量局限

维度灾难

孤立检测

💡 血泪教训:某风电集团因齿轮箱温度-振动-电流的耦合异常未被及时捕获,导致21台机组连环故障,直接损失超2.3亿!

⚡ KANs:一个数学定理掀起的工业革命

1957年埋下的种子:Kolmogorov-Arnold定理证明——任何复杂函数都可拆解为简单函数的组合,这一数学圣杯在66年后终于绽放工业光芒

三大破局利器

  1. 动态函数乐高:每条神经网络连接独立学习激活函数,像拼乐高般动态重组检测逻辑

  2. 时空多尺度透镜

    • 浅层网络捕捉毫秒级毛刺(如传感器噪声)

    • 深层网络解析设备生命周期曲线(如轴承磨损趋势)

    • 可解释决策树:可视化函数组合路径,让「黑箱决策」变为可追溯的故障诊断报告

代码已开源,有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、技术背景

柯尔莫哥洛夫-阿诺德表示定理指出,任何多元连续函数都可以表示为单变量连续函数与加法的有限和与组合。KAN 可以表示为,对于任何连续函数,存在 q 和 ψpq 的连续函数 ϕ,使得:.

该定理为柯尔莫哥洛夫-阿诺德网络提供了理论基础,该网络通过组合和求和单变量函数来近似复杂的多变量函数。这使得 KAN 中的激活函数可以在网络边执行,从而使激活函数“可学习”并提高其性能。

为什么使用 KAN 进行时间序列异常检测?

以下是 KAN 适用于时间序列异常检测的几个原因:

1. 通用函数近似:KAN 可以近似任何连续函数,使其能够准确地模拟正常时间序列数据中的底层模式。

2.对异常的敏感性:异常是指与正常模式的偏差。KAN 能够精确模拟正常行为,因此对此类偏差非常敏感。

3. 分层特征学习:通过堆叠层,KAN 可以捕获局部和全局模式,这对于检测不同类型的异常(点、上下文和集体)至关重要。

02、数据准备与构建

本教程将展示如何对合成数据运行基于 KAN 的时间序列异常检测。

我们首先将连续时间序列分割成窗口,每个窗口捕获一个特定的数据点窗口,其中包含正常模式和潜在异常。

窗口化是一种将连续的时间序列数据分割成更小、更易于管理的块(称为窗口或段)的技术。

每个窗口捕获时间序列的特定部分,从而允许模型分析该部分内的模式和依赖关系

我们将使用重叠窗口,即窗口间隔重叠。这样可以更准确地检测异常,并有助于学习 KAN 跨越多个窗口的时间依赖性。

然后根据窗口中是否存在任何异常行为将这些窗口分为不同的类别。

类别定义:

  • 正常窗口(类 0) :包含任何异常的窗口。

  • 异常窗口(第 1 类) :包含至少一个异常的窗口。

我们确保使用这些类标记窗口,以在我们的项目(训练、验证、测试)的每个数据集中保持异常窗口数量的平衡。

为了平衡每组数据的异常窗口数量,我们使用 SMOTE(合成少数类过采样技术)。该技术生成少数类(异常)的合成样本来平衡数据集,确保模型在训练期间收到两个类的足够示例。

为了有效地从这些数据中学习,模型处理每个窗口,通过多层基于傅里叶的变换提取复杂的时间特征,从而增强其识别预期模式的细微偏差的能力。

通过对这些结构化片段进行训练,KAN 模型学会区分典型序列和异常序列,从而能够准确识别新的、未见过的数据中的异常。

我们将使用这组库:

  • PyTorch:用于构建和训练神经网络。
  • NumPy:用于数值计算。
  • Scikit-learn:用于评估指标和数据预处理。
  • Matplotlib:用于数据可视化。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    precision_recall_curve,
    roc_curve,
    auc,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

以下是我们将针对该模型使用的参数。

关键参数:

  • window_size:时间序列上的滑动窗口的大小。

  • step_size:移动滑动窗口的步长。

  • anomaly_fraction:包含异常的数据窗口比例。

  • dropout:丢弃率,以防止过度拟合。

  • hidden_size:KAN 中隐藏层的大小。

  • grid_size:KAN层中使用的傅里叶频率的数量。

  • n_layers:模型中堆叠的 KAN 层数。

  • epochs:训练周期数。

  • early_stopping:训练期间提前停止的耐心参数。

  • lr:优化器的学习率。

  • batch_size:训练时每批的样本数量。

确保运行.seed()函数来验证此模型中使用的数字的随机性

class Args:
    path = "./data/"
    dropout = 0.3
    hidden_size = 128
    grid_size = 50
    n_layers = 2
    epochs = 200
    early_stopping = 30  # Increased patience
    seed = 42
    lr = 1e-3  # Increased learning rate
    window_size = 20
    step_size = 10
    batch_size = 32
    anomaly_fraction = 0.1

args = Args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)

生成合成数据

我们的异常时间序列数据将按以下方式构建:

  • 正弦波生成:我们使用创建正弦波np.sin(x)

  • 异常注入:随机选择窗口中心并根据类型注入异常。

  • 标记:创建一个标签数组来标记异常

我们将使用三种策略在正弦波时间序列中引入异常。

  1. 点异常:向单个数据点添加较大的随机噪声。

  2. 情境异常:特定 正弦值被随机因素放大。

  3. 集体异常:通过添加噪声来改变连续数据点的窗口 ,以模拟更大的群体级干扰。

将异常插入正弦波中随机选择的位置,并创建附带的标签数组,标记异常的位置。

生成具有异常的正弦波后,我们使用StandardScalerscikit-learn对时间序列进行规范化和缩放。

def generate_sine_wave_with_anomalies(
    length=5000, anomaly_fraction=0.1, window_size=20, step_size=10
):
    x = np.linspace(0, 100 * np.pi, length)
    y = np.sin(x)
    labels = np.zeros(length)
    window_centers = list(range(window_size // 2, length - window_size // 2, step_size))
    num_anomalies = int(len(window_centers) * anomaly_fraction)
    anomaly_centers = np.random.choice(window_centers, num_anomalies, replace=False)
    for center in anomaly_centers:
        anomaly_type = np.random.choice(['point', 'contextual', 'collective'])
        if anomaly_type == 'point':
            y[center] += np.random.normal(0, 10)
            labels[center] = 1
        elif anomaly_type == 'contextual':
            y[center] = y[center] * np.random.uniform(1.5, 2.0)
            labels[center] = 1
        elif anomaly_type == 'collective':
            start = max(0, center - window_size // 2)
            end = min(length, center + window_size // 2)
            y[start:end] += np.random.normal(0, 5, size=end - start)
            labels[start:end] = 1  # Mark the entire window as anomalous
    return y, labels


time_series, labels = generate_sine_wave_with_anomalies(
    length=5000,
    anomaly_fraction=args.anomaly_fraction,
    window_size=args.window_size,
    step_size=args.step_size,
)
scaler = StandardScaler()
time_series = scaler.fit_transform(time_series.reshape(-1, 1)).flatten()

设置数据集

现在我们已经设计了TimeSeriesAnomalyDataset,我们将实例化它并验证时间序列数据集的结构。

重要细节:

  • num_pos表示包含异常的窗口总数

  • num_neg表示无异常的窗口总数

  • 我们确保检查是否有异常的窗户

class TimeSeriesAnomalyDataset(torch.utils.data.Dataset):
    def __init__(
        self, time_series, labels, window_size=20, step_size=10, transform=None
    ):
        self.time_series = time_series
        self.labels = labels
        self.window_size = window_size
        self.step_size = step_size
        self.transform = transform
        self.sample_indices = list(
            range(0, len(time_series) - window_size + 1, step_size)
        )

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        if idx >= len(self.sample_indices) or idx < 0:
            raise IndexError(
                f"Index {idx} out of range for sample_indices of length {len(self.sample_indices)}"
            )
        i = self.sample_indices[idx]
        window = self.time_series[i : i + self.window_size]
        window_labels = self.labels[i : i + self.window_size]

        # Input features: window values
        x = torch.tensor(window, dtype=torch.float).unsqueeze(-1)  # Shape: [window_size, 1]

        # Label: 1 if any point in the window is an anomaly, else 0
        y = torch.tensor(1.0 if window_labels.any() else 0.0, dtype=torch.float)

        return x, y

    def indices(self):
        return self.sample_indices

分割数据集

我们将将时间序列数据分为训练、验证和测试集,使用默认比例 60% 用于训练、20% 用于验证、20% 用于测试。

def stratified_split(
    dataset, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2, seed=42
):
    labels = [y.item() for _, y in dataset]

    train_val_indices, test_indices = train_test_split(
        np.arange(len(labels)),
        test_size=test_ratio,
        stratify=labels,
        random_state=seed,
    )

    val_relative_ratio = val_ratio / (train_ratio + val_ratio)

    train_indices, val_indices = train_test_split(
        train_val_indices,
        test_size=val_relative_ratio,
        stratify=[labels[i] for i in train_val_indices],
        random_state=seed,
    )

    return train_indices, val_indices, test_indices

train_indices, val_indices, test_indices = stratified_split(dataset, seed=args.seed)

print(
    f"Train samples: {len(train_indices)}, Val samples: {len(val_indices)}, Test samples: {len(test_indices)}"
)

验证异常

我们计算每组数据的异常数。

def count_anomalies(dataset_subset, name):
    labels = [y.item() for _, y in dataset_subset]
    num_anomalies = int(sum(labels))
    num_normals = len(labels) - num_anomalies
    print(f"{name} - Anomalies: {num_anomalies}, Normals: {num_normals}")

train_dataset = torch.utils.data.Subset(dataset, train_indices)
val_dataset = torch.utils.data.Subset(dataset, val_indices)
test_dataset = torch.utils.data.Subset(dataset, test_indices)

count_anomalies(train_dataset, "Train")
count_anomalies(val_dataset, "Validation")
count_anomalies(test_dataset, "Test")

平衡数据集

我们使用SMOTE(合成少数过采样技术)来平衡数据集。

SMOTE(合成少数类过采样技术)是一种用于解决数据集中类别不平衡问题的技术。当一个类的实例数量明显多于另一个类时,模型可能会偏向多数类。

SMOTE 的工作原理

  • 合成样本生成:SMOTE 通过在现有少数类实例之间进行插值来生成少数类的合成样本。

  • K-最近邻:对于每个少数类实例,SMOTE 选择它的 k 个最近邻之一,并沿着连接它们的线创建一个合成点。

  • 平衡数据集:通过添加合成样本,SMOTE 平衡类别分布,使模型能够有效地从两个类别中学习。

然后,我们创建一个ResampledDataset与 Pytorch 的 DataLoader 配合使用的类。此后,我们在 SMOTE 之后验证新的类分布。

X_train = [x.numpy().flatten() for x, _ in train_dataset]
y_train = [int(y.item()) for _, y in train_dataset]


from imblearn.over_sampling import SMOTE
smote = SMOTE(random_state=args.seed)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

class ResampledDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = [torch.tensor(x, dtype=torch.float).view(-1, 1) for x in X]
        self.y = [torch.tensor(label, dtype=torch.float) for label in y]
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

balanced_train_dataset = ResampledDataset(X_resampled, y_resampled)

train_loader = torch.utils.data.DataLoader(
    balanced_train_dataset, batch_size=args.batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=args.batch_size, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=False
)

total_anomalies = y_resampled.count(1)
total_normals = y_resampled.count(0)
print(f"Balanced Training Set - Anomalies: {total_anomalies}, Normals: {total_normals}")

# Adjust the pos_weight for BCEWithLogitsLoss
pos_weight = total_normals / total_anomalies
print(f"Using pos_weight: {pos_weight:.4f}")

03、时序异常检测代码

在构建 KAN 之前,我们将定义NaiveFourierKANLayer,它使用傅里叶特征(正弦和余弦变换,即该模型中的“激活函数”)来转换输入数据,增强模型捕捉复杂时间模式的能力。

初始化时,该层设置可学习的傅里叶系数和可选的偏差项。

在前向传递过程中,它首先生成一系列频率,并对这些频率的输入数据应用正弦和余弦变换,从而有效地创建一组丰富的傅里叶特征。

然后使用学习到的傅里叶系数将这些变换后的特征线性组合起来,产生具有所需维数的输出。

NaiveFourierKANLayer分解:

  1. 初始化

  • gridsize:要使用的傅里叶频率的数量。

  • fouriercoeffs:可学习形状的参数[2 * gridsize, inputdim, outdim]。因子 2 表示正弦和余弦分量。

  • 标准化:用比例因子初始化系数以保持稳定的梯度。

2.前向传递

  • 输入形状x具有形状[batch_size, window_size, inputdim]

  • 频率范围(k:我们创建一个频率从 1 到 的张量gridsize。你可以将其理解为 k=[1,2,…,gridsize]

计算角度

θ = x × k × π如下式计算:

angles = x_expanded * k * np.pi

确保每个输入维度都乘以每个频率。

计算傅里叶特征

sin_features = sin(θ)计算如下:

sin_features = torch.sin(angles)

cos_features=cos(θ)计算如下:

cos_features = torch.cos(angles)

连接特征:我们沿频率维度连接正弦和余弦特征,得到形状为的特征张量[batch_size, window_size, inputdim, 2 * gridsize]

重塑矩阵乘法:展平批次和窗口尺寸以准备进行矩阵乘法:[batch_size * window_size, inputdim, 2 * gridsize]

矩阵乘法:使用爱因斯坦求和 ( torch.einsum) 执行矩阵乘法y= features⋅ fouriercoeffs

重塑背部

  • 重塑y[batch_size, window_size, outdim]

偏置添加:如果addbiasTrue,则添加偏差项。

输出张量的形状为[batch_size, window_size, outdim],表示每个窗口中的每个时间步长的变换特征,并丰富了基于频率的信息。

class NaiveFourierKANLayer(nn.Module):
    def __init__(self, inputdim, outdim, gridsize=50, addbias=True):
        super(NaiveFourierKANLayer, self).__init__()
        self.gridsize = gridsize
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        self.fouriercoeffs = nn.Parameter(
            torch.randn(2 * gridsize, inputdim, outdim)
            / (np.sqrt(inputdim) * np.sqrt(gridsize))
        )
        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(outdim))

    def forward(self, x):
        # x shape: [batch_size, window_size, inputdim]
        batch_size, window_size, inputdim = x.size()
        k = torch.arange(1, self.gridsize + 1, device=x.device).float()
        k = k.view(1, 1, 1, self.gridsize)
        x_expanded = x.unsqueeze(-1)  # [batch_size, window_size, inputdim, 1]
        angles = x_expanded * k * np.pi  # [batch_size, window_size, inputdim, gridsize]
        sin_features = torch.sin(angles)
        cos_features = torch.cos(angles)
        features = torch.cat([sin_features, cos_features], dim=-1)  # Concatenate on gridsize dimension
        features = features.view(batch_size * window_size, inputdim, -1)  # Flatten for matmul
        coeffs = self.fouriercoeffs  # [2 * gridsize, inputdim, outdim]
        y = torch.einsum('bik,kio->bo', features, coeffs)
        y = y.view(batch_size, window_size, self.outdim)
        if self.addbias:
            y += self.bias
        return y

现在我们将定义 KAN 架构。

  • 输入层:将输入映射到更高维的隐藏空间。

  • KAN 层:堆叠多个NaiveFourierKANLayer实例以捕获复杂的模式。

  • 全局平均池化:聚合时间窗口内的信息。

  • 输出层:产生最终的异常分数。

KAN分解:

  1. 初始化输入层(lin_in:从in_feat到 的线性变换hidden_feat

  • 批量标准化(bn_in:应用于输入层之后。

  • Dropout:正则化以防止过度拟合,促进对看不见的数据的泛化。

  • 隐藏层layers:实例列表NaiveFourierKANLayer

  • bns:相应的批量标准化层。

  • 输出层(lin_outhidden_feat :从到的线性变换out_feat(本例中为标量输出)。将聚合特征映射到标量异常分数,量化窗口包含异常的可能性。

2.前向传递

输入转换

  • 将输入传递x通过输入线性层。

  • 应用批量标准化和Leaky ReLU激活。

  • 應用 dropout。

隐藏层(对于每一层)

  • 将输出传递出去NaiveFourierKANLayer

  • 应用批量标准化和激活。

  • 應用 dropout。

全局平均池化

  • 对维度进行平均window_size以汇总时间信息。

输出层

  • 将池化特征传递到输出线性层。

  • 压缩输出以删除不必要的尺寸。

输出:

  • x:形状为的张量[batch_size],其中每个元素都是一个标量,表示批次中相应窗口的异常分数。

  • 分数越高,窗口包含异常的可能性就越大。

class KAN(nn.Module):
    def __init__(
        self,
        in_feat,
        hidden_feat,
        out_feat,
        grid_feat,
        num_layers,
        use_bias=True,
        dropout=0.3,
    ):
        super(KAN, self).__init__()
        self.num_layers = num_layers
        self.lin_in = nn.Linear(in_feat, hidden_feat, bias=use_bias)
        self.bn_in = nn.BatchNorm1d(hidden_feat)
        self.dropout = nn.Dropout(p=dropout)
        self.layers = nn.ModuleList()
        self.bns = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(
                NaiveFourierKANLayer(
                    hidden_feat, hidden_feat, grid_feat, addbias=use_bias
                )
            )
            self.bns.append(nn.BatchNorm1d(hidden_feat))
        self.lin_out = nn.Linear(hidden_feat, out_feat, bias=use_bias)

    def forward(self, x):
        # x shape: [batch_size, window_size, 1]
        batch_size, window_size, _ = x.size()
        x = self.lin_in(x)  # [batch_size, window_size, hidden_feat]
        x = self.bn_in(x.view(-1, x.size(-1))).view(batch_size, window_size, -1)
        x = F.leaky_relu(x, negative_slope=0.1)
        x = self.dropout(x)
        for layer, bn in zip(self.layers, self.bns):
            x = layer(x)
            x = bn(x.view(-1, x.size(-1))).view(batch_size, window_size, -1)
            x = F.leaky_relu(x, negative_slope=0.1)
            x = self.dropout(x)
        # Global average pooling over the window dimension
        x = x.mean(dim=1)  # [batch_size, hidden_feat]
        x = self.lin_out(x).squeeze()  # [batch_size]
        return x

定义损失函数:焦点损失

对于我们的项目,Focal Loss 是最适合计算检测异常时遗漏的损失函数。

Focal Loss是一种损失函数,旨在通过将训练重点放在难以分类的示例上来解决类别不平衡问题。

异常检测本质上处理不平衡的数据集,其中异常(正类)比正常实例(负类)明显少见。

这种不平衡对二元交叉熵 (BCE) 等传统损失函数构成了挑战,因为这些损失函数可能会被多数类别所淹没,从而导致模型偏向于预测多数类别而忽略少数(异常)类别。

对于二元分类,焦点损失定义为:

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * ((1 - pt) ** self.gamma) * BCE_loss
        if self.reduction == 'mean':
            return F_loss.mean()
        else:
            return F_loss.sum()

criterion = FocalLoss(alpha=0.25, gamma=2)

我们将使用Adam优化器和ReduceLROnPlateau学习率调度程序,当指定指标停止改进时降低学习率。

我们还将定义该evaluate_metrics函数。

我们将跟踪的指标包括:

  • precision:真实阳性预测与阳性预测总数的比例,衡量阳性预测的准确性。

  • recall:真实阳性预测与实际阳性实例总数的比例,衡量模型捕获所有实际异常的能力。

  • f1:精确度和召回率的调和平均值(平衡精确度和召回率,提供一个同时考虑假阳性和假阴性的单一指标)

  • roc_auc_val:接收者操作特征 (ROC) 曲线下的面积。ROC AUC 越高,表示在区分异常和正常情况方面的整体表现越好。

最后,我们将定义一个函数,该函数用于将预测概率转换为二进制类标签(对于正常,对于异常)find_optimal_threshold的阈值,从而最大化 F1 分数。

model = KAN(
    in_feat=1,
    hidden_feat=args.hidden_size,
    out_feat=1,
    grid_feat=args.grid_size,
    num_layers=args.n_layers,
    use_bias=True,
    dropout=args.dropout,
).to(args.device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5
)

def evaluate_metrics(true_labels, pred_labels, pred_probs):
    precision = precision_score(true_labels, pred_labels, zero_division=0)
    recall = recall_score(true_labels, pred_labels, zero_division=0)
    f1 = f1_score(true_labels, pred_labels, zero_division=0)
    roc_auc_val = roc_auc_score(true_labels, pred_probs)
    return precision, recall, f1, roc_auc_val

def find_optimal_threshold(probs, labels):
    precision_vals, recall_vals, thresholds = precision_recall_curve(labels, probs)
    f1_scores = 2 * (precision_vals * recall_vals) / (precision_vals + recall_vals + 1e-8)
    optimal_idx = np.argmax(f1_scores)
    if optimal_idx < len(thresholds):
        optimal_threshold = thresholds[optimal_idx]
    else:
        optimal_threshold = 0.5  # Default threshold
    optimal_f1 = f1_scores[optimal_idx]
    return optimal_threshold, optimal_f1

完成模型训练设置后,我们开始训练模型。我们有一个训练循环,用于训练模型并监控训练损失和准确率,还有一个验证循环,用于使用验证集监控模型的性能。请注意,如果模型没有显示出改善的迹象,我们会采用提前停止的方法。

# Training and validation loop with early stopping
best_val_f1 = 0
patience = args.early_stopping
patience_counter = 0
optimal_threshold = 0.5  # Initialize with default threshold

for epoch in range(args.epochs):
    # Training Phase
    model.train()
    total_loss = 0
    total_acc = 0
    total_preds_pos = 0  # Monitor number of positive predictions
    for x_batch, y_batch in train_loader:
        x_batch = x_batch.to(args.device)
        y_batch = y_batch.to(args.device)
        optimizer.zero_grad()
        out = model(x_batch)  # Output shape: [batch_size]
        loss = criterion(out, y_batch)
        loss.backward()
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item() * x_batch.size(0)
        probs = torch.sigmoid(out)
        preds = (probs > 0.5).float()
        acc = (preds == y_batch).float().mean().item()
        total_acc += acc * x_batch.size(0)
        total_preds_pos += preds.sum().item()
    avg_loss = total_loss / len(balanced_train_dataset)
    avg_acc = total_acc / len(balanced_train_dataset)

    print(f"Epoch {epoch+1}, Training Positive Predictions: {total_preds_pos}")

    # Validation Phase
    model.eval()
    val_loss = 0
    val_acc = 0
    all_true = []
    all_preds = []
    all_probs = []
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            x_batch = x_batch.to(args.device)
            y_batch = y_batch.to(args.device)
            out = model(x_batch)
            loss = criterion(out, y_batch)
            val_loss += loss.item() * x_batch.size(0)
            probs = torch.sigmoid(out)
            preds = (probs > 0.5).float()
            acc = (preds == y_batch).float().mean().item()
            val_acc += acc * x_batch.size(0)
            all_true.extend(y_batch.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    avg_val_loss = val_loss / len(val_dataset)
    avg_val_acc = val_acc / len(val_dataset)
    precision, recall, f1, roc_auc_val = evaluate_metrics(all_true, all_preds, all_probs)

    # Find Optimal Threshold
    current_threshold, current_f1 = find_optimal_threshold(all_probs, all_true)

    print(
        f"Epoch: {epoch+1:04d}, "
        f"Train Loss: {avg_loss:.4f}, Train Acc: {avg_acc:.4f}, "
        f"Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}, "
        f"Precision: {precision:.4f}, Recall: {recall:.4f}, "
        f"F1: {f1:.4f}, ROC AUC: {roc_auc_val:.4f}, "
        f"Optimal Threshold: {current_threshold:.4f}, Val F1: {current_f1:.4f}"
    )

    # Step the scheduler
    scheduler.step(avg_val_loss)

    # Early Stopping
    if f1 > best_val_f1:
        best_val_f1 = f1
        patience_counter = 0
        optimal_threshold = current_threshold  # Update optimal threshold
        # Save the best model
        torch.save(model.state_dict(), "best_kan_model.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.")
            break

为了测试,我们加载在验证集上表现最佳的模型并在测试集上对其进行评估。

model.load_state_dict(torch.load("best_kan_model.pth"))

model.eval()
test_loss = 0
test_acc = 0
all_true_test = []
all_preds_test = []
all_probs_test = []
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        x_batch = x_batch.to(args.device)
        y_batch = y_batch.to(args.device)
        out = model(x_batch)
        loss = criterion(out, y_batch)
        test_loss += loss.item() * x_batch.size(0)
        probs = torch.sigmoid(out)
        preds = (probs > optimal_threshold).float()
        acc = (preds == y_batch).float().mean().item()
        test_acc += acc * x_batch.size(0)
        all_true_test.extend(y_batch.cpu().numpy())
        all_preds_test.extend(preds.cpu().numpy())
        all_probs_test.extend(probs.cpu().numpy())
avg_test_loss = test_loss / len(test_dataset)
avg_test_acc = test_acc / len(test_dataset)
precision, recall, f1, roc_auc_val = evaluate_metrics(
    all_true_test, all_preds_test, all_probs_test
)

print(
    f"\nTest Loss: {avg_test_loss:.4f}, Test Acc: {avg_test_acc:.4f}, "
    f"Precision: {precision:.4f}, Recall: {recall:.4f}, "
    f"F1: {f1:.4f}, ROC AUC: {roc_auc_val:.4f}"
)

04、可视化

最后,我们编写函数来可视化异常检测测试。plot_anomalies创建一个时间序列可视化,其中红色图显示真实异常,黄色X显示模型预测异常的位置。aggregate_predictions汇总时间序列窗口上的预测以与原始时间序列对齐。plot_metrics绘制ROC和精确召回率曲线以可视化模型性能。

def plot_anomalies(time_series, labels, preds, start=0, end=1000):
    plt.figure(figsize=(15, 5))
    plt.plot(time_series[start:end], label="Time Series")
    plt.scatter(
        np.arange(start, end)[labels[start:end] == 1],
        time_series[start:end][labels[start:end] == 1],
        color="red",
        label="True Anomalies",
    )
    plt.scatter(
        np.arange(start, end)[preds[start:end] == 1],
        time_series[start:end][preds[start:end] == 1],
        color="orange",
        marker="x",
        label="Predicted Anomalies",
    )
    plt.legend()
    plt.title("Anomaly Detection")
    plt.xlabel("Time Step")
    plt.ylabel("Normalized Value")
    plt.show()
def aggregate_predictions(indices, preds, window_size, total_length):
    aggregated = np.zeros(total_length, dtype=float)
    counts = np.zeros(total_length, dtype=float)
    for idx, pred in zip(indices, preds):
        start = idx
        end = idx + window_size
        if end > total_length:
            end = total_length
        aggregated[start:end] += pred
        counts[start:end] += 1
    counts[counts == 0] = 1
    averaged = aggregated / counts
    return (averaged > 0.5).astype(int)


plot_metrics(all_true_test, all_probs_test)
test_sample_indices = [dataset.sample_indices[i] for i in test_indices]
aggregated_preds = aggregate_predictions(
    test_sample_indices, all_preds_test, args.window_size, len(time_series)
)

def plot_metrics(true_labels, pred_probs):
    # ROC Curve
    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
    roc_auc_val = auc(fpr, tpr)

    # Precision-Recall Curve
    precision_vals, recall_vals, _ = precision_recall_curve(true_labels, pred_probs)
    pr_auc_val = auc(recall_vals, precision_vals)

    plt.figure(figsize=(12, 5))

    # ROC Curve
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, label=f"ROC Curve (AUC = {roc_auc_val:.2f})")
    plt.plot([0, 1], [0, 1], "k--", label="Random Guess")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("Receiver Operating Characteristic (ROC) Curve")
    plt.legend()

    # Precision-Recall Curve
    plt.subplot(1, 2, 2)
    plt.plot(recall_vals, precision_vals, label=f"PR Curve (AUC = {pr_auc_val:.2f})")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title("Precision-Recall (PR) Curve")
    plt.legend()

    plt.tight_layout()
    plt.show()

# Plot anomalies on the test set
test_start = min(test_sample_indices)
test_end = max(test_sample_indices) + args.window_size
plot_anomalies(time_series, labels, aggregated_preds, start=test_start, end=test_end)

05、结果

运行该模型的最终结果显示,检测时间序列中的异常的准确率为 91%。

精度为 1.0,所有预测异常均为真异常。

由于召回率为 0.57,该模型确实遗漏了大量异常。

总体而言,对于这个例子,KAN 显示出能够检测时间序列中的真实异常的良好迹象,但需要更多改进来提高其召回率。

对于图 1,通过将 ROC 曲线(表示真实阳性率的图表)与随机猜测进行比较,我们观察到曲线下面积 (AUC) 为 0.88,表明模型性能强劲。图 2 显示了模型的召回率和精确度之间的权衡,表明大多数数据的精确度都很高。

对于 ECG5000 数据集(本教程中未显示),该模型显示准确率为 82%,精确率为 72%,召回率为 93%,展示了在保持合理错误率的同时检测异常的强大能力。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小Z的科研日常

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值