AI论文阅读笔记 | TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network

目录

一、文献基本信息

二、基本内容 

1、课题研究背景 

2、现有方法的不足

3、创新点 

三、具体网络结构 

1、TTS-GAN

 2、时间序列处理​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​

四、实验

1、数据集

2、实验结果

一、文献基本信息

题目:TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network

会议:International Conference on Artificial Intelligence in Medicine (AIME 2022)

原文https://arxiv.org/pdf/2202.02691v2.pdf

源码https://github.com/imics-lab/tts-gan

二、基本内容 

1、课题研究背景 

        医学机器学习应用中最常见的数据类型之一是以时间序列形式出现的信号测量。 然而这样的数据集通常很小,使得深度神经网络架构的训练无效。因此需要探寻有效的数据扩充方法来对数据集进行增广,提升模型效果。其中GAN(生成对抗网络)是一个有效的工具。

2、现有方法的不足

        基于RNN的GAN不能有效地对具有不规则时间关系的长序列建模,且存在梯度消失问题。

3、创新点 

  1. 创建了一个纯粹基于Transformer编码器的GAN来生成时间序列数据(TTS-GAN)。
  2. 提出几种方法,更有效地在时间序列数据上训练基于transformer的GAN模型。

三、具体网络结构 

1、TTS-GAN

        TTS-GAN模型结构如图1所示,包含两个主要组件,一个生成器和一个鉴别器,都是基于Transformer的编码器架构构建的。其中编码器是由两个复合块组成的。第一块由多头自注意模块构成,第二块由具有GELU激活函数的前馈MLP构成。在两个块之前应用规范化层,在每个块之后添加dropout层,两个块都使用残差连接。

生成器 :

        生成器首先接收一个1D向量,其中N个均匀分布的随机数值在(0,1)范围内,即N_{i}\sim (0,1)。N表示合成信号的潜在维数,是一个可调的超参数。然后将向量映射到具有相同实际信号长度和嵌入维数为M的序列,其中M也是一个可以改变的超参数;接下来,将序列划分为多个patch,并在每个patch中添加一个位置编码值。然后将这些pstch输入到Transformer的编码器块;之后将编码器的输出通过Conv2D层以降低合成数据的维数,Conv2D层卷积核大小为(1,1),不会改变合成数据的宽度和高度。

鉴别器:

        鉴别器架构类似ViT模型(一种二分类器)。在TTS-GAN中,将任何输入序列视为高度为1的图像,输入的时间步长是图像宽度。因此,要在时间序列输入上添加位置编码,只需要将宽度均匀地分成多个片段,并保持每个片段的高度不变。

 2、时间序列处理​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​

        将时间序列数据视为是高度为1的图像。时间步长是图像的宽度W,一个时间序列序列可以有一个单通道或多个通道,这些通道可以被看作是一个图像的通道数(RGB) c,因此输入序列可以用大小矩阵(Batch Size,C,1,W),然后选择一个patch大小为N,将序列划分为W/N个patch。然后,我们在每个patch的末尾添加一个软位置编码值,该位置值在模型训练过程中学习。因此,鉴别器编码块的输入将具有数据形状(Batch Size,C,(W=N) + 1)。过程如图2所示。

四、实验

1、数据集

  • 模拟正弦波

  • UniMiB人类活动识别(HAR)数据集

  • PTB诊断ECG数据库

2、实验结果 

        真实数据和TTS-GAN生成的数据可视化结果: 

        使用PCA(主成分分析法)和t-SNE(用于高维数据降维到2维或者3维),进行可视化 :

        将TTS-GAN生成的序列和RNN-GAN生成的序列与同类真实序列的相似性进行了比较:        (avg_cos_sim,越大越好; avg_jen_dis,越小越好) 

 

03-15
### TimeGAN 的概述 TimeGAN 是一种基于生成对抗网络 (GAN) 的时间序列合成模型,旨在通过学习真实数据分布来生成高质量的时间序列数据[^1]。该方法的核心在于结合了监督学习和无监督学习的思想,在生成器和判别器之间建立了一种动态平衡。 以下是关于 TimeGAN 的实现及其 GitHub 教程的一些关键点: --- ### TimeGAN 的核心组件 TimeGAN 主要由以下几个部分组成: 1. **嵌入网络(Embedding Network)**: 将原始时间序列映射到隐空间。 2. **恢复网络(Recovery Network)**: 从隐空间重建回原始时间序列。 3. **生成器(Generator)**: 学习生成新的时间序列样本。 4. **判别器(Discriminator)**: 区分真实时间和生成时间序列之间的差异。 5. **监督损失(Supervised Loss)**: 确保生成器能够模仿真实的序列模式。 这些模块共同作用以优化生成的时间序列的质量。 --- ### TimeGAN 的 GitHub 实现教程 目前有许多开源项目提供了 TimeGAN 的具体实现代码示例。以下是一个典型的 Python 实现框架: #### 安装依赖库 首先需要安装必要的依赖项,通常包括 TensorFlow 或 PyTorch。例如: ```bash pip install tensorflow numpy matplotlib ``` #### 时间序列生成的代码示例 下面展示了一个简单的 TimeGAN 实现流程: ```python import numpy as np from timegan import TimeGAN # 假设已导入 TimeGAN 类 # 数据准备 def load_data(): """加载时间序列数据""" data = np.random.randn(100, 24, 1) # 示例:100 条长度为 24 的单变量时间序列 return data data = load_data() # 初始化 TimeGAN 模型 model = TimeGAN(module='gru', hidden_dim=24, num_layers=3) # 训练模型 model.train(data, train_steps=10000) # 合成新时间序列 synthetic_data = model.generate(n_samples=10) print(synthetic_data.shape) # 输出形状应为 (n_samples, seq_len, dim) ``` 上述代码展示了如何利用 `TimeGAN` 库训练并生成新的时间序列数据。 --- ### 可视化生成结果 为了验证生成的数据质量,可以绘制真实数据与合成数据的对比图: ```python import matplotlib.pyplot as plt real_sample = data[0].flatten() synthetic_sample = synthetic_data[0].flatten() plt.figure(figsize=(8, 4)) plt.plot(real_sample, label="Real Data", color="blue") plt.plot(synthetic_sample, label="Synthetic Data", color="orange", linestyle="--") plt.legend() plt.title("Comparison of Real vs Synthetic Time Series") plt.show() ``` 此可视化过程有助于直观评估生成效果。 --- ### 进一步探索的方向 除了基本的功能外,开发者还可以尝试扩展 TimeGAN 的应用场景,比如将其应用于异常检测、预测建模等领域[^2]。此外,也可以研究其他类似的 GAN 架构,如 C-RNN-GAN 或 SeqGAN,以便找到最适合特定任务的方法。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值