
该章节介绍VITGAN对抗生成网络中,PatchEmbedding 部分的代码实现。
目录(文章发布后会补上链接):
- 网络结构简介
- Mapping NetWork 实现
- PositionalEmbedding 实现
- MLP 实现
- MSA多头注意力 实现
- SLN自调制 实现
- CoordinatesPositionalEmbedding 实现
- ModulatedLinear 实现
- Siren 实现
- Generator生成器 实现
- PatchEmbedding 实现
- ISN 实现
- Discriminator鉴别器 实现
- VITGAN 实现
PatchEmbedding 简介

PatchEmbedding 是鉴别器的输入部分,用于将图片分割,并在边缘补上
o
o
o宽度的重叠部分,可包含一部分位置信息。
代码实现
import tensorflow as tf
class PatchEmbedding(tf.keras.layers.Layer):
"""
2D Image to Patch Embedding
"""
def __init__(
self,
image_size=224,
patch_size=16,
overlapping=3, # 图像重叠部分
emb_dim=768,
discriminator=True,
):
super().__init__()
assert image_size % patch_size == 0
self.image_size = image_size
self.patch_size = patch_size
self.overlapping = overlapping
self.emb_dim = emb_dim
self.discriminator = discriminator
self.grid_size = image_size // patch_size
self.num_patches = self.grid_size ** 2
self.proj = tf.keras.layers.Dense(emb_dim, use_bias=False)
self.create_indexes()
def create_indexes(self):
'''创建切片下标'''
self.all_indexes = []
for y in range(self.grid_size):
for x in range(self.grid_size):
now_y_start = y * self.patch_size
now_y_end = (y+1) * self.patch_size
now_x_start = x * self.patch_size
now_x_end = (x+1) * self.patch_size
# 加重叠部分,边缘则加到同一边
if y == 0:
now_y_end += 2 * self.overlapping
elif y == self.grid_size-1:
now_y_start -= 2 * self.overlapping
else:
now_y_start -= self.overlapping
now_y_end += self.overlapping
if x == 0:
now_x_end += 2 * self.overlapping
elif x == self.grid_size-1:
now_x_start -= 2 * self.overlapping
else:
now_x_start -= self.overlapping
now_x_end += self.overlapping
self.all_indexes.append(
(now_y_start, now_y_end, now_x_start, now_x_end)
)
# print('all_indexes:', self.all_indexes)
def call(self, x):
batch = tf.shape(x)[0]
patch_list = []
for now_y_start, now_y_end, now_x_start, now_x_end in self.all_indexes:
patch_x = x[:,now_y_start:now_y_end,now_x_start:now_x_end,:]
patch_list.append(tf.reshape(patch_x,[batch,1,-1]))
x = tf.concat(patch_list, axis=1)
x = self.proj(x)
return x
if __name__ == "__main__":
layer = PatchEmbedding(
image_size=224,
patch_size=16,
overlapping=3,
embed_dim=768
)
# x = tf.random.uniform([2,224,224,3], dtype=tf.float32)
x = tf.io.read_file('./test.jpg')
x = tf.image.decode_jpeg(x, channels=3)
x = tf.expand_dims(x, axis=0)
x = tf.image.crop_and_resize(x, [[0,0,1,1]], [0], [224,224])
o1 = layer(x)
tf.print('o1:', tf.shape(o1))
o1 = tf.reshape(o1, [1, 14, 14, 22, 22, 3])
o1 = tf.transpose(o1, perm=[0,1,3,2,4,5])
o1 = tf.reshape(o1, [1, 308, 308, 3])
o1 = tf.image.encode_jpeg(tf.cast(o1[0], dtype=tf.uint8))
tf.io.write_file('./test2.jpg', o1)
本文介绍VITGAN对抗生成网络中PatchEmbedding模块的实现细节,包括如何通过TensorFlow构建PatchEmbedding层,该层能将图像分割并添加重叠区域,以保留位置信息。

被折叠的 条评论
为什么被折叠?



