tensorflow2搭建vision_transformer

model

"""
refer to:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""
import tensorflow as tf
from tensorflow.keras import Model, layers, initializers
import numpy as np


class PatchEmbed(layers.Layer):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, embed_dim=768):
        super(PatchEmbed, self).__init__()
        self.embed_dim = embed_dim
        self.img_size = (img_size, img_size)
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = layers.Conv2D(filters=embed_dim, kernel_size=patch_size,
                                  strides=patch_size, padding='SAME',
                                  kernel_initializer=initializers.LecunNormal(),
                                  bias_initializer=initializers.Zeros())

    def call(self, inputs, **kwargs):
        B, H, W, C = inputs.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(inputs)
        # [B, H, W, C] -> [B, H*W, C]
        x = tf.reshape(x, [B, self.num_patches, self.embed_dim])
        return x


class ConcatClassTokenAddPosEmbed(layers.Layer):
    def __init__(self, embed_dim=768, num_patches=196, name=None):
        super(ConcatClassTokenAddPosEmbed, self).__init__(name=name)
        self.embed_dim = embed_dim
        self.num_patches = num_patches

    def build(self, input_shape):
        self.cls_token = self.add_weight(name="cls",
                                         shape=[1, 1, self.embed_dim],
                                         initializer=initializers.Zeros(),
                                         trainable=True,
                                         dtype=tf.float32)
        self.pos_embed = self.add_weight(name="pos_embed",
                                         shape=[1, self.num_patches + 1, self.embed_dim],
                                         initializer=initializers.RandomNormal(stddev=0.02),
                                         trainable=True,
                                         dtype=tf.float32)

    def call(self, inputs, **kwargs):
        batch_size, _, _ = inputs.shape

        # [1, 1, 768] -> [B, 1, 768]
        cls_token = tf.broadcast_to(self.cls_token, shape=[batch_size, 1, self.embed_dim])
        x = tf.concat([cls_token, inputs], axis=1)  # [B, 197, 768]
        x = x + self.pos_embed

        return x


class Attention(layers.Layer):
    k_ini = initializers.GlorotUniform()
    b_ini = initializers.Zeros()

    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.,
                 name=None):
        super(Attention, self).__init__(name=name)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias, name="qkv",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.attn_drop = layers.Dropout(attn_drop_ratio)
        self.proj = layers.Dense(dim, name="out",
                                 kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.proj_drop = layers.Dropout(proj_drop_ratio)

    def call(self, inputs, training=None):
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = inputs.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        qkv = self.qkv(inputs)
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        qkv = tf.reshape(qkv, [B, N, 3, self.num_heads, C // self.num_heads])
        # transpose: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = tf.matmul(a=q, b=k, transpose_b=True) * self.scale
        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn, training=training)

        # multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        x = tf.matmul(attn, v)
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        x = tf.transpose(x, [0, 2, 1, 3])
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = tf.reshape(x, [B, N, C])

        x = self.proj(x)
        x = self.proj_drop(x, training=training)
        return x


class MLP(layers.Layer):
    """
    MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    k_ini = initializers.GlorotUniform()
    b_ini = initializers.RandomNormal(stddev=1e-6)

    def __init__(self, in_features, mlp_ratio=4.0, drop=0., name=None):
        super(MLP, self).__init__(name=name)
        self.fc1 = layers.Dense(int(in_features * mlp_ratio), name="Dense_0",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.act = layers.Activation("gelu")
        self.fc2 = layers.Dense(in_features, name="Dense_1",
                                kernel_initializer=self.k_ini, bias_initializer=self.b_ini)
        self.drop = layers.Dropout(drop)

    def call(self, inputs, training=None):
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.drop(x, training=training)
        x = self.fc2(x)
        x = self.drop(x, training=training)
        return x


class Block(layers.Layer):
    def __init__(self,
                 dim,
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_ratio=0.,
                 attn_drop_ratio=0.,
                 drop_path_ratio=0.,
                 name=None):
        super(Block, self).__init__(name=name)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_0")
        self.attn = Attention(dim, num_heads=num_heads,
                              qkv_bias=qkv_bias, qk_scale=qk_scale,
                              attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio,
                              name="MultiHeadAttention")
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = layers.Dropout(rate=drop_path_ratio, noise_shape=(None, 1, 1)) if drop_path_ratio > 0. \
            else layers.Activation("linear")
        self.norm2 = layers.LayerNormalization(epsilon=1e-6, name="LayerNorm_1")
        self.mlp = MLP(dim, drop=drop_ratio, name="MlpBlock")

    def call(self, inputs, training=None):
        x = inputs + self.drop_path(self.attn(self.norm1(inputs)), training=training)
        x = x + self.drop_path(self.mlp(self.norm2(x)), training=training)
        return x


class VisionTransformer(Model):
    def __init__(self, img_size=224, patch_size=16, embed_dim=768,
                 depth=12, num_heads=12, qkv_bias=True, qk_scale=None,
                 drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0.,
                 representation_size=None, num_classes=1000, name="ViT-B/16"):
        super(VisionTransformer, self).__init__(name=name)
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.depth = depth
        self.qkv_bias = qkv_bias

        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token_pos_embed = ConcatClassTokenAddPosEmbed(embed_dim=embed_dim,
                                                               num_patches=num_patches,
                                                               name="cls_pos")

        self.pos_drop = layers.Dropout(drop_ratio)

        dpr = np.linspace(0., drop_path_ratio, depth)  # stochastic depth decay rule
        self.blocks = [Block(dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
                             qk_scale=qk_scale, drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio,
                             drop_path_ratio=dpr[i], name="encoderblock_{}".format(i))
                       for i in range(depth)]

        self.norm = layers.LayerNormalization(epsilon=1e-6, name="encoder_norm")

        if representation_size:
            self.has_logits = True
            self.pre_logits = layers.Dense(representation_size, activation="tanh", name="pre_logits")
        else:
            self.has_logits = False
            self.pre_logits = layers.Activation("linear")

        self.head = layers.Dense(num_classes, name="head", kernel_initializer=initializers.Zeros())

    def call(self, inputs, training=None):
        # [B, H, W, C] -> [B, num_patches, embed_dim]
        x = self.patch_embed(inputs)  # [B, 196, 768]
        x = self.cls_token_pos_embed(x)  # [B, 176, 768]
        x = self.pos_drop(x, training=training)

        for block in self.blocks:
            x = block(x, training=training)

        x = self.norm(x)
        x = self.pre_logits(x[:, 0])
        x = self.head(x)

        return x


def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes,
                              name="ViT-B_16")
    return model


def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=768,
                              depth=12,
                              num_heads=12,
                              representation_size=768 if has_logits else None,
                              num_classes=num_classes,
                              name="ViT-B_32")
    return model


def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=16,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes,
                              name="ViT-L_16")
    return model


def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=32,
                              embed_dim=1024,
                              depth=24,
                              num_heads=16,
                              representation_size=1024 if has_logits else None,
                              num_classes=num_classes,
                              name="ViT-L_32")
    return model


def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
    """
    ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
    ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
    """
    model = VisionTransformer(img_size=224,
                              patch_size=14,
                              embed_dim=1280,
                              depth=32,
                              num_heads=16,
                              representation_size=1280 if has_logits else None,
                              num_classes=num_classes,
                              name="ViT-H_14")
    return model

main

import os
import re
import sys
import math
import datetime
import cv2 as cv
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from random import shuffle
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import generate_ds

assert tf.version.VERSION >= "2.4.0", "version of tf must greater/equal than 2.4.0"


def main():
    data_root = "./data/flower_photos"  # get data root path

    if not os.path.exists("./save_weights"):
        os.makedirs("./save_weights")

    batch_size = 32
    epochs = 1000
    num_classes = 10
    freeze_layers = True
    initial_lr = 0.001
    weight_decay = 1e-4

    log_dir = "./logs/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    train_writer = tf.summary.create_file_writer(os.path.join(log_dir, "train"))
    val_writer = tf.summary.create_file_writer(os.path.join(log_dir, "val"))

    # data generator with data augmentation
    # train_ds, val_ds = generate_ds(data_root, batch_size=batch_size, val_rate=0.2)
    print('--------------------------------------------------------------数据集准备-----------------------------------------')
    name_dict = {"BF": 0, "BK": 1, "BL": 2, "BR": 3, "CF": 4, "CL": 5, "CV": 6, "CXK": 7, "S": 8, "XF": 9}

    data_root_path = "C:/my_all_data_download/ZCB/color_part_data_processing/"
    test_file_path = "C:/my_all_data_download/ZCB/TXT_doc/test.txt"  # 测试集数据集文件
    trainer_file_path = "C:/my_all_data_download/ZCB/TXT_doc/trainer.txt"  # 训练集数据集文件

    name_data_list = {}  # 记录每类图片有多少训练图片、测试图片

    trainer_list = []
    test_list = []

    # 将图片完整路径存入字典
    def save_train_test_file(path, name):
        if name not in name_data_list:
            img_list = []
            img_list.append(path)
            name_data_list[name] = img_list
        else:
            name_data_list[name].append(path)

    # 遍历数据集目录,提取出图片路径,分训练集、测试集
    dirs = os.listdir(data_root_path)
    for d in dirs:
        full_path = data_root_path + d
        if os.path.isdir(full_path):
            imgs = os.listdir(full_path)  # 列出子目录中所有图片
            for img in imgs:
                save_train_test_file(full_path + "/" + img, d)

    # 将字典中的内容写入测试集、训练集文件
    with open(test_file_path, "w") as f:  # 清空测试集文件
        pass
    with open(trainer_file_path, "w") as f:  # 清空训练集文件
        pass

    # 遍历字典,分数据
    for name, img_list in name_data_list.items():
        i = 0
        num = len(img_list)
        print(f"{name}:{num}张")
        for img in img_list:
            if i % 10 == 0:
                test_list.append(f"{img}\t{name_dict[name]}\n")
            else:
                trainer_list.append(f"{img}\t{name_dict[name]}\n")
            i += 1
    with open(trainer_file_path, "w") as f:
        shuffle(trainer_list)
        f.writelines(trainer_list)

    with open(test_file_path, "w") as f:
        f.writelines(test_list)

    print("---------------------------------------------------之前的代码主要是生成.txt文件便于找到图片和对应的标签-------------------------------------------------")

    def generateds(train_list):
        x, y_ = [], []  # x图片数据,y_为标签
        with open(train_list, 'r') as f:
            # 读取所有行
            lines = [line.strip() for line in f]  # 对数据进行掐头去尾放入列表
            for line in lines:
                img_path, lab = line.strip().split("\t")
                img = cv.imread(img_path)  # 读入图片
                img = cv.resize(img, (224, 224))  ####对图片进行放缩**********************************
                # img = np.array(img.convert('L')) #将图片变为8位宽灰度值的np.array格式
                img = img / 255  # 数据归一化(实现预处理)
                x.append(img)  # 归一化后的数据,贴到列表x
                y_.append(lab)

        x = np.array(x)
        y_ = np.array(y_)
        y_ = y_.astype(np.int64)
        return x, y_

    x_train, y_train = generateds(trainer_file_path)
    x_test, y_test = generateds(test_file_path)
    x_train = tf.convert_to_tensor(x_train, dtype=tf.float32)
    y_train = tf.convert_to_tensor(y_train, dtype=tf.int32)
    x_test = tf.convert_to_tensor(x_test, dtype=tf.float32)
    y_test = tf.convert_to_tensor(y_test, dtype=tf.int32)
    train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))  # 构建数据集对象
    train_ds = train_dataset.batch(32)  # 设置批量训练的batch为32,要将训练集重复训练10遍
    test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
    val_ds = test_dataset.batch(32)
    print('--------------------------------------------------------------数据集准备-----------------------------------------')
    # create model
    model = create_model(num_classes=num_classes, has_logits=False)
    model.build((32, 224, 224, 3))

    # 下载我提前转好的预训练权重
    # 链接: https://pan.baidu.com/s/1ro-6bebc8zroYfupn-7jVQ  密码: s9d9
    # load weights
    # pre_weights_path = './ViT-B_16.h5'
    # assert os.path.exists(pre_weights_path), "cannot find {}".format(pre_weights_path)
    # model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)

    # freeze bottom layers
    if freeze_layers:
        for layer in model.layers:
            if "pre_logits" not in layer.name and "head" not in layer.name:
                layer.trainable = False
            else:
                print("training {}".format(layer.name))

    model.summary()

    # custom learning rate curve
    def scheduler(now_epoch):
        end_lr_rate = 0.01  # end_lr = initial_lr * end_lr_rate
        rate = ((1 + math.cos(now_epoch * math.pi / epochs)) / 2) * (1 - end_lr_rate) + end_lr_rate  # cosine
        new_lr = rate * initial_lr

        # writing lr into tensorboard
        with train_writer.as_default():
            tf.summary.scalar('learning rate', data=new_lr, step=epoch)

        return new_lr

    # using keras low level api for training
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    optimizer = tf.keras.optimizers.SGD(learning_rate=initial_lr, momentum=0.9)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

    @tf.function
    def train_step(train_images, train_labels):
        with tf.GradientTape() as tape:
            output = model(train_images, training=True)
            # cross entropy loss
            ce_loss = loss_object(train_labels, output)

            # l2 loss
            matcher = re.compile(".*(bias|gamma|beta).*")
            l2loss = weight_decay * tf.add_n([
                tf.nn.l2_loss(v)
                for v in model.trainable_variables
                if not matcher.match(v.name)
            ])

            loss = ce_loss + l2loss

        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss(ce_loss)
        train_accuracy(train_labels, output)

    @tf.function
    def val_step(val_images, val_labels):
        output = model(val_images, training=False)
        loss = loss_object(val_labels, output)

        val_loss(loss)
        val_accuracy(val_labels, output)

    best_val_acc = 0.
    for epoch in range(epochs):
        train_loss.reset_states()  # clear history info
        train_accuracy.reset_states()  # clear history info
        val_loss.reset_states()  # clear history info
        val_accuracy.reset_states()  # clear history info

        # train
        train_bar = tqdm(train_ds, file=sys.stdout)
        for images, labels in train_bar:
            train_step(images, labels)

            # print train process
            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                                 epochs,
                                                                                 train_loss.result(),
                                                                                 train_accuracy.result())

        # update learning rate
        optimizer.learning_rate = scheduler(epoch)

        # validate
        val_bar = tqdm(val_ds, file=sys.stdout)
        for images, labels in val_bar:
            val_step(images, labels)

            # print val process
            val_bar.desc = "valid epoch[{}/{}] loss:{:.3f}, acc:{:.3f}".format(epoch + 1,
                                                                               epochs,
                                                                               val_loss.result(),
                                                                               val_accuracy.result())
        # writing training loss and acc
        with train_writer.as_default():
            tf.summary.scalar("loss", train_loss.result(), epoch)
            tf.summary.scalar("accuracy", train_accuracy.result(), epoch)

        # writing validation loss and acc
        with val_writer.as_default():
            tf.summary.scalar("loss", val_loss.result(), epoch)
            tf.summary.scalar("accuracy", val_accuracy.result(), epoch)

        # only save best weights
        if val_accuracy.result() > best_val_acc:
            best_val_acc = val_accuracy.result()
            save_name = "./save_weights/model.ckpt"
            model.save_weights(save_name, save_format="tf")


if __name__ == '__main__':
    main()

测试集准确率为75%

### CNN与Vision Transformer结合的Matlab实现 融合卷积神经网络(CNN)和视觉Transformer(Vision Transformer, ViT)的方法已经在图像处理领域取得了显著进展[^1]。然而,目前主流的ViT框架主要基于Python环境下的PyTorch或TensorFlow开发,在MATLAB中的支持相对有限。 尽管如此,MathWorks官方文档提供了如何集成自定义深度学习层以及导入外部预训练模型的支持[^2]。对于希望在MATLAB中探索CNN-ViT架构的研究人员来说,可以考虑以下两种方法: #### 方法一:通过ONNX转换器引入预训练模型 一种可行的方式是从其他平台导出已有的混合模型并将其迁移到MATLAB环境中。具体而言,可以在Python环境下构建所需的CNN-ViT结构,并保存为ONNX格式文件;之后利用`importONNXLayers`函数加载到MATLAB工作区中继续调参优化。 ```matlab % 导入ONNX模型至MATLAB layers = importONNXLayers('cnn_vit.onnx'); analyzeNetwork(layers); ``` 这种方法允许使用者充分利用现有资源的同时享受MATLAB强大的工具箱功能。 #### 方法二:手动创建定制化深层网络 另一种更为灵活的选择是在MATLAB内部逐步搭建起整个网络拓扑。虽然这可能涉及到更多编码量,但也给予开发者更大的自由度去调整细节参数配置。下面给出一段简化版代码片段用于说明这一过程: ```matlab function dlnet = createCnnVitNet() % 定义输入尺寸 inputSize = [224 224 3]; % 构建基础CNN部分... layers_cnn = [ imageInputLayer(inputSize,'Normalization','none') convolution2dLayer(7,64,'Stride',2,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(3,'Stride',2)]; % 将特征图分割成patch序列作为ViT输入... patchSize = [16 16]; numPatchesPerDim = floor(inputSize(1:2)./patchSize); % 添加位置嵌入及多头注意力机制构成ViT模块... transformerLayers = ...]; % 组合两大部分形成最终网络架构 layerGraph = connectLayers(layerGraph,... 'lastConvOut', ... 'firstTransIn'); dlnet = dlnetwork(layerGraph); end ``` 值得注意的是上述示例仅展示了概念性的框架设计思路,实际应用时还需要根据特定任务需求进一步完善各个组件的具体实现逻辑。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值