该章节介绍VITGAN对抗生成网络中,PositionalEmbedding 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
PositionalEmbedding 简介
PositionalEmbedding 就是图中1-N的位置编码,根据论文中描述,位置编号1-N(N为图片块数量),经过全连接层映射,再用sin函数约束值的范围[-1,1]。
代码实现
import tensorflow as tf
class PositionalEmbedding(tf.Module):
"""
输入位置编码
"""
def __init__(
self,
sequence_length,
emb_dim,
name=None,
):
super().__init__(name=name)
self.emb_dim = emb_dim
self.sequence_length = sequence_length
self.pos_emb = tf.keras.layers.Dense(emb_dim, use_bias=False)
self.pos_input = tf.linspace(-1, 1, sequence_length)[tf.newaxis, :, tf.newaxis]
def __call__(self):
x = self.pos_emb(self.pos_input)
x = tf.math.sin(x)
return x
if __name__ == "__main__":
layer = PositionalEmbedding(
sequence_length=196,
emb_dim=768
)
o1 = layer()
tf.print('o1:', tf.shape(o1))