手撕CLIP -- Day4 -- text_encoder.py

手撕CLIP – Day4 – text_encoder.py

Contrastive Language-Image Pre-Training (CLIP) 模型原理

CLIP 网络结构图

在这里插入图片描述

CLIP 网络结构

text_encoder(文本编码器)代码

Part1 引入库函数

# 该模块主要是为了实现text编码,但是因为只有十类,所以值需要用nn.emdding进行初始化就行

'''
# Part1 引入相关的库函数
'''
import torch
from torch import nn

Part2 初始化一个文本编码器的类

'''
# Part2 设计下文本编码器的类
'''
class TextEncoder(nn.Module):
    def __init__(self,voca_size,emd_size=16,f_size=64,final_emd_size=8):
        super().__init__()

        # 首先需要初始化嵌入的类别和维度
        self.emd=nn.Embedding(num_embeddings=voca_size,embedding_dim=emd_size)
        # 对嵌入的维度进行初始化
        self.linear1=nn.Linear(emd_size,f_size)
        self.linear2=nn.Linear(f_size,emd_size)
        self.linear3=nn.Linear(emd_size,final_emd_size)

        self.ln=nn.LayerNorm(final_emd_size)

    def forward(self,batch_label):
        batch_label_emd=self.emd(batch_label)
        batch_label_emd=self.linear1(batch_label_emd)
        batch_label_emd=self.linear2(batch_label_emd)
        batch_label_emd=self.linear3(batch_label_emd)
        return self.ln(batch_label_emd)

Part3 测试

'''
# 测试
'''
if __name__=='__main__':
    text_encoder=TextEncoder(voca_size=10,emd_size=16,f_size=64,final_emd_size=8)
    x=torch.tensor([1,2,3,4,5,6,7,8,9,0])
    y=text_encoder(x)
    print(y.shape)

参考

视频讲解:【多模态】复现OpenAI的CLIP模型_哔哩哔哩_bilibili

模型原理讲解:手撕CLIP – Day1 – 基础原理-优快云博客

github资料:YanxinTong/CLIP_Pytorch: 利用 Pytorch 手撕 CLIP 模型

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值