WITRAN_2DPSGMU_Encoder 类中,门机制

WITRAN_2DPSGMU_Encoder 类中的门机制详解

WITRAN_2DPSGMU_Encoder 类中,门机制是核心部分,类似于 LSTM 或 GRU 的门控机制,用于控制隐藏状态的更新和输出。以下是对门机制的详细解析。


1. 门机制的作用

门机制的主要作用是:

  1. 控制信息流动

    • 决定当前时间步的输入信息和上一时间步的隐藏状态如何结合。
    • 通过更新门、输入门和输出门,选择性地保留或丢弃信息。
  2. 更新隐藏状态

    • 根据门控信号,更新行隐藏状态(hidden_slice_row)和列隐藏状态(hidden_slice_col)。
  3. 生成输出

    • 将更新后的隐藏状态拼接,作为当前时间步的输出。

2. 门机制的组成

门机制由以下几部分组成:

(1) 门控信号的生成
gate = self.linear(torch.cat([hidden_slice_row, hidden_slice_col, a[:, slice, :]], dim=-1), W, B, batch_size, slice, Water2sea_slice_num)
  • 输入

    • hidden_slice_row:行隐藏状态,形状为 [128, 32]
    • hidden_slice_col:列隐藏状态,形状为 [128, 32]
    • a[:, slice, :]:当前时间步的输入,形状为 [128, 11]
    • 拼接后,输入形状为 [128, 75]
  • 线性变换

    • 权重矩阵 W 的形状为 [192, 75]
    • 偏置向量 B 的形状为 [192]
    • 输出 gate 的形状为 [128, 192]

(2) 分割门控信号
sigmod_gate, tanh_gate = torch.split(gate, 4 * self.hidden_size, dim=-1)
  • 分割结果
    • sigmod_gate:前 128 列,用于生成更新门和输出门。
    • tanh_gate:后 64 列,用于生成输入门。

(3) 激活函数处理
  • Sigmoid 激活

    sigmod_gate = torch.sigmoid(sigmod_gate)
    
    • sigmod_gate 的值映射到 [0, 1] 区间。
    • 用于生成更新门和输出门的开关信号。
  • Tanh 激活

    tanh_gate = torch.tanh(tanh_gate)
    
    • tanh_gate 的值映射到 [-1, 1] 区间。
    • 用于生成输入门的候选值。

(4) 分割门控信号为具体的门
update_gate_row, output_gate_row, update_gate_col, output_gate_col = sigmod_gate.chunk(4, dim=-1)
input_gate_row, input_gate_col = tanh_gate.chunk(2, dim=-1)
  • 更新门

    • update_gate_row:更新行隐藏状态的门控信号,形状为 [128, 32]
    • update_gate_col:更新列隐藏状态的门控信号,形状为 [128, 32]
  • 输出门

    • output_gate_row:输出行隐藏状态的门控信号,形状为 [128, 32]
    • output_gate_col:输出列隐藏状态的门控信号,形状为 [128, 32]
  • 输入门

    • input_gate_row:用于更新行隐藏状态的输入门信号,形状为 [128, 32]
    • input_gate_col:用于更新列隐藏状态的输入门信号,形状为 [128, 32]

3. 隐藏状态的更新

(1) 更新行隐藏状态
hidden_slice_row = torch.tanh((1-update_gate_row)*hidden_slice_row + update_gate_row*input_gate_row) * output_gate_row
  • 计算过程

    1. (1 - update_gate_row) * hidden_slice_row
      • 保留上一时间步的行隐藏状态。
    2. update_gate_row * input_gate_row
      • 引入当前时间步的输入信息。
    3. 相加后通过 torch.tanh 激活:
      • 将结果映射到 [-1, 1] 区间。
    4. 乘以 output_gate_row
      • 控制隐藏状态的输出强度。
  • 结果

    • 更新后的行隐藏状态 hidden_slice_row,形状为 [128, 32]

(2) 更新列隐藏状态
hidden_slice_col = torch.tanh((1-update_gate_col)*hidden_slice_col + update_gate_col*input_gate_col) * output_gate_col
  • 计算过程

    • 与更新行隐藏状态的逻辑相同,但作用于列隐藏状态。
  • 结果

    • 更新后的列隐藏状态 hidden_slice_col,形状为 [128, 32]

4. 输出的生成

output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1)
  • 拼接隐藏状态

    • 将行隐藏状态 hidden_slice_row 和列隐藏状态 hidden_slice_col 在最后一个维度上拼接。
    • 输出形状为 [128, 64]
  • 保存输出

    output_all_slice_list.append(output_slice)
    
    • 将当前时间步的输出保存到 output_all_slice_list 中。

5. 门机制的核心逻辑总结

  1. 门控信号的生成

    • 通过线性变换生成门控信号 gate,并分割为更新门、输出门和输入门。
  2. 隐藏状态的更新

    • 使用更新门和输入门结合上一时间步的隐藏状态,生成新的隐藏状态。
    • 使用输出门控制隐藏状态的输出强度。
  3. 输出的生成

    • 将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。

6. 图示化表示

输入:
    hidden_slice_row [128, 32]
    hidden_slice_col [128, 32]
    当前时间步输入 a[:, slice, :] [128, 11]
    拼接后输入 [128, 75]
        ↓
线性变换:
    gate = self.linear(...) → [128, 192]
        ↓
分割 gate:
    sigmod_gate [128, 128] → 更新门、输出门
    tanh_gate [128, 64] → 输入门
        ↓
激活函数:
    Sigmoid → sigmod_gate
    Tanh → tanh_gate
        ↓
分割门控信号:
    更新门:update_gate_row, update_gate_col
    输出门:output_gate_row, output_gate_col
    输入门:input_gate_row, input_gate_col
        ↓
更新隐藏状态:
    hidden_slice_row → 更新行隐藏状态
    hidden_slice_col → 更新列隐藏状态
        ↓
生成输出:
    output_slice = torch.cat([hidden_slice_row, hidden_slice_col], dim=-1) → [128, 64]

7. 总结

  • 更新门:控制上一时间步的隐藏状态保留多少。
  • 输入门:控制当前时间步的输入信息引入多少。
  • 输出门:控制隐藏状态的输出强度。
  • 最终输出:将行隐藏状态和列隐藏状态拼接,作为当前时间步的输出。

这种门机制类似于 LSTM,但针对二维时间序列数据进行了扩展,分别更新行和列的隐藏状态,从而捕获时间序列的复杂依赖关系。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值