[CVPR 2024] FinePOSE 源码阅读(一):核心模型架构与跨模态交互

Qwen-Image-Edit-2509

Qwen-Image-Edit-2509

图片编辑
Qwen

Qwen-Image-Edit-2509 是阿里巴巴通义千问团队于2025年9月发布的最新图像编辑AI模型,主要支持多图编辑,包括“人物+人物”、“人物+商品”等组合玩法

论文标题: FinePOSE: Fine-grained Prompt-driven 3D Human Pose Estimation via Diffusion Models

代码仓库https://github.com/PKU-ICST-MIPL/FinePOSE_CVPR2024

前言

在 3D 人体姿态估计领域,如何有效地利用文本语义信息一直是一个难点。CVPR 2024 的 FinePOSE 提出了一种基于扩散模型的细粒度提示驱动框架,通过引入 CLIP 和多层次的 Prompt 机制,显著提升了姿态估计的精度。

本文将结合论文架构图,逐行解析 FinePOSE 的核心代码实现(mixste_finepose.py),理解这个模型是如何一步步将文本、时间和空间信息融合在一起的。

一、 Transformer 构建块

FinePOSE 里无论是处理空间的 STE 还是处理时间的 TTE,它们底层调用的都是以下三个标准类。

1. MLP

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., changedim=False, currentdim=0, depth=0):
        """
        初始化 MLP 模块。
        
        Args:
            in_features (int): 输入特征的维度。
            hidden_features (int, optional): 隐藏层的维度。如果为 None,则默认等于输入维度。通常 hidden_features 会比 in_features 大 (比如 4 倍)。
            out_features (int, optional): 输出特征的维度。如果为 None,则默认等于输入维度。
            act_layer (nn.Module, optional): 激活函数层,默认为 GELU (Gaussian Error Linear Units)。
            drop (float, optional): Dropout 丢弃率,用于防止过拟合。
            changedim, currentdim, depth: (这些参数在当前代码逻辑中似乎未被直接使用,可能是为了兼容某些特殊的动态调整接口保留的冗余参数)
        """
        super().__init__() # 初始化父类 nn.Module
        
        # 处理默认参数:如果没指定输出/隐藏层维度,就设为和输入一样
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        # 第一层全连接:将特征从 in_features 映射到 hidden_features (通常是升维)
        self.fc1 = nn.Linear(in_features, hidden_features)
        
        # 激活函数:引入非线性因素,让网络能拟合复杂函数
        self.act = act_layer()
        
        # 第二层全连接:将特征从 hidden_features 映射回 out_features (通常是降维回原尺寸)
        self.fc2 = nn.Linear(hidden_features, out_features)
        
        # Dropout 层:随机丢弃神经元,防止过拟合
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        """
        前向传播逻辑
        输入 x 的形状通常为: [Batch, ..., in_features]
        """
        # 1. 线性变换 (升维)
        x = self.fc1(x)
        
        # 2. 非线性激活 (GELU)
        x = self.act(x)
        
        # 3. 第一次 Dropout
        x = self.drop(x)
        
        # 4. 线性变换 (降维/恢复维度)
        x = self.fc2(x)
        
        # 5. 第二次 Dropout
        x = self.drop(x)
        
        # 返回处理后的特征
        return x

2. Attention

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., comb=False, vis=False):
        """
        Args:
            dim (int): 输入特征维度。
            num_heads (int): 多头注意力的头数。
            qkv_bias (bool): 生成 Q,K,V 的线性层是否使用偏置。
            qk_scale (float): 手动设置缩放因子,如果为 None 则使用默认的 1/sqrt(head_dim)。
            attn_drop (float): 注意力矩阵的 Dropout 率。
            proj_drop (float): 输出投影层的 Dropout 率。
            comb (bool): 一个特殊的开关,用于切换注意力计算的维度顺序 (可能是为了适配特定的时空注意力变体)。
            vis (bool): 是否开启可视化模式 (代码中似乎仅作为标志位)。
        """
        super().__init__()
        self.num_heads = num_heads
        
        # 计算每个头的维度
        head_dim = dim // num_heads
        
        # 缩放因子:防止点积结果过大导致 Softmax 梯度消失
        self.scale = qk_scale or head_dim ** -0.5

        # 定义生成 Q, K, V 的全连接层
        # 输出维度是 dim * 3,因为它要同时生成 Q, K, V 三份数据
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        
        # 输出投影层:把多头的结果融合回原来的维度
        self.proj = nn.Linear(dim, dim) 

        self.proj_drop = nn.Dropout(proj_drop)
        self.comb = comb # 组合模式开关
        self.vis = vis   # 可视化开关

    def forward(self, x, vis=False):
        """
        前向传播
        输入 x: [Batch, N (序列长度), C (特征维度)]
        """
        B, N, C = x.shape
        
        # 1. 生成 Q, K, V
        # self.qkv(x) -> [B, N, 3*C]
        # reshape -> [B, N, 3, num_heads, head_dim]
        # permute -> [3, B, num_heads, N, head_dim]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        
        # 分离出 Q, K, V
        # q, k, v 的形状都是 [B, num_heads, N, head_dim]
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # 2. 计算注意力分数 (Attention Score)
        if self.comb == True:
            # 特殊模式:可能是在进行通道注意力 (Channel Attention) 或某种转置计算
            # q.transpose(-2, -1) -> [B, num_heads, head_dim, N]
            # attn = [B, num_heads, head_dim, N] @ [B, num_heads, N, head_dim] -> [B, num_heads, head_dim, head_dim]
            # 这种计算方式得到的注意力图大小是 head_dim * head_dim,关注的是特征通道间的关系
            attn = (q.transpose(-2, -1) @ k) * self.scale
        elif self.comb == False:
            # 标准模式:空间/时间注意力 (Spatial/Temporal Attention)
            # k.transpose(-2, -1) -> [B, num_heads, head_dim, N]
            # attn = [B, num_heads, N, head_dim] @ [B, num_heads, head_dim, N] -> [B, num_heads, N, N]
            # 这种计算方式得到的注意力图大小是 N * N,关注的是序列元素间的关系 (例如帧与帧)
            attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # 3. Softmax 归一化
        # 将分数转换为概率分布 (0~1 之间,和为 1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        # 4. 加权求和 (Weighted Sum)
        if self.comb == True:
            # 特殊模式的后续处理
            x = (attn @ v.transpose(-2, -1)).transpose(-2, -1)
            x = rearrange(x, 'B H N C -> B N (H C)')
        elif self.comb == False:
            # 标准模式:利用注意力概率对 V 进行加权
            # [B, num_heads, N, N] @ [B, num_heads, N, head_dim] -> [B, num_heads, N, head_dim]
            x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 还原形状为 [B, N, C]
        
        # 5. 输出投影
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

3. Block 

这是 STE 和 TTE 共用的核心模块

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., attention=Attention, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, comb=False, changedim=False, currentdim=0, depth=0, vis=False):
        """
        Args:
            dim (int): 输入/输出特征维度。
            num_heads (int): 注意力头数。
            mlp_ratio (float): MLP 隐藏层的放大倍数 (通常为 4)。
            attention (class): 使用的 Attention 类 (默认为上面定义的 Attention)。
            drop_path (float): 随机深度 (Stochastic Depth) 的概率,用于训练深层网络。
            changedim (bool): 是否改变特征维度 (这是该模型的一个特殊设计,可能用于构建 U-Net 形状的 Transformer)。
            currentdim, depth: 用于控制 changedim 逻辑的参数。
        """
        super().__init__()

        self.changedim = changedim
        self.currentdim = currentdim
        self.depth = depth
        if self.changedim:
            assert self.depth > 0 # 如果开启维度变化,必须指定总深度

        # 第一个 LayerNorm (用于 Attention 之前)
        self.norm1 = norm_layer(dim)
        
        # Attention 模块
        self.attn = attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, comb=comb, vis=vis)
        
        # DropPath (随机深度):在训练时随机丢弃整个残差分支,有助于深层网络的训练
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
        # 第二个 LayerNorm (用于 MLP 之前)
        self.norm2 = norm_layer(dim)
        
        # MLP 模块
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        
        # --- 特殊的维度调整逻辑 (类似 U-Net 的下采样/上采样) ---
        if self.changedim and self.currentdim < self.depth//2:
            # 如果在网络的前半部分,进行降维 (Reduction)
            # 使用 1x1 卷积将维度减半: dim -> dim // 2
            self.reduction = nn.Conv1d(dim, dim//2, kernel_size=1)
        elif self.changedim and depth > self.currentdim > self.depth//2:
            # 如果在网络的后半部分,进行升维 (Improve/Expansion)
            # 使用 1x1 卷积将维度翻倍: dim -> dim * 2
            self.improve = nn.Conv1d(dim, dim*2, kernel_size=1)
        self.vis = vis

    def forward(self, x, vis=False):
        """
        前向传播
        输入 x: [Batch, Sequence, Dim]
        """
        # 1. Attention 分支 (Pre-Norm 结构)
        # x = x + DropPath(Attention(LayerNorm(x)))
        x = x + self.drop_path(self.attn(self.norm1(x), vis=vis))
        
        # 2. MLP 分支
        # x = x + DropPath(MLP(LayerNorm(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        
        # 3. 维度调整 (可选)
        # 这里的 rearrange 操作是为了适配 Conv1d 的输入格式 (Batch, Channel, Length)
        if self.changedim and self.currentdim < self.depth//2:
            x = rearrange(x, 'b t c -> b c t') # 交换维度,把特征放到第 1 维
            x = self.reduction(x)              # 降维
            x = rearrange(x, 'b c t -> b t c') # 换回来
        elif self.changedim and self.depth > self.currentdim > self.depth//2:
            x = rearrange(x, 'b t c -> b c t')
            x = self.improve(x)                # 升维
            x = rearrange(x, 'b c t -> b t c')
            
        return x

二、 核心组件

1. FPP模块核心

可学习向量(ctx_subject、ctx_verb、ctx_speed、ctx_head、ctx_body、ctx_arm、ctx_leg)在模型初始化部分进行定义。

为了从加了噪声的3D姿态重建出纯净的3D姿态,FinePOSE使用常规的2D关键点X、t和细粒度、部位感知的提示嵌入P来指导去噪过程。我们利用FPP模块来学习这个P。FPP模块在提示嵌入空间中编码了三种与姿态相关的信息,包括其动作类别、粗粒度和细粒度的人体部位(如“人、头、身体、手臂、腿”)以及运动学信息“速度”。

P的形状是K,L,D,K是文本提示的数量、L是每个提示中的token数量、D是token嵌入的维度。K=7,且text分别表示{人, [动作类别], 速度, 头, 身体, 手臂, 腿},每个文本提示的前四个token与可学习向量拼在一起作为最终嵌入。可学习向量通过u=0标准差为0.02的高斯分布初始化。

def encode_text(self, text, pre_text_tensor):
        """
        处理文本提示,实现 FPP 模块的核心功能
        将原始文本与可学习向量结合,并提取全局和细粒度特征
        """
        
        # 使用冻结的 CLIP 模型将文本 ID 转换为初始向量表示
        with torch.no_grad():
            # x 是动作类别的嵌入,维度为 [Batch, 77, 512]
            # .type() 确保数据精度与 CLIP 模型一致,通常是半精度或单精度
            x = self.clip_text.token_embedding(text).type(self.clip_text.dtype)
            
            # pre_text_tensor 是其他 6 个静态属性的嵌入,维度为 [Batch, 6, 77, 512]
            pre_text_tensor = self.clip_text.token_embedding(pre_text_tensor).type(self.clip_text.dtype)

        #构建混合提示:处理 Subject 部分
        # 获取可学习参数,维度为 [M, 512]
        learnable_prompt_subject = self.ctx_subject
        # 增加 Batch 维度并广播,维度变为 [Batch, M, 512]
        learnable_prompt_subject = learnable_prompt_subject.view(1, self.ctx_subject.shape[0], self.ctx_subject.shape[1])
        learnable_prompt_subject = learnable_prompt_subject.repeat(x.shape[0], 1, 1)
        # 将可学习向量与原始文本的前几个 token 拼接
        # 这里取 pre_text_tensor 的第 0 个属性
        learnable_prompt_subject = torch.cat((learnable_prompt_subject, pre_text_tensor[:, 0, :self.remain_len, :]), dim=1)

        # 构建混合提示:处理 Verb 部分
        # 注意这里的硬提示来自输入的动态文本 x
        learnable_prompt_verb = self.ctx_verb
        learnable_prompt_verb = learnable_prompt_verb.view(1, self.ctx_verb.shape[0], self.ctx_verb.shape[1])
        learnable_prompt_verb = learnable_prompt_verb.repeat(x.shape[0], 1, 1)
        learnable_prompt_verb = torch.cat((learnable_prompt_verb, x[:, :self.remain_len, :]), dim=1)

        # 构建混合提示:处理 Speed 部分
        learnable_prompt_speed = self.ctx_speed
        learnable_prompt_speed = learnable_prompt_speed.view(1, self.ctx_speed.shape[0], self.ctx_speed.shape[1])
        learnable_prompt_speed = learnable_prompt_speed.repeat(x.shape[0], 1, 1)
        # 这里取 pre_text_tensor 的第 1 个属性
        learnable_prompt_speed = torch.cat((learnable_prompt_speed, pre_text_tensor[:, 1, :self.remain_len, :]), dim=1)

        # 构建混合提示:处理身体部位 Head, Body, Arm, Leg 
        # 分别取 pre_text_tensor 的第 2, 3, 4, 5 个属性进行拼接
        
        learnable_prompt_head = self.ctx_head
        learnable_prompt_head = learnable_prompt_head.view(1, self.ctx_head.shape[0], self.ctx_head.shape[1])
        learnable_prompt_head = learnable_prompt_head.repeat(x.shape[0], 1, 1)
        learnable_prompt_head = torch.cat((learnable_prompt_head, pre_text_tensor[:, 2, :self.remain_len, :]), dim=1)

        learnable_prompt_body = self.ctx_body
        learnable_prompt_body = learnable_prompt_body.view(1, self.ctx_body.shape[0], self.ctx_body.shape[1])
        learnable_prompt_body = learnable_prompt_body.repeat(x.shape[0], 1, 1)
        learnable_prompt_body = torch.cat((learnable_prompt_body, pre_text_tensor[:, 3, :self.remain_len, :]), dim=1)

        learnable_prompt_arm = self.ctx_arm
        learnable_prompt_arm = learnable_prompt_arm.view(1, self.ctx_arm.shape[0], self.ctx_arm.shape[1])
        learnable_prompt_arm = learnable_prompt_arm.repeat(x.shape[0], 1, 1)
        learnable_prompt_arm = torch.cat((learnable_prompt_arm, pre_text_tensor[:, 4, :self.remain_len, :]), dim=1)

        learnable_prompt_leg = self.ctx_leg
        learnable_prompt_leg = learnable_prompt_leg.view(1, self.ctx_leg.shape[0], self.ctx_leg.shape[1])
        learnable_prompt_leg = learnable_prompt_leg.repeat(x.shape[0], 1, 1)
        learnable_prompt_leg = torch.cat((learnable_prompt_leg, pre_text_tensor[:, 5, :self.remain_len, :]), dim=1)

   
        # 将 7 个部分的混合提示拼接成一个超长的序列
        # x 的维度变为 [Batch, Total_Length, 512]
        x = torch.cat((learnable_prompt_subject, learnable_prompt_verb, learnable_prompt_speed, learnable_prompt_head, learnable_prompt_body, learnable_prompt_arm, learnable_prompt_leg), dim=1)

        with torch.no_grad():
            # 注入位置信息:CLIP 的位置编码包含了序列顺序信息
            x = x + self.clip_text.positional_embedding.type(self.clip_text.dtype)
            
            # 维度调整:PyTorch Transformer 要求输入形状为 [Seq_Len, Batch, Dim]
            x = x.permute(1, 0, 2) 
            
            # 深度特征交互:让所有的 Prompt 部分利用注意力机制互相交流
            x = self.clip_text.transformer(x)
            
            # 层归一化
            x = self.clip_text.ln_final(x).type(self.clip_text.dtype)

        
        # 这里的 text_pre_proj 是一个占位符,不改变数据
        x = self.text_pre_proj(x)
        
        # 使用可训练的 Transformer 编码器进行领域适配
        # 将 CLIP 的通用图像文本特征转化为 3D 姿态领域的文本特征
        xf_out = self.textTransEncoder(x)
        xf_out = self.text_ln(xf_out)
        
        # 提取全局宏观指令
        # text.argmax 找到每句话结束符号的位置,因为结束符 ID 最大
        # 提取该位置的向量作为整句话的全局摘要
        global_feature = xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])]
        
        # 通过线性层映射,用于后续注入到时间步嵌入中
        xf_proj = self.text_proj(global_feature)
        
        # 准备微观细粒度指令
        # 调整维度回 [Batch, Seq_Len, Dim],保留完整序列用于后续的跨模态交互
        xf_out = xf_out.permute(1, 0, 2)
        
        return xf_proj, xf_out

2. PTS 模块核心

# 定义风格化模块
# 作用:将条件嵌入 (emb) 注入到特征 (h) 中,调节特征的分布 (scale & shift)
class StylizationBlock(nn.Module):

    def __init__(self, latent_dim, time_embed_dim, dropout):
        """
        Args:
            latent_dim (int): 输入特征 h 的维度。
            time_embed_dim (int): 条件嵌入 emb 的维度 (例如时间步编码)。
            dropout (float): Dropout 概率。
        """
        super().__init__()
        
        # 1. 条件映射层 
        # 将条件向量映射为两倍的特征维度 (因为要生成 scale 和 shift 两个参数)
        self.emb_layers = nn.Sequential(
            nn.SiLU(), # 激活函数
            nn.Linear(time_embed_dim, 2 * latent_dim), # 线性层: dim -> 2*dim
        )
        
        # 2. 归一化层 
        # 对输入特征 h 进行标准化
        self.norm = nn.LayerNorm(latent_dim)
        
        # 3. 输出层 
        self.out_layers = nn.Sequential(
            nn.SiLU(),
            nn.Dropout(p=dropout),
            # 使用 zero_module 初始化,意味着初始输出为 0 (不改变主干流)
            zero_module(nn.Linear(latent_dim, latent_dim)),
        )

    def forward(self, h, emb):
        """
        前向传播
        输入 h: [Batch, Sequence (T), Dim (D)] - 主干特征
        输入 emb: [Batch, Dim (D)] 或 [Batch, 1, Dim] - 条件嵌入 (例如时间编码)
        """
        B, T, D = h.shape
        
        # 调整 emb 形状以匹配 h
        emb = emb.view(B, T, D)
        
        # 1. 计算 Scale 和 Shift
        # 通过 MLP 处理条件向量
        # [:, 0:1, :] 表示只取第一个时间步的特征
        emb_out = self.emb_layers(emb)[:,0:1,:] # 形状: [B, 1, 2*D]
        
        # 将结果切分为两份:缩放因子 (scale) 和 平移因子 (shift)
        # chunk(..., 2, dim=2) -> [B, 1, D], [B, 1, D]
        scale, shift = torch.chunk(emb_out, 2, dim=2)
        
        # 2. 适应性归一化
        # 公式: y = norm(x) * (1 + scale) + shift
        # 这是条件生成的核心:利用外部条件来“调制”特征的均值和方差
        h = self.norm(h) * (1 + scale) + shift
        
        # 3. 输出处理 
        h = self.out_layers(h)
        
        return h

3. FPC 模块核心


# 定义时序交叉注意力模块
# 作用:计算动作序列 (Query) 与文本序列 (Key/Value) 之间的注意力,实现多模态融合
class TemporalCrossAttention(nn.Module):

    def __init__(self, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
        """
        Args:
            latent_dim (int): 动作特征的维度 (D)。
            text_latent_dim (int): 文本特征的维度 (L)。
            num_head (int): 注意力头数。
            time_embed_dim (int): 时间嵌入维度 (用于 StylizationBlock)。
        """
        super().__init__()
        self.num_head = num_head
        # 归一化层
        self.norm = nn.LayerNorm(latent_dim)
        self.text_norm = nn.LayerNorm(text_latent_dim)
        # 生成 Q, K, V 的线性层
        # Query 来自动作特征,维度 D
        self.query = nn.Linear(latent_dim, latent_dim)
        # Key 和 Value 来自文本特征,维度从 L 映射到 D
        self.key = nn.Linear(text_latent_dim, latent_dim)
        self.value = nn.Linear(text_latent_dim, latent_dim)
        self.dropout = nn.Dropout(dropout)
        # 输出投影层:使用 StylizationBlock,不仅融合信息,还注入时间步信息
        self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
    
    def forward(self, x, xf, emb):
        """
        前向传播
        输入 x: [Batch, T (Frames), D] - 动作特征序列
        输入 xf: [Batch, N (Text Tokens), L] - 文本特征序列 (来自 CLIP)
        输入 emb: [Batch, D] - 时间步嵌入
        """
        B, T, D = x.shape
        N = xf.shape[1] # 文本长度 (例如 77)
        H = self.num_head
        # 1. 生成 Query (来自动作 x) 
        # [B, T, D] -> [B, T, D] -> [B, T, 1, D]
        query = self.query(self.norm(x)).unsqueeze(2)
        # 2. 生成 Key (来自文本 xf) 
        # [B, N, L] -> [B, N, D] -> [B, 1, N, D]
        key = self.key(self.text_norm(xf)).unsqueeze(1)
        # 这里有个 repeat 操作,可能是为了处理文本 Batch 和动作 Batch 大小不一致的情况
        key = key.repeat(int(B/key.shape[0]), 1, 1, 1)
        # 3. 重塑维度以适应多头 (Multi-head)
        # Query: [B, T, H, D/H]
        query = query.view(B, T, H, -1)
        # Key: [B, N, H, D/H]
        key = key.view(B, N, H, -1)
        # 4. 计算注意力分数
        # 使用 einsum 进行张量乘法
        # 'bnhd, bmhd -> bnmh' 含义:
        # b: Batch, n: T (frames), m: N (text tokens), h: Head, d: Dim
        # 计算每个帧 (n) 与每个文本 token (m) 的相关性
        attention = torch.einsum('bnhd,bmhd->bnmh', query, key) / math.sqrt(D // H)
        # Softmax 归一化 (在文本长度 m 维度上)
        weight = self.dropout(F.softmax(attention, dim=2))
        # 5. 生成 Value (来自文本 xf) 
        value = self.value(self.text_norm(xf)).unsqueeze(1)
        value = value.repeat(int(B/value.shape[0]), 1, 1, 1)
        value = value.view(B, N, H, -1)
        # 6. 加权求和 
        # 'bnmh, bmhd -> bnhd'
        # 用算出来的权重 (weight) 去加权文本的 Value
        # 结果是每个帧都融合了相关的文本信息
        y = torch.einsum('bnmh,bmhd->bnhd', weight, value).reshape(B, T, D)
        # 7. 输出处理与残差连接 
        # 通过 StylizationBlock 进一步注入时间步信息 emb
        # 残差连接
        y = x + self.proj_out(y, emb)
        return y

三、 主模型架构 (MixSTE2)

1. 模型初始化

class MixSTE2(nn.Module):
    def __init__(self, num_frame=9, num_joints=17, in_chans=2, embed_dim_ratio=32, depth=4,
                 num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.2, norm_layer=None, is_train=True):
        """
        Args:
            num_frame (int): 输入序列的帧数。
            num_joints (int): 每帧姿态的关节点数量。
            in_chans (int): 输入通道数,对于 2D 关键点,通常是 2 (x, y)。
            embed_dim_ratio (int): 嵌入维度的大小。
            depth (int): Transformer 编码器的深度(Block 的数量)。
            num_heads (int): 多头注意力机制的头数。
            mlp_ratio (float): MLP 层的隐藏维度与嵌入维度的比率。
            qkv_bias (bool): 是否为 QKV 线性层启用偏置。
            qk_scale (float): QK 缩放因子。
            drop_rate (float): Dropout 率。
            attn_drop_rate (float): Attention Dropout 率。
            drop_path_rate (float): Stochastic Depth 的衰减率。
            norm_layer (nn.Module): 归一化层,默认为 LayerNorm。
            is_train (bool): 标记模型是否处于训练模式。
        """
        super().__init__()

        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        embed_dim = embed_dim_ratio  # 嵌入维度
        out_dim = 3  # 输出维度,3D 坐标 (x, y, z)
        self.is_train = is_train

        # 姿态嵌入模块
        # 线性层,将每个关节点的输入特征(2D坐标+带噪3D坐标)映射到高维嵌入空间。
        # in_chans + 3 表示输入是 2D(x,y) + 3D(x,y,z) = 5 维
        self.Spatial_patch_to_embedding = nn.Linear(in_chans + 3, embed_dim_ratio)

        # 可学习的空间位置编码,为每个关节点学习一个特定的位置嵌入。
        # 维度: (1, num_joints, embed_dim)
        self.Spatial_pos_embed = nn.Parameter(torch.zeros(1, num_joints, embed_dim_ratio))

        # 可学习的时间位置编码,为序列中的每一帧学习一个特定的位置嵌入。
        # 维度: (1, num_frame, embed_dim)
        self.Temporal_pos_embed = nn.Parameter(torch.zeros(1, num_frame, embed_dim))

        self.pos_drop = nn.Dropout(p=drop_rate)

        # 时间步和文本提示嵌入模块 
        # MLP 网络,用于将扩散过程的时间步 t 转换为高维嵌入。
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(embed_dim_ratio), # 将标量 t 转换为正弦位置编码
            nn.Linear(embed_dim_ratio, embed_dim_ratio * 2),
            nn.GELU(),
            nn.Linear(embed_dim_ratio * 2, embed_dim_ratio),
        )

        #  Transformer 核心模块 
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # 随机深度衰减规则
        self.block_depth = depth

        # 空间Transformer编码器块 (Spatial Transformer Encoder)
        # 负责在单帧内对所有关节点之间的空间关系进行建模。
        self.STEblocks = nn.ModuleList([
            Block(
                dim=embed_dim_ratio, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])

        # 时间Transformer编码器块 (Temporal Transformer Encoder)
        # 负责对单个关节点在所有帧之间的时间动态进行建模。
        self.TTEblocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
            for i in range(depth)])

        self.Spatial_norm = norm_layer(embed_dim_ratio)
        self.Temporal_norm = norm_layer(embed_dim)

        # 输出头 
        # 最终的线性层,将处理后的高维特征映射回 3D 坐标空间。
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, out_dim),
        )

        # 文本处理与交叉注意力模块 
        # 时间交叉注意力模块,实现图中的 FPC (Fine-grained Prompt-Pose Communication)
        # Query 来自姿态特征,Key 和 Value 来自文本提示特征。
        self.temporal_cross_attn = TemporalCrossAttention(512, 512, num_heads, drop_rate, 512)
        
        # 文本编码器,用于进一步处理从 CLIP 获得的文本嵌入。
        self.text_pre_proj = nn.Identity()
        textTransEncoderLayer = nn.TransformerEncoderLayer(
            d_model=512, nhead=num_heads, dim_feedforward=2048,
            dropout=drop_rate, activation="gelu")
        self.textTransEncoder = nn.TransformerEncoder(
            textTransEncoderLayer, num_layers=4)
        self.text_ln = nn.LayerNorm(512)
        
        # 线性层,用于生成最终注入到姿态特征中的全局文本提示。
        self.text_proj = nn.Sequential(
            nn.Linear(512, 512)
        )

        # --- 可学习提示 (FPP) 模块 ---
        # 加载预训练的 CLIP 模型,用于将文本转换为初始嵌入。
        self.clip_text, _ = clip.load('ViT-B/32', "cpu")
        set_requires_grad(self.clip_text, False)  # 冻结 CLIP 模型参数

        self.remain_len = 4 # 保留原始 token 的长度

        # 定义一系列可学习的向量 (context vectors),这些就是 FPP 的核心。
        # 它们与原始的文本 token 嵌入拼接在一起,形成新的、可优化的 prompt。
        # 每个 `ctx_*` 对应图中的一类细粒度提示(如 person, action class, head, body 等)。
        ctx_vectors_subject = torch.empty((7 - self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_subject, std=0.02)
        self.ctx_subject = nn.Parameter(ctx_vectors_subject)

        ctx_vectors_verb = torch.empty((12 - self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_verb, std=0.02)
        self.ctx_verb = nn.Parameter(ctx_vectors_verb)

        # ... (为 speed, head, body, arm, leg 定义类似的可学习向量) ...
        ctx_vectors_speed = torch.empty((10-self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_speed, std=0.02)
        self.ctx_speed = nn.Parameter(ctx_vectors_speed)

        ctx_vectors_head = torch.empty((10-self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_head, std=0.02)
        self.ctx_head = nn.Parameter(ctx_vectors_head)
        
        ctx_vectors_body = torch.empty((10-self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_body, std=0.02)
        self.ctx_body = nn.Parameter(ctx_vectors_body)
        
        ctx_vectors_arm = torch.empty((14-self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_arm, std=0.02)
        self.ctx_arm = nn.Parameter(ctx_vectors_arm)

        ctx_vectors_leg = torch.empty((14-self.remain_len), 512, dtype=self.clip_text.dtype)
        nn.init.normal_(ctx_vectors_leg, std=0.02)
        self.ctx_leg = nn.Parameter(ctx_vectors_leg)

2. Spatial MHSA模块核心

实现全局条件(时间+文本)的注入。

def STE_forward(self, x_2d, x_3d, t, xf_proj):
        """空间编码器前向传播"""
        if self.is_train:
            # 训练时,输入维度 (batch, frame, joint, channel)
            x = torch.cat((x_2d, x_3d), dim=-1) # 拼接 2D 和 3D 姿态
            b, f, n, c = x.shape
            x = rearrange(x, 'b f n c -> (b f) n c') # (b*f, n, c) 方便进行空间处理
            x = self.Spatial_patch_to_embedding(x) # 嵌入到高维
            x += self.Spatial_pos_embed # 添加空间位置编码
            time_embed = self.time_mlp(t)[:, None, None, :] # (b, 1, 1, dim)
            xf_proj = xf_proj.view(xf_proj.shape[0], 1, 1, xf_proj.shape[1])
            time_embed = time_embed + xf_proj # 注入全局文本提示
            time_embed = time_embed.repeat(1, f, n, 1)
            time_embed = rearrange(time_embed, 'b f n c -> (b f) n c')
            x += time_embed # 将融合后的嵌入加到每个关节点特征上
        else: # 测试时,因为有 num_proposals,所以多一个维度
            # 维度 (batch, proposal, frame, joint, channel)
            x_2d = x_2d[:, None].repeat(1, x_3d.shape[1], 1, 1, 1)
            x = torch.cat((x_2d, x_3d), dim=-1)
            b, h, f, n, c = x.shape
            x = rearrange(x, 'b h f n c -> (b h f) n c')
            # 后续操作与训练时类似
            x = self.Spatial_patch_to_embedding(x)
            x += self.Spatial_pos_embed
            time_embed = self.time_mlp(t)[:, None, None, None, :]
            xf_proj = xf_proj.view(xf_proj.shape[0], 1, 1, 1, xf_proj.shape[1])
            time_embed = time_embed + xf_proj
            time_embed = time_embed.repeat(1, h, f, n, 1)
            time_embed = rearrange(time_embed, 'b h f n c -> (b h f) n c')
            x += time_embed

        x = self.pos_drop(x)
        # 通过第一个空间 Transformer Block
        blk = self.STEblocks[0]
        x = blk(x)
        x = self.Spatial_norm(x)
        # 维度重排,为时间处理做准备: (b*f, n, dim) -> (b*n, f, dim)
        x = rearrange(x, '(b f) n cw -> (b n) f cw', f=f)
        return x, time_embed

3.Temporal MHSA模块核心

在通过 FPC 模块吸收了文本信息后,数据的视角需要从特征融合切换回时序建模。这一步的核心任务是给数据打上时间戳,并进行初步的轨迹平滑。

def TTE_foward(self, x):
    """
    输入 x: [Batch, Frame, Dim] 

    """
    b, f, _ = x.shape
    
    # 1. 注入时间位置编码
    # Transformer 不懂顺序,必需加上 Positional Embedding 才能理解 第1帧 和 第2帧 的区别
    x += self.Temporal_pos_embed 
    x = self.self.pos_drop(x)

    # 2. 初始时间建模
    # 调用第 0 个时间 Block
    # 这一步模型在 Frame 维度进行 Attention,平滑单个关节的运动轨迹
    blk = self.TTEblocks[0]
    x = blk(x)
    
    x = self.Temporal_norm(x)
    return x

4.Spatial-Temporal MHSA模块核心

模型进入了一个循环,通过不断地切换空间视角(骨架结构)时间视角(运动轨迹),对 3D 姿态进行精细打磨。

def ST_foward(self, x):
    """
    交替进行时空编码
    输入 x: [Batch, Frame, Joint, Dim]
    """
    b, f, n, cw = x.shape
    
    # 从第 1 个 Block 开始循环 (因为第 0 个已经被 STE 和 TTE 用掉了)
    for i in range(1, self.block_depth):
        
        #  A. 空间步 
        # 维度重排: 'b f n c -> (b f) n c'
        # 将 Frame 维度合并到 Batch 中,强迫模型忽略时间,只看单帧内的骨架
        x = rearrange(x, 'b f n cw -> (b f) n cw')
        
        steblock = self.STEblocks[i]
        x = steblock(x)
        x = self.Spatial_norm(x)
        
        # B. 时间步
        # 维度重排: '(b f) n c -> (b n) f c'
        # 将 Joint 维度合并到 Batch 中,强迫模型忽略骨架,只看单关节的轨迹
        x = rearrange(x, '(b f) n cw -> (b n) f cw', f=f)
        
        tteblock = self.TTEblocks[i]
        x = tteblock(x)
        x = self.Temporal_norm(x)
        
        # 恢复维度,准备进入下一次循环
        x = rearrange(x, '(b n) f cw -> b f n cw', n=n)
        
    return x

您可能感兴趣的与本文相关的镜像

Qwen-Image-Edit-2509

Qwen-Image-Edit-2509

图片编辑
Qwen

Qwen-Image-Edit-2509 是阿里巴巴通义千问团队于2025年9月发布的最新图像编辑AI模型,主要支持多图编辑,包括“人物+人物”、“人物+商品”等组合玩法

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值