"""---------*- coding: utf-8 -*------------------
File Name: a
Author: kingsCat
date: 2021/11/26 17:13
Description:
----------------------------------------------"""
import paddle
import paddle.nn as nn
import numpy as np
# np.set_rintoptions(threshold=np.inf)
from PIL import Image
import cv2
import torch
import matplotlib.pyplot as plt
'''
'''
'''
img=Image.open("./p1.jpg").crop((100,100,102,102))
img1=np.array(img)
# img2=paddle.to_tensor(img)
img3=paddle.to_tensor(img1)
#print(img3)
img4=torch.as_tensor(img1)
#print(img4)
#print(img)
'''
class Attention(nn.Layer):
def __init__(self,
embed_dim,
num_heads,
qkv_bias=False,
qk_scale=None,
attention_dropout=0.):
super().__init__()
self.num_heads=num_heads
self.embed_dim=embed_dim#一个token的维度
self.head_dim=int(embed_dim/num_heads)#用多头,但保持总的embed不变
self.all_head_dim=self.head_dim*num_heads
self.qkv=nn.Linear(embed_dim,
self.all_head_dim*3,
bias_attr=False if qkv_bias is False else None)#直接一个全连接层得到qkv
#bias_attr none 给一个0的初始化
self.scale=self.head_dim**-0.5 if qk_scale is None else qk_scale
self.softmax=nn.Softmax(-1)#给最后一个轴做softmax
self.project=nn.Linear(self.all_head_dim,embed_dim)#
def tranpose_muti_head(self,x):
# x:[B,N,all_head_dim]
# ??这个什么意思,没看懂
# #print("\n0:x.shape",x.shape)
new_shape=x.shape[:-1]+[self.num_heads,self.head_dim]
x=x.reshape(new_shape)
#x:[B,N,num_heads,head_dim]
x=x.transpose([0,2,1,3])
# #print("\nx.shape:",x.shape)
#x:[B,num_heads,num_patches,head_dim]
return x
def forward(self,x):
#[8,16,96]
B,N,_=x.shape#
# #print("x.shape",x.shape,B,N)
qkv=self.qkv(x)#沿着最后一个维度分成三份
#print("qkv:",qkv.shape)
qkv=qkv.chunk(3,-1)
# #print("qkv list:",qkv)
#[B,N,all_head_dim]*3
q,k,v=map(self.tranpose_muti_head,qkv)#map的作用,把qkv列表中的三份,分别给到那个函数里面去,分别存到q k v
# q,k,v:[B,num_heads,num_patches,head_dim] 这样才能做矩阵乘法,多头放在前面,每个头单独去做 实际是[num_patches,head_dim]相乘 16*96 96*16
#print("q,k,v:",q.shape)
attn=paddle.matmul(q,k,transpose_y=True)#qk相乘,k转置
#print("q*k:",attn.shape)
attn=self.scale*attn
attn=self.softmax(attn)
#dropout
#attn:[B,num_heads,N]
out=paddle.matmul(attn,v)
#print("attn*v:",out.shape)
out=out.transpose([0,2,1,3])
#print("transpose out.shape:",out.shape)
out=out.reshape([B,N,-1])
out=self.project(out)#[8, 16, 4, 24]
# #print("final out:",out.shape)
return out
class PatchEmbedding(nn.Layer):#[1,1,28,28]->[1,16,1]
def __init__(self,img_size,patch_size,in_channels,embed_dim,dropout=0.):
super().__init__()
n_patches=(img_size//patch_size)**2
self.patch_embed=nn.Conv2D(in_channels=in_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias_attr=False)
self.dropout=nn.Dropout(dropout)
# class token
self.class_token=paddle.create_parameter(shape=[1,1,embed_dim],
dtype='float32',
default_initializer=nn.initializer.Constant(0.))
#position embedding
self.position_embedding=paddle.create_parameter(shape=[1,n_patches+1,embed_dim],
dtype='float32',
default_initializer=nn.initializer.TruncatedNormal(std=.02))
def forward(self,x):
class_tokens=self.class_token.expand([x.shape[0],-1,-1])#每一张图片都增加了class_token,这样就能拼起来
# [1,1,28,28] [NCHW][N,embed_dim,h',w']c作为embed_dim
x=self.patch_embed(x)#[1,1,28,28]->[1,1,4,4]
# #print(x)
x=x.flatten(2)#[n,c,HW][1,1,16]
x=x.transpose([0,2,1])#[n,HW,c][1,16,1]
x=paddle.concat([class_tokens,x],axis=1)
x=x+self.position_embedding
return x
class MLP(nn.Layer):
def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.):
super().__init__()
self.fc1=nn.Linear(embed_dim,int(embed_dim*mlp_ratio))
self.fc2=nn.Linear(int(embed_dim*mlp_ratio),embed_dim)
self.act=nn.GELU()
self.dropout=nn.Dropout(dropout)
def forward(self,x):
x=self.fc1(x)
# #print(type(x))
x=self.act(x)
x=self.dropout(x)
x=self.fc2(x)
x=self.dropout(x)
return x
class Encoder(nn.Layer):
def __init__(self,embed_dim):
super().__init__()
self.attn=Attention(embed_dim=embed_dim,num_heads=4)#Todo
self.attn_norm=nn.LayerNorm(embed_dim)
self.mlp=MLP(embed_dim)
self.mlp_norm=nn.LayerNorm(embed_dim)
def forward(self, x):
h=x
x=self.attn_norm(x)
x=self.attn(x)
x=h+x
h=x
x=self.mlp_norm(x)
x=self.mlp(x)
x=h+x
return x
class Vit(nn.Layer):
def __init__(self):
super().__init__()
self.patch_embed=PatchEmbedding(224,7,3,16)
layer_list=[Encoder(16) for i in range(5)]
self.encoders=nn.LayerList(layer_list)
self.head=nn.Linear(16,10)#10jiushi fenlei leishu
self.avgpool=nn.AdaptiveAvgPool1D(1)#shuchu yige tensor 这么多tensor怎么做分类?这么多做一个平均,放到head里面去
def forward(self, x):
x=self.patch_embed(x)
for encoder in self.encoders:
x=encoder(x)
#LayerNorm
#[n,h*w,c]->[n,c,h*w]
x=x.transpose([0,2,1])
x=self.avgpool(x)#n,c,1
x=x.flatten(1)#[n,c]
x=self.head(x)
return x
def printArray(a):
for i in range(a.shape[0]):
for j in range(a.shape[1]):
print(a[i,j],end=" ")
print()
def main():
vit=Vit()
# #print(vit)
paddle.summary(vit,(4,3,224,224))
if __name__=="__main__":
main()
'''img=Image.open("./p.jpg")
img_a=np.array(img)
sample=paddle.to_tensor(img_a,dtype="float32")
# #printArray(img_a)
sample=sample.reshape([1,1,28,28])
patch_embed=PatchEmbedding(28,7,1,1)
out=patch_embed(sample)
# #print(out)
mlp=MLP(1)
mlp_out=mlp(out)
#print(out)
#print(mlp_out)
# #print(img_e(img_t))
'''
'''t=paddle.randn([4,3,224,224])
model=Vit()
out=model(t)
#print(out.shape)'''
'''img=Image.open("./p.jpg")
img_a=np.array(img)
img_t=torch.Tensor(img_a)
#print(img_t.shape)'''
VisionTransformer
最新推荐文章于 2024-11-24 01:30:49 发布