机器学习进阶之 时域/时间卷积网络 TCN 概念+由来+原理+代码实现

TCN(时间卷积网络)是一种处理时间序列数据的网络结构,用于时序预测、概率预测等任务。它源于CNN和RNN,解决了传统网络的梯度消失和并行性问题,采用膨胀因果卷积、残差连接等技术。本文详细介绍了TCN的原理,并提供了PaddlePaddle实现的代码示例。
部署运行你感兴趣的模型镜像

TCN 从“阿巴阿巴”到“巴拉巴拉”

  • TCN的概念(干嘛来的!能解决什么问题)
  • TCN的父母(由来)
  • TCN的原理介绍
  • 上代码!

1、TCN(时域卷积网络、时间卷积网络)是干嘛的,能干嘛

  • 主要应用方向:

时序预测、概率预测、时间预测、交通预测

2、TCN的由来

ps:在了解TCN之前需要先对CNN和RNN有一定的了解。

  • 处理问题:

是一种能够处理时间序列数据的网络结构,在特定条件下,效果优于传统的神经网络(RNN、CNN等)。

3、TCN的原理介绍

TCN 的网络结构

请添加图片描述

一、TCN的网络结构主要由上图构成。本文分为左边和右边两部分,首先是左边

Dilated Causal Conv ---> WeightNorm--->ReLU--->Dropout--->Dilated Causal Conv ---> WeightNorm--->ReLU--->Dropout

很明显这个可以分为

(Dilated Causal Conv ---> WeightNorm--->ReLU--->Dropout)*2

ok,下面我们对这四个逐个进行讲解,如有了解可以选择跳读

1、Dilated Gausal Conv

中文名:膨胀因果卷积

膨胀因果卷积可以分为膨胀因果卷积三部分。

卷积是指 CNN中的卷积,是指卷积核在数据上进行的一种滑动运算操作;

膨胀是指 允许卷积时的输入存在间隔采样,其和卷积神经网络中的stride有相似之处,但也有很明显的区别

图片说明:

请添加图片描述

因果是指 第i层中t时刻的数据,只依赖与(i-1)层t时刻及其以前的值的影响。因果卷积可以在训练的时候摒弃掉对未来数据的读取,是一种严格的时间约束模型。

图片说明:

请添加图片描述

​ (ps:没有加入膨胀卷积)

2、WeightNorm

权重归一化

对权重值进行归一化,如果有想仔细研究归一化过程&归一化公式的,可以点击链接进行学习

点击

优点:

1、时间开销小,运算速度快!

2、引入更少的噪声

3、WeightNorm是通过重写深度网络的权重来进行加速的,没有引入对minibatch的依赖

3、ReLU()

激活函数的一种

优点:

1、可以使网络的训练速度更快

2、增加网络的非线性,提高模型的表达能力

3、防止梯度消失,

4、使网络具有稀疏性等

公式:

在这里插入图片描述

概述图:

ReLU 函数

4、Dropout()

Dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。

优点:防止过拟合,提高模型的运算速度

二、最后是右边-残差连接

右边是一个1*1的卷积块儿,不仅可以使网络拥有跨层传递信息的功能,而且可以保证输入输出的一致性。

三、TCN的优点:

1、并行性

2、可以很大程度上避免梯度消失和梯度爆炸

3、感受野更大,学习到的信息更多

4、从零coding

import os
import sys
import paddle
import paddle.nn as nn
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
import paddle.nn.functional as F
from paddle.nn.utils import weight_norm
from sklearn.preprocessing import MinMaxScaler
from pandas.plotting import register_matplotlib_converters
from sourceCode import TimeSeriesNetwork
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../..")))

class Chomp1d(nn.Layer):
    def __init__(self, chomp_size):
        super(Chomp1d, self).__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, :-self.chomp_size]


class TemporalBlock(nn.Layer):
    def __init__(self,
                 n_inputs,
                 n_outputs,
                 kernel_size,
                 stride,
                 dilation,
                 padding,
                 dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = weight_norm(
            nn.Conv1D(
                n_inputs,
                n_outputs,
                kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation))
        # Chomp1d is used to make sure the network is causal.
        # We pad by (k-1)*d on the two sides of the input for convolution,
        # and then use Chomp1d to remove the (k-1)*d output elements on the right.
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = weight_norm(
            nn.Conv1D(
                n_outputs,
                n_outputs,
                kernel_size,
                stride=stride,
                padding=padding,
                dilation=dilation))
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1,
                                 self.dropout1, self.conv2, self.chomp2,
                                 self.relu2, self.dropout2)
        self.downsample = nn.Conv1D(n_inputs, n_outputs,
                                    1) if n_inputs != n_outputs else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.set_value(
            paddle.tensor.normal(0.0, 0.01, self.conv1.weight.shape))
        self.conv2.weight.set_value(
            paddle.tensor.normal(0.0, 0.01, self.conv2.weight.shape))
        if self.downsample is not None:
            self.downsample.weight.set_value(
                paddle.tensor.normal(0.0, 0.01, self.downsample.weight.shape))

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x) # 让输入等于输出
        return self.relu(out + res)


class TCNEncoder(nn.Layer):
    def __init__(self, input_size, num_channels, kernel_size=2, dropout=0.2):
        # input_size : 输入的预期特征数
        # num_channels: 通道数
        # kernel_size: 卷积核大小
        super(TCNEncoder, self).__init__()
        self._input_size = input_size
        self._output_dim = num_channels[-1]

        layers = nn.LayerList()
        num_levels = len(num_channels)
        # print('print num_channels: ', num_channels)
        # print('print num_levels: ',num_levels)
        # exit(0)
        for i in range(num_levels):
            dilation_size = 2 ** i
            in_channels = input_size if i == 0 else num_channels[i - 1]
            out_channels = num_channels[i]
            layers.append(
                TemporalBlock(
                    in_channels,
                    out_channels,
                    kernel_size,
                    stride=1,
                    dilation=dilation_size,
                    padding=(kernel_size - 1) * dilation_size,
                    dropout=dropout))

        self.network = nn.Sequential(*layers)

    def get_input_dim(self):
        return self._input_size



    def get_output_dim(self):
        return self._output_dim

    def forward(self, inputs):
        inputs_t = inputs.transpose([0, 2, 1])
        output = self.network(inputs_t).transpose([2, 0, 1])[-1]
        return output


class TimeSeriesNetwork(nn.Layer):

    def __init__(self, input_size, next_k=1, num_channels=[256]):
        super(TimeSeriesNetwork, self).__init__()

        self.last_num_channel = num_channels[-1]

        self.tcn = TCNEncoder(
            input_size=input_size,
            num_channels=num_channels,
            kernel_size=3,
            dropout=0.2
        )

        self.linear = nn.Linear(in_features=self.last_num_channel, out_features=next_k)

    def forward(self, x):
        tcn_out = self.tcn(x)
        y_pred = self.linear(tcn_out)
        return y_pred
'''
我努力把自己塑造成悲剧里面的男主角,
把一切过错推到你的身上,
让你成为万恶的巫婆,
丧心病狂
可是我就是一个正常的人,
有悲有喜,
有错有对,
走到今天这个地步,
我们都有责任,
直到现在我还没有觉得我失去了你
你告诉我,我失去你了么?
'''
def config_mtp():
    sns.set(style='whitegrid', palette='muted', font_scale=1.2)
    HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#93D30C", "#8F00FF"]
    sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
    rcParams['figure.figsize'] = 14, 10
    register_matplotlib_converters()

def read_data():
    df_all = pd.read_csv('./data/time_series_covid19_confirmed_global.csv')
    # print(df_all.head())

    # 我们将对全世界的病例数进行预测,因此我们不需要关心具体国家的经纬度等信息,只需关注具体日期下的全球病例数即可。

    df = df_all.iloc[:, 4:]
    daily_cases = df.sum(axis=0)
    daily_cases.index = pd.to_datetime(daily_cases.index)
    # print(daily_cases.head())

    plt.figure(figsize=(5, 5))
    plt.plot(daily_cases)
    plt.title("Cumulative daily cases")
    # plt.show()

    # 为了提高样本时间序列的平稳性,继续取一阶差分
    daily_cases = daily_cases.diff().fillna(daily_cases[0]).astype(np.int64)
    # print(daily_cases.head())

    plt.figure(figsize=(5, 5))
    plt.plot(daily_cases)
    plt.title("Daily cases")
    plt.xticks(rotation=60)
    plt.show()
    return daily_cases

def create_sequences(data, seq_length):
    xs = []
    ys = []
    for i in range(len(data) - seq_length + 1):
        x = data[i:i + seq_length - 1]
        y = data[i + seq_length - 1]
        xs.append(x)
        ys.append(y)
    return np.array(xs), np.array(ys)

def preprocess_data(daily_cases):
    TEST_DATA_SIZE,SEQ_LEN = 30,10
    TEST_DATA_SIZE = int(TEST_DATA_SIZE/100*len(daily_cases))
    # TEST_DATA_SIZE=30,最后30个数据当成测试集,进行预测
    train_data = daily_cases[:-TEST_DATA_SIZE]
    test_data = daily_cases[-TEST_DATA_SIZE:]
    print("The number of the samples in train set is : %i" % train_data.shape[0])
    print(train_data.shape, test_data.shape)

    # 为了提升模型收敛速度与性能,我们使用scikit-learn进行数据归一化。
    scaler = MinMaxScaler()
    train_data = scaler.fit_transform(np.expand_dims(train_data, axis=1)).astype('float32')
    test_data = scaler.transform(np.expand_dims(test_data, axis=1)).astype('float32')

    # 搭建时间序列
    # 可以用前10天的病例数预测当天的病例数,为了让测试集中的所有数据都能参与预测,我们将向测试集补充少量数据,这部分数据只会作为模型的输入。
    x_train, y_train = create_sequences(train_data, SEQ_LEN)
    test_data = np.concatenate((train_data[-SEQ_LEN + 1:], test_data), axis=0)
    x_test, y_test = create_sequences(test_data, SEQ_LEN)

    # 尝试输出
    '''
    print("The shape of x_train is: %s"%str(x_train.shape))
    print("The shape of y_train is: %s"%str(y_train.shape))
    print("The shape of x_test is: %s"%str(x_test.shape))
    print("The shape of y_test is: %s"%str(y_test.shape))
    '''
    return x_train,y_train,x_test,y_test,scaler

# 数据集处理完毕,将数据集封装到CovidDataset,以便模型训练、预测时调用。
class CovidDataset(paddle.io.Dataset):
    def __init__(self, feature, label):
        self.feature = feature
        self.label = label
        super(CovidDataset, self).__init__()

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):
        return [self.feature[index], self.label[index]]

def parameter():
    LR = 1e-2

    model = paddle.Model(network)

    optimizer = paddle.optimizer.Adam(
        learning_rate=LR, parameters=model.parameters())

    loss = paddle.nn.MSELoss(reduction='sum')
    model.prepare(optimizer, loss)



config_mtp()
data = read_data()
x_train,y_train,x_test,y_test,scaler = preprocess_data(data)
train_dataset = CovidDataset(x_train, y_train)
test_dataset = CovidDataset(x_test, y_test)
network = TimeSeriesNetwork(input_size=1)

# 参数配置
LR = 1e-2

model = paddle.Model(network)

optimizer = paddle.optimizer.Adam(learning_rate=LR, parameters=model.parameters()) # 优化器

loss = paddle.nn.MSELoss(reduction='sum')
model.prepare(optimizer, loss) # Configures the model before runing,运行前配置模型

# 训练
USE_GPU = False
TRAIN_EPOCH = 100
LOG_FREQ = 20
SAVE_DIR = os.path.join(os.getcwd(),"save_dir")
SAVE_FREQ = 20

if USE_GPU:
    paddle.set_device("gpu")
else:
    paddle.set_device("cpu")

model.fit(train_dataset,
    batch_size=32,
    drop_last=True,
    epochs=TRAIN_EPOCH,
    log_freq=LOG_FREQ,
    save_dir=SAVE_DIR,
    save_freq=SAVE_FREQ,
    verbose=1 # The verbosity mode, should be 0, 1, or 2.   0 = silent, 1 = progress bar, 2 = one line per epoch. Default: 2.
    )




# 预测
preds = model.predict(
        test_data=test_dataset
        )

# 数据后处理,将归一化的数据转化为原数据,画出真实值对应的曲线和预测值对应的曲线。
true_cases = scaler.inverse_transform(
    np.expand_dims(y_test.flatten(), axis=0)
).flatten()

predicted_cases = scaler.inverse_transform(
  np.expand_dims(np.array(preds).flatten(), axis=0)
).flatten()
print(true_cases.shape, predicted_cases.shape)
# print (type(data))
# print(data[1:3])
# print (len(data), len(data))
# print(data.index[:len(data)])
mse_loss = paddle.nn.MSELoss(reduction='mean')
print(paddle.sqrt(mse_loss(paddle.to_tensor(true_cases), paddle.to_tensor(predicted_cases))))

print(true_cases, predicted_cases)

如果需要数据欢迎下方评论,同时也可以私信获取。

千万不要忘了点赞、评论、收藏,对我真的很重要偶~

您可能感兴趣的与本文相关的镜像

PaddlePaddle-v3.3

PaddlePaddle-v3.3

PaddlePaddle

PaddlePaddle是由百度自主研发的深度学习平台,自 2016 年开源以来已广泛应用于工业界。作为一个全面的深度学习生态系统,它提供了核心框架、模型库、开发工具包等完整解决方案。目前已服务超过 2185 万开发者,67 万企业,产生了 110 万个模型

<think>我们正在处理一个关于Android窗口管理器和屏幕旋转的问题。用户提到了一个特定的问题:"Android WindowManagerShell Transition requested Finish fixed rotation transform issue"。这似乎涉及屏幕旋转时窗口转换的问题。 根据用户提供的引用(尽管是游戏开发中的修复),我们可以类比: 引用[1]提到了属性修改器没有正确复制到客户端的问题,这可能导致客户端显示不正确的值。类似地,在Android窗口旋转中,可能涉及状态同步问题。 引用[2]提到了支持连续的欧拉角变化,避免动画翻转。在屏幕旋转中,也可能涉及旋转动画的平滑过渡。 因此,我们推测用户的问题可能是:在屏幕旋转时,窗口转换(Transition)过程中出现了问题,特别是在请求完成固定旋转变换(fixed rotation transform)时,可能出现了显示错误或过渡不流畅。 在Android系统中,屏幕旋转涉及窗口的重新配置(如尺寸变化、方向变化)。WindowManager(窗口管理器)负责管理窗口的布局和转换。WindowManagerShell(WM Shell)是WindowManager的扩展,用于处理多窗口模式(如分屏、自由窗口)和过渡动画。 "fixed rotation transform" 可能指的是在屏幕旋转过程中,系统会暂时固定窗口的旋转状态,以确保旋转过程中的流畅性。当旋转完成时,系统需要正确结束这个固定旋转状态,并应用最终的变换。 问题可能出现在这个结束固定旋转变换的过程中,导致窗口显示不正确(如布局错误、黑屏、闪烁)或过渡动画异常。 **解决方案思路:** 1. **检查旋转转换流程**:确保在屏幕旋转时,系统正确执行了以下步骤: - 保存当前窗口状态。 - 应用临时的旋转变换(fixed rotation transform)以保持窗口在旋转过程中的显示。 - 在旋转完成后,移除临时变换并应用最终的窗口配置。 2. **状态同步问题**:类似于引用[1]中的问题,可能是在转换过程中,窗口的状态(如旋转角度、位置、尺寸)没有正确同步到渲染线程或客户端,导致显示错误。 3. **动画平滑过渡**:参考引用[2],确保旋转角度变化是连续的,避免在180度附近发生翻转(即避免从179度直接跳到-180度,导致动画翻转)。在Android中,可以使用四元数或优化旋转插值来实现平滑过渡。 **具体步骤:** 1. **确认问题发生的场景**:在哪些情况下会出现这个问题?例如,特定应用、特定旋转方向(0°、90°、180°、270°)或快速旋转。 2. **检查WindowManagerService和WM Shell的日志**:通过`adb logcat`查看相关日志(过滤`WindowManager`、`WindowManagerShell`等标签),寻找错误或异常。 3. **分析转换代码**:在Android源码中,与屏幕旋转转换相关的关键类可能包括: - `WindowManagerService`:管理窗口状态和策略。 - `DisplayRotation`:处理显示旋转。 - `WindowContainer`及其子类:管理窗口容器。 - `Transition`相关类(如`WindowManagerShell`中的`Transitions`):处理窗口转换动画。 特别关注`finishFixedRotationTransform`或类似名称的方法,查看其实现逻辑。 4. **修复建议**: - **状态同步**:确保在结束固定旋转变换时,窗口的最终状态(包括旋转角度、位置、尺寸)被正确更新并通知到所有相关组件(如SurfaceFlinger)。 - **动画处理**:在旋转转换动画中,使用连续的旋转角度插值(例如,使用四元数插值或角度插值时避免180度翻转),确保动画平滑。 - **错误处理**:在转换过程中增加健壮性检查,例如检查窗口是否已经被销毁,避免空指针异常。 5. **测试**:编写测试用例模拟快速旋转、不同方向旋转等场景,验证修复效果。 **代码示例(概念性):** 假设在`WindowManagerShell`中有一个处理结束固定旋转变换的方法,我们可以这样修改: ```java // 原始代码可能存在未正确移除变换或状态未更新的问题 void finishFixedRotationTransform(WindowContainer wc) { // 移除临时旋转变换 wc.removeFixedRotationTransform(); // 确保更新窗口状态并重新布局 wc.updateOrientation(); wc.layoutIfNeeded(); // 同步状态到客户端(避免引用[1]中的问题) wc.syncStateToClient(); } ``` 同时,在动画处理中,避免角度翻转: ```java // 在计算旋转动画插值时,使用连续的角度插值 float interpolateRotation(float start, float end, float fraction) { // 处理角度跨越180度边界的情况 float diff = end - start; if (Math.abs(diff) > 180) { end = (diff > 0) ? end - 360 : end + 360; } return start + (end - start) * fraction; } ``` **相关问题:** 1. 如何调试Android窗口管理器相关的显示问题? 2. 屏幕旋转时,Android系统如何管理窗口的生命周期和状态? 3. 如何优化屏幕旋转的动画性能?</think>### Android WindowManagerShell Transition 请求完成固定旋转变换问题分析 在 Android 系统中,`WindowManagerShell` 是窗口管理系统的关键组件,负责处理屏幕旋转时的窗口变换。当遇到 "Transition requested Finish fixed rotation transform" 相关问题时,通常涉及屏幕旋转过程中窗口状态转换的异常处理。以下是系统化的解决方案: #### 核心问题原因 1. **旋转状态同步失败** 屏幕旋转时,`WindowManagerService` 需同步更新所有窗口的旋转状态。若 `finishFixedRotationTransform()` 方法被过早调用,会导致: - 窗口变换未完成时旋转状态被清除 - 客户端收到错误的位置/尺寸信息(类似引用[1]中的属性复制问题) - 出现黑屏、画面撕裂或布局错位 2. **动画连续性中断** 旋转动画需要连续的角度变换(如引用[2]所述)。当系统强制限制在 $[-180°, 180°]$ 范围时,会导致: - 角度跳变(例如从 $179°$ 直接跳到 $-180°$) - 动画翻转(flip)现象 - 过渡效果不连贯 3. **窗口生命周期冲突** 当 Activity 在旋转过程中重建时: ```mermaid graph LR A[屏幕旋转事件] --> B[WindowManager 冻结窗口] B --> C[Activity 销毁重建] C --> D[新窗口创建] D --> E[应用旋转变换] E -->|冲突| F[finishFixedRotation 被提前调用] ``` #### 解决方案 1. **修复旋转状态同步** 修改 `WindowManagerService` 中的状态同步逻辑: ```java // 确保在完成旋转前验证所有窗口状态 void finishFixedRotationTransform(WindowToken token) { if (!token.isRotationFinished()) { // 添加状态检查逻辑 for (WindowState w : token.getWindows()) { if (w.isInFixedRotation()) { w.finishSeamlessRotation(); } } // 延迟清除旋转标志 mHandler.postDelayed(() -> { token.clearFixedRotationTransform(); }, 300); // 确保动画完成 } } ``` 2. **优化旋转动画连续性** 应用连续角度变换模型(参考引用[2]): $$ \theta_{\text{new}} = \begin{cases} \theta_{\text{current}} + 360^\circ & \text{if } \Delta\theta > 180^\circ \\ \theta_{\text{current}} - 360^\circ & \text{if } \Delta\theta < -180^\circ \\ \theta_{\text{target}} & \text{otherwise} \end{cases} $$ - 避免角度跳变 - 使用四元数插值代替欧拉角 3. **生命周期冲突处理** 在 `ActivityRecord.java` 中添加防护逻辑: ```java void onConfigurationChanged(Configuration newConfig) { if (isFixedRotationActive() && !isRotationComplete()) { // 延迟配置变更直到旋转完成 mDeferredConfig = newConfig; return; } super.onConfigurationChanged(newConfig); } ``` #### 验证方法 1. **日志分析** 过滤关键日志标签: ```bash adb logcat -s WindowManager:V WindowManagerShell:V Rotation:V ``` 检查以下关键事件序列: - `applyRotation: start` - `finishFixedRotation requested` - `transform complete` 2. **测试用例** ```java // 模拟快速旋转冲突 for (int i = 0; i < 5; i++) { device.setRotation(ROTATION_90); device.setRotation(ROTATION_0); // 验证窗口无残留旋转状态 assertFalse(windowState.isInFixedRotation()); } ``` #### 临时规避措施 若需紧急修复: ```xml <!-- AndroidManifest.xml --> <activity android:name=".MainActivity" android:configChanges="orientation|screenSize" android:screenOrientation="fullUser"/> ``` > **注意**:此方案禁用系统旋转处理,可能影响多窗口兼容性 ### 相关问题 1. 如何诊断 Android 屏幕旋转过程中的窗口层级状态? 2. WindowManagerService 如何处理多窗口模式下的旋转同步? 3. 在自定义窗口动画时如何避免与系统旋转变换冲突? 4. Android 13 的窗口边衬区(insets)计算在旋转时有哪些变化?
评论 289
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值