SAM之MaskDecoder总结(个人研究)

深入解读:SAM模型中的MaskDecoder及其Transformer结构
本文详细介绍了SAM模型中的MaskDecoder组件,它利用Transformer结构将imageembeddings和promptembeddings映射到outputmask,包括TransformerDecoder的修改、上采样过程、MLP网络和IoU预测。重点讨论了模块的初始化、前向传播以及如何结合两种输入进行预测.

SAM模型总览

前言

        SAM模型主要由三部分组成,本文旨在结合代码的方式对其中mask decoder进行详细的描述与总结,不足之处还望多多指教

        mask decoder模块可以高效的将image embeddings与prompt embeddings映射到一个output mask中。为了结合image embeddings与prompt embeddings 这两个输入, 受Transformer模型的启发,作者修改了transformer中标准的Transformer Decoder作为本文所要介绍的Mask Decoder。

Mask Decoder定义

可以结合论文中该图看下列代码:

def __init__(
        self,
        *,
        transformer_dim: int,
        transformer: nn.Module,
        num_multimask_outputs: int = 3,
        activation: Type[nn.Module] = nn.GELU,
        iou_head_depth: int = 3,
        iou_head_hidden_dim: int = 256,
    ) -> None:
        super().__init__()
        self.transformer_dim = transformer_dim
        self.transformer = transformer

        self.num_multimask_outputs = num_multimask_outputs

        self.iou_token = nn.Embedding(1, transformer_dim)  
        self.num_mask_tokens = num_multimask_outputs + 1  
        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

        # ---upscaled---
        # ---四倍上采样---
        self.output_upscaling = nn.Sequential(
            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
            LayerNorm2d(transformer_dim // 4),
            activation(),
            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
            activation(),
        )
        # ---upscaled end---
        
        # ---MLP---
        # ---对应mask数量的mlp---
        self.output_hypernetworks_mlps = nn.ModuleList(
            [
                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
                for i in range(self.num_mask_tokens)
            ]
        )
        # ---对应mask数量的mlp end---
        
        # ---对应iou的mlp--
        self.iou_prediction_head = MLP(
            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
        )
        # ---对应iou的mlp end--
        # ---MLP end---

先看传入的参数,如下:

transformer dim:表示transformer channel的维度。

transformer:表示传入的transformer模型。

num_multimask_outputs表示mask decoder输出的mask个数(用于消除歧义), SAM原文中默认值为3。

activation:表示mask decoder 上采样过程中使用的激活函数。

iou_head_depth表示用于预测mask的IoU质量指标时,所使用MLP层的深度。

iou_head_hidden_dim: 表示用于预测mask的IoU质量指标时,所使用MLP层中隐藏层的维度。

在下面代码中,由于在该阶段引入了一个额外的iou_token用于计算所预测mask的质量,因此要在该处+1。

self.num_mask_tokens = num_multimask_outputs + 1

Mask Decoder前向传播

def forward(
        self,
        image_embeddings: torch.Tensor,  # image encoder 输出的image embedding
        image_pe: torch.Tensor,  # image的position embedding
        sparse_prompt_embeddings: torch.Tensor,  # prompt encoder输出的sparse prompt
        dense_prompt_embeddings: torch.Tensor,  # prompt encoder 输出的dense prompt
        multimask_output: bool,  # 多类别输出,具有模糊识别的能力
    ) -> Tuple[torch.Tensor, torch.Tensor]:
       
        masks, iou_pred = self.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_prompt_embeddings,
            dense_prompt_embeddings=dense_prompt_embeddings,
        )

        # 根据multimask_output的bool值来对masks以及iou_pred进行选择性的切片
        if multimask_output:
            mask_slice = slice(1, None)  # 为真时, 切片选择从第一个元素到最后一个
        else:
            mask_slice = slice(0, 1)  # 为假时,只选择第一个切片
        masks = masks[:, mask_slice, :, :]
        iou_pred = iou_pred[:, mask_slice]

        # 返回mask 以及 iou的预测分数
        return masks, iou_pred

Mask Decoder中predict_masks

def predict_masks(
        self,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_prompt_embeddings: torch.Tensor,
        dense_prompt_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Predicts masks. See 'forward' for more details."""
        # Concatenate output tokens
        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)  # 对应 output_tokens + prompt tokens

        # Expand per-image data in batch direction to be per-mask
        # 扩展image_embeddings的B维度,因为boxes标记分割时,n个box时batchsize=batchsize*n
        if image_embeddings.shape[0] != tokens.shape[0]:
            # torch.repeat_interleave() 沿着指定的维度重复张量的元素
            # image_embeddings 相当于待重复的张量元素
            # tokens.shape[0] 相当于重复次数
            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        else:
            src = image_embeddings
        src = src + dense_prompt_embeddings  # 对应 image embedding + dense_prompt_embeddings(mask)
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape

        # Run the transformer
        # hs代表transformer的输出隐藏状态 , 而src代表transformer的输入
        hs, src = self.transformer(src, pos_src, tokens)
        # [hs]是Transformer的输出,其中第一个维度表示batch中的不同样本,第二个维度表示token的序列,第三个维度表示token的特征维度。
        # 通过`hs[:, 0, :]`可以获取第一个token对应的输出,即`iou_token_out`;
        # 通过`hs[:, 1 : (1 + self.num_mask_tokens), :]`可以获取接下来的`num_mask_tokens`个token对应的输出,即`mask_tokens_out`。
        # 这样的切片操作可以有效地提取出不同类型的token对应的输出。
        iou_token_out = hs[:, 0, :]
        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]

        # Upscale mask embeddings and predict masks using the mask tokens
        src = src.transpose(1, 2).view(b, c, h, w)
        upscaled_embedding = self.output_upscaling(src)
        # 用于存储每个mask token 对应的经过 MLP 处理后的输出。这些处理后的输出将被用于生成最终的预测 masks。
        hyper_in_list: List[torch.Tensor] = []
        # ---MLP---
        for i in range(self.num_mask_tokens):
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        # 将列表 hyper_in_list 中的张量沿着指定维度进行堆叠,生成一个新的张量 hyper_in
        hyper_in = torch.stack(hyper_in_list, dim=1)
        # ---MLP End---
        
        b, c, h, w = upscaled_embedding.shape
        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # ---MLP---
        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)
        # ---MLP End---

        return masks, iou_pred

该代码第一部分 Concatenate output tokens的实现原理即为, output_tokens + prompt_tokens的结合。这一部分的理解可以参考该博主的图片。【图像分割】【深度学习】SAM官方Pytorch代码-Mask decoder模块MaskDeco网络解析_sam decoder-优快云博客

Transformer

  Transformer数据流如图所示:从底部向上看我们发现, output tokens与prompt tokens先流入一个self-attention, 然后在从 token到 image 以及 image 到 token都采用corss-attention机制。在第一个Cross attention机制中(token to image), toke当作q,image embedding当作k与v。在第二个Cross attention中(image to token), image当作q, token充当k与v。【流程见TwoWayAttentionBlock代码】

class TwoWayTransformer(nn.Module):
    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,  # 下采样
    ) -> None:
        """
        A transformer decoder 尝试对一个输入图片使用带有位置embedding的查询
        由多个transformer block组成, 每个block包含两个attention模块.
        输入是图像的embedding、图像的position embedding和 点的embedding,
        输出是处理后的点的embedding和处理后的图像的embedding。
        Args:
          depth (int): number of layers in the transformer
          embedding_dim (int): the channel dimension for the input embeddings
          num_heads (int): the number of heads for multihead attention. Must
            divide embedding_dim
          mlp_dim (int): the channel dimension internal to the MLP block
          activation (nn.Module): the activation to use in the MLP block
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),  # 在第一个循环中 i=0, 说明在TwoWayAttentionBlock前向传播过程中第一次进self attn
                )
            )

        
        self.final_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

    def forward(
        self,
        image_embedding: Tensor,
        image_pe: Tensor,
        point_embedding: Tensor,  # 传入的是token = output_tokens + prompt_tokens
    ) -> Tuple[Tensor, Tensor]:
        """
        前向传播过程:
        (1) 将图像的embedding和position embedding 分别经过一个线性层,
            得到image_embedding 和 image_pe。
        (2) 将点嵌入的embedding经过一个线性层,得到point_embedding。
        (3) 对 image_embedding 和 point_embedding 进行 transformer block处理,
            得到经过处理的 image_embedding 和 point_embedding。
        (4) 对经过处理的 image_embedding 和 point_embedding 进行交叉注意力,
            得到经过处理的 point_embedding 和 image_embedding。
        
        Args:
            image_embedding (torch.Tensor): 图像嵌入张量,形状为 B x embedding_dim x h x w。
            image_pe (torch.Tensor): 图像的位置编码张量,与 image_embedding 具有相同的形状。
            point_embedding (torch.Tensor): 查询点的嵌入张量,形状为 B x N_points x embedding_dim。
        
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: 经过处理的 point_embedding 和 image_embedding。
        
        """
        # Flatten image embedding to B x N_image_tokens x C
        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
        bs, c, h, w = image_embedding.shape
        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
        image_pe = image_pe.flatten(2).permute(0, 2, 1)  # image embedding 对应的 position embedding

        # Prepare queries
        queries = point_embedding
        keys = image_embedding

        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,  # 第一次添加时, queries与query_pe相同
                key_pe=image_pe,
            )

        # Apply the final attention layer from the points to the image
        q = queries + point_embedding
        k = keys + image_pe
        # # 最后一个cross attn Final attention layer from the points to the image
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)

        return queries, keys

TwoWayAttentionBlock

        

class TwoWayAttentionBlock(nn.Module):
    #  TwoWayAttentionBlock = LayerNorm + Multi-Head Attention + MLP
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        A transformer block with four layers: 
        (1) self-attention of sparse inputs,
        (2) cross attention of sparse inputs to dense inputs,
        (3) mlp block on sparse inputs, 
        (4) cross attention of dense inputs to sparse
        inputs.

        Arguments:
          embedding_dim (int): the channel dimension of the embeddings
          num_heads (int): the number of heads in the attention layers
          mlp_dim (int): the hidden dimension of the mlp block
          activation (nn.Module): the activation of the mlp block
          skip_first_layer_pe (bool): skip the PE on the first layer
        """
        super().__init__()
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1 = nn.LayerNorm(embedding_dim)

        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
        self.norm3 = nn.LayerNorm(embedding_dim)

        self.norm4 = nn.LayerNorm(embedding_dim)
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )

        self.skip_first_layer_pe = skip_first_layer_pe

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        
        # 第一个Self attention 模块。
        # 第一轮本身queries==query_pe
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out
        queries = self.norm1(queries)

        # 第一个 Cross attention block。 tokens attending to image embedding
        # q, k, v不再是来源于同一个序列,而是多个序列. queries + query_pe充当q, k与v都由 keys提供
        # tokens to image embedding意味着,将token作为q, image_embedding 作为 k与v
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # 第二个 Cross attention block。 image embedding attending to tokens
        q = queries + query_pe
        k = keys + key_pe
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys

Attention

class Attention(nn.Module):
    """
    一个允许下采样embedding size的attention层
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.internal_dim = embedding_dim // downsample_rate
        self.num_heads = num_heads
        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."

        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head  C_per_head表示一个head中有多少个channel

    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape
        x = x.transpose(1, 2)
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # Input projections
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # Separate into heads
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        # Attention
        _, _, _, c_per_head = q.shape
        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
        attn = attn / math.sqrt(c_per_head)
        attn = torch.softmax(attn, dim=-1)

        # Get output
        out = attn @ v
        out = self._recombine_heads(out)
        out = self.out_proj(out)

        return out

E:\python3.11\python.exe C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py Traceback (most recent call last): File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 52, in <module> convert_sam_to_onnx() File "C:\Users\12890\Desktop\deepseek_python_20250730_31eda7.py", line 16, in convert_sam_to_onnx sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_PATH) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "E:\python3.11\Lib\site-packages\segment_anything\build_sam.py", line 38, in build_sam_vit_b return _build_sam( ^^^^^^^^^^^ File "E:\python3.11\Lib\site-packages\segment_anything\build_sam.py", line 106, in _build_sam sam.load_state_dict(state_dict) File "E:\python3.11\Lib\site-packages\torch\nn\modules\module.py", line 2593, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for Sam: Missing key(s) in state_dict: "image_encoder.pos_embed", "image_encoder.patch_embed.proj.weight", "image_encoder.patch_embed.proj.bias", "image_encoder.blocks.0.norm1.weight", "image_encoder.blocks.0.norm1.bias", "image_encoder.blocks.0.attn.rel_pos_h", "image_encoder.blocks.0.attn.rel_pos_w", "image_encoder.blocks.0.attn.qkv.weight", "image_encoder.blocks.0.attn.qkv.bias", "image_encoder.blocks.0.attn.proj.weight", "image_encoder.blocks.0.attn.proj.bias", "image_encoder.blocks.0.norm2.weight", "image_encoder.blocks.0.norm2.bias", "image_encoder.blocks.0.mlp.lin1.weight", "image_encoder.blocks.0.mlp.lin1.bias", "image_encoder.blocks.0.mlp.lin2.weight", "image_encoder.blocks.0.mlp.lin2.bias", "image_encoder.blocks.1.norm1.weight", "image_encoder.blocks.1.norm1.bias", "image_encoder.blocks.1.attn.rel_pos_h", "image_encoder.blocks.1.attn.rel_pos_w", "image_encoder.blocks.1.attn.qkv.weight", "image_encoder.blocks.1.attn.qkv.bias", "image_encoder.blocks.1.attn.proj.weight", "image_encoder.blocks.1.attn.proj.bias", "image_encoder.blocks.1.norm2.weight", "image_encoder.blocks.1.norm2.bias", "image_encoder.blocks.1.mlp.lin1.weight", "image_encoder.blocks.1.mlp.lin1.bias", "image_encoder.blocks.1.mlp.lin2.weight", "image_encoder.blocks.1.mlp.lin2.bias", "image_encoder.blocks.2.norm1.weight", "image_encoder.blocks.2.norm1.bias", "image_encoder.blocks.2.attn.rel_pos_h", "image_encoder.blocks.2.attn.rel_pos_w", "image_encoder.blocks.2.attn.qkv.weight", "image_encoder.blocks.2.attn.qkv.bias", "image_encoder.blocks.2.attn.proj.weight", "image_encoder.blocks.2.attn.proj.bias", "image_encoder.blocks.2.norm2.weight", "image_encoder.blocks.2.norm2.bias", "image_encoder.blocks.2.mlp.lin1.weight", "image_encoder.blocks.2.mlp.lin1.bias", "image_encoder.blocks.2.mlp.lin2.weight", "image_encoder.blocks.2.mlp.lin2.bias", "image_encoder.blocks.3.norm1.weight", "image_encoder.blocks.3.norm1.bias", "image_encoder.blocks.3.attn.rel_pos_h", "image_encoder.blocks.3.attn.rel_pos_w", "image_encoder.blocks.3.attn.qkv.weight", "image_encoder.blocks.3.attn.qkv.bias", "image_encoder.blocks.3.attn.proj.weight", "image_encoder.blocks.3.attn.proj.bias", "image_encoder.blocks.3.norm2.weight", "image_encoder.blocks.3.norm2.bias", "image_encoder.blocks.3.mlp.lin1.weight", "image_encoder.blocks.3.mlp.lin1.bias", "image_encoder.blocks.3.mlp.lin2.weight", "image_encoder.blocks.3.mlp.lin2.bias", "image_encoder.blocks.4.norm1.weight", "image_encoder.blocks.4.norm1.bias", "image_encoder.blocks.4.attn.rel_pos_h", "image_encoder.blocks.4.attn.rel_pos_w", "image_encoder.blocks.4.attn.qkv.weight", "image_encoder.blocks.4.attn.qkv.bias", "image_encoder.blocks.4.attn.proj.weight", "image_encoder.blocks.4.attn.proj.bias", "image_encoder.blocks.4.norm2.weight", "image_encoder.blocks.4.norm2.bias", "image_encoder.blocks.4.mlp.lin1.weight", "image_encoder.blocks.4.mlp.lin1.bias", "image_encoder.blocks.4.mlp.lin2.weight", "image_encoder.blocks.4.mlp.lin2.bias", "image_encoder.blocks.5.norm1.weight", "image_encoder.blocks.5.norm1.bias", "image_encoder.blocks.5.attn.rel_pos_h", "image_encoder.blocks.5.attn.rel_pos_w", "image_encoder.blocks.5.attn.qkv.weight", "image_encoder.blocks.5.attn.qkv.bias", "image_encoder.blocks.5.attn.proj.weight", "image_encoder.blocks.5.attn.proj.bias", "image_encoder.blocks.5.norm2.weight", "image_encoder.blocks.5.norm2.bias", "image_encoder.blocks.5.mlp.lin1.weight", "image_encoder.blocks.5.mlp.lin1.bias", "image_encoder.blocks.5.mlp.lin2.weight", "image_encoder.blocks.5.mlp.lin2.bias", "image_encoder.blocks.6.norm1.weight", "image_encoder.blocks.6.norm1.bias", "image_encoder.blocks.6.attn.rel_pos_h", "image_encoder.blocks.6.attn.rel_pos_w", "image_encoder.blocks.6.attn.qkv.weight", "image_encoder.blocks.6.attn.qkv.bias", "image_encoder.blocks.6.attn.proj.weight", "image_encoder.blocks.6.attn.proj.bias", "image_encoder.blocks.6.norm2.weight", "image_encoder.blocks.6.norm2.bias", "image_encoder.blocks.6.mlp.lin1.weight", "image_encoder.blocks.6.mlp.lin1.bias", "image_encoder.blocks.6.mlp.lin2.weight", "image_encoder.blocks.6.mlp.lin2.bias", "image_encoder.blocks.7.norm1.weight", "image_encoder.blocks.7.norm1.bias", "image_encoder.blocks.7.attn.rel_pos_h", "image_encoder.blocks.7.attn.rel_pos_w", "image_encoder.blocks.7.attn.qkv.weight", "image_encoder.blocks.7.attn.qkv.bias", "image_encoder.blocks.7.attn.proj.weight", "image_encoder.blocks.7.attn.proj.bias", "image_encoder.blocks.7.norm2.weight", "image_encoder.blocks.7.norm2.bias", "image_encoder.blocks.7.mlp.lin1.weight", "image_encoder.blocks.7.mlp.lin1.bias", "image_encoder.blocks.7.mlp.lin2.weight", "image_encoder.blocks.7.mlp.lin2.bias", "image_encoder.blocks.8.norm1.weight", "image_encoder.blocks.8.norm1.bias", "image_encoder.blocks.8.attn.rel_pos_h", "image_encoder.blocks.8.attn.rel_pos_w", "image_encoder.blocks.8.attn.qkv.weight", "image_encoder.blocks.8.attn.qkv.bias", "image_encoder.blocks.8.attn.proj.weight", "image_encoder.blocks.8.attn.proj.bias", "image_encoder.blocks.8.norm2.weight", "image_encoder.blocks.8.norm2.bias", "image_encoder.blocks.8.mlp.lin1.weight", "image_encoder.blocks.8.mlp.lin1.bias", "image_encoder.blocks.8.mlp.lin2.weight", "image_encoder.blocks.8.mlp.lin2.bias", "image_encoder.blocks.9.norm1.weight", "image_encoder.blocks.9.norm1.bias", "image_encoder.blocks.9.attn.rel_pos_h", "image_encoder.blocks.9.attn.rel_pos_w", "image_encoder.blocks.9.attn.qkv.weight", "image_encoder.blocks.9.attn.qkv.bias", "image_encoder.blocks.9.attn.proj.weight", "image_encoder.blocks.9.attn.proj.bias", "image_encoder.blocks.9.norm2.weight", "image_encoder.blocks.9.norm2.bias", "image_encoder.blocks.9.mlp.lin1.weight", "image_encoder.blocks.9.mlp.lin1.bias", "image_encoder.blocks.9.mlp.lin2.weight", "image_encoder.blocks.9.mlp.lin2.bias", "image_encoder.blocks.10.norm1.weight", "image_encoder.blocks.10.norm1.bias", "image_encoder.blocks.10.attn.rel_pos_h", "image_encoder.blocks.10.attn.rel_pos_w", "image_encoder.blocks.10.attn.qkv.weight", "image_encoder.blocks.10.attn.qkv.bias", "image_encoder.blocks.10.attn.proj.weight", "image_encoder.blocks.10.attn.proj.bias", "image_encoder.blocks.10.norm2.weight", "image_encoder.blocks.10.norm2.bias", "image_encoder.blocks.10.mlp.lin1.weight", "image_encoder.blocks.10.mlp.lin1.bias", "image_encoder.blocks.10.mlp.lin2.weight", "image_encoder.blocks.10.mlp.lin2.bias", "image_encoder.blocks.11.norm1.weight", "image_encoder.blocks.11.norm1.bias", "image_encoder.blocks.11.attn.rel_pos_h", "image_encoder.blocks.11.attn.rel_pos_w", "image_encoder.blocks.11.attn.qkv.weight", "image_encoder.blocks.11.attn.qkv.bias", "image_encoder.blocks.11.attn.proj.weight", "image_encoder.blocks.11.attn.proj.bias", "image_encoder.blocks.11.norm2.weight", "image_encoder.blocks.11.norm2.bias", "image_encoder.blocks.11.mlp.lin1.weight", "image_encoder.blocks.11.mlp.lin1.bias", "image_encoder.blocks.11.mlp.lin2.weight", "image_encoder.blocks.11.mlp.lin2.bias", "image_encoder.neck.0.weight", "image_encoder.neck.1.weight", "image_encoder.neck.1.bias", "image_encoder.neck.2.weight", "image_encoder.neck.3.weight", "image_encoder.neck.3.bias", "prompt_encoder.pe_layer.positional_encoding_gaussian_matrix", "prompt_encoder.point_embeddings.0.weight", "prompt_encoder.point_embeddings.1.weight", "prompt_encoder.point_embeddings.2.weight", "prompt_encoder.point_embeddings.3.weight", "prompt_encoder.not_a_point_embed.weight", "prompt_encoder.mask_downscaling.0.weight", "prompt_encoder.mask_downscaling.0.bias", "prompt_encoder.mask_downscaling.1.weight", "prompt_encoder.mask_downscaling.1.bias", "prompt_encoder.mask_downscaling.3.weight", "prompt_encoder.mask_downscaling.3.bias", "prompt_encoder.mask_downscaling.4.weight", "prompt_encoder.mask_downscaling.4.bias", "prompt_encoder.mask_downscaling.6.weight", "prompt_encoder.mask_downscaling.6.bias", "prompt_encoder.no_mask_embed.weight", "mask_decoder.transformer.layers.0.self_attn.q_proj.weight", "mask_decoder.transformer.layers.0.self_attn.q_proj.bias", "mask_decoder.transformer.layers.0.self_attn.k_proj.weight", "mask_decoder.transformer.layers.0.self_attn.k_proj.bias", "mask_decoder.transformer.layers.0.self_attn.v_proj.weight", "mask_decoder.transformer.layers.0.self_attn.v_proj.bias", "mask_decoder.transformer.layers.0.self_attn.out_proj.weight", "mask_decoder.transformer.layers.0.self_attn.out_proj.bias", "mask_decoder.transformer.layers.0.norm1.weight", "mask_decoder.transformer.layers.0.norm1.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.layers.0.norm2.weight", "mask_decoder.transformer.layers.0.norm2.bias", "mask_decoder.transformer.layers.0.mlp.lin1.weight", "mask_decoder.transformer.layers.0.mlp.lin1.bias", "mask_decoder.transformer.layers.0.mlp.lin2.weight", "mask_decoder.transformer.layers.0.mlp.lin2.bias", "mask_decoder.transformer.layers.0.norm3.weight", "mask_decoder.transformer.layers.0.norm3.bias", "mask_decoder.transformer.layers.0.norm4.weight", "mask_decoder.transformer.layers.0.norm4.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight", "mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias", "mask_decoder.transformer.layers.1.self_attn.q_proj.weight", "mask_decoder.transformer.layers.1.self_attn.q_proj.bias", "mask_decoder.transformer.layers.1.self_attn.k_proj.weight", "mask_decoder.transformer.layers.1.self_attn.k_proj.bias", "mask_decoder.transformer.layers.1.self_attn.v_proj.weight", "mask_decoder.transformer.layers.1.self_attn.v_proj.bias", "mask_decoder.transformer.layers.1.self_attn.out_proj.weight", "mask_decoder.transformer.layers.1.self_attn.out_proj.bias", "mask_decoder.transformer.layers.1.norm1.weight", "mask_decoder.transformer.layers.1.norm1.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.layers.1.norm2.weight", "mask_decoder.transformer.layers.1.norm2.bias", "mask_decoder.transformer.layers.1.mlp.lin1.weight", "mask_decoder.transformer.layers.1.mlp.lin1.bias", "mask_decoder.transformer.layers.1.mlp.lin2.weight", "mask_decoder.transformer.layers.1.mlp.lin2.bias", "mask_decoder.transformer.layers.1.norm3.weight", "mask_decoder.transformer.layers.1.norm3.bias", "mask_decoder.transformer.layers.1.norm4.weight", "mask_decoder.transformer.layers.1.norm4.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight", "mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.q_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.q_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.k_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.k_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.v_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.v_proj.bias", "mask_decoder.transformer.final_attn_token_to_image.out_proj.weight", "mask_decoder.transformer.final_attn_token_to_image.out_proj.bias", "mask_decoder.transformer.norm_final_attn.weight", "mask_decoder.transformer.norm_final_attn.bias", "mask_decoder.iou_token.weight", "mask_decoder.mask_tokens.weight", "mask_decoder.output_upscaling.0.weight", "mask_decoder.output_upscaling.0.bias", "mask_decoder.output_upscaling.1.weight", "mask_decoder.output_upscaling.1.bias", "mask_decoder.output_upscaling.3.weight", "mask_decoder.output_upscaling.3.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.0.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.0.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.1.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.1.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.2.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.2.layers.2.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.0.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.0.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.1.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.1.bias", "mask_decoder.output_hypernetworks_mlps.3.layers.2.weight", "mask_decoder.output_hypernetworks_mlps.3.layers.2.bias", "mask_decoder.iou_prediction_head.layers.0.weight", "mask_decoder.iou_prediction_head.layers.0.bias", "mask_decoder.iou_prediction_head.layers.1.weight", "mask_decoder.iou_prediction_head.layers.1.bias", "mask_decoder.iou_prediction_head.layers.2.weight", "mask_decoder.iou_prediction_head.layers.2.bias". Unexpected key(s) in state_dict: "model". Process finished with exit code 1
最新发布
08-01
<think>我们遇到了在加载SAM模型权重时出现的错误:`RuntimeError: Error(s) in loading state_dict for Sam`,具体是缺少某些键(missing keys)和出现了意外的键`'model'`(unexpected key 'model')。这通常是因为保存的模型权重文件(.pt)的结构与模型类期望的结构不匹配。 根据引用[1]中提到的经验,我们可以尝试在加载权重时设置`strict`参数为`False`来忽略不匹配的键。但是,这可能会导致模型性能下降,因为部分权重没有被加载。更好的做法是检查权重文件的结构,并调整加载方式。 分析:SAM的官方仓库提供了多种模型,其权重文件可能包含一个顶层的键`'model'`,而我们的模型类期望直接加载状态字典。因此,我们需要检查权重文件的结构。 解决方案步骤: 1. **检查权重文件结构**:加载.pt文件,查看其顶层键。 2. **调整权重加载**:如果权重文件包含`'model'`键,则使用`state_dict = torch.load(MODEL_PATH)['model']`来获取真正的状态字典。 3. **严格加载**:如果调整后仍然有缺失的键,可能需要考虑是否模型结构不匹配(比如模型类型选择错误)或者权重文件损坏。 具体代码调整如下: 原代码: ```python sam = sam_model_registry[MODEL_TYPE](checkpoint=MODEL_PATH) ``` 调整为: ```python # 加载模型结构 sam = sam_model_registry[MODEL_TYPE]() # 加载权重文件 checkpoint = torch.load(MODEL_PATH) # 检查权重文件是否包含'state_dict'或'model'键 if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] else: state_dict = checkpoint # 加载权重 sam.load_state_dict(state_dict, strict=False) # 先尝试非严格加载,看是否可行 ``` 但是,根据错误信息,我们可能遇到了权重文件顶层有`'model'`键的情况,所以我们可以直接提取`'model'`键对应的状态字典。 然而,注意:如果使用非严格加载,会忽略缺失的键和多余的键,但可能影响模型性能。因此,我们更希望精确匹配。 另一种可能是,我们使用的模型类型(`MODEL_TYPE`)与权重文件不匹配。请确保您选择的`MODEL_TYPE`(如`vit_h`)与您下载的权重文件匹配。 如果上述方法不能解决问题,我们可以尝试打印出模型的状态字典的键和权重文件的键,进行对比分析。 以下是修改后的完整代码,包括详细的错误检查和键对比: ```python import torch from segment_anything import sam_model_registry MODEL_PATH = r"E:\ketizu\XAnylabeling\sam2.1_hiera_base_plus.pt" MODEL_TYPE = "vit_h" # 请根据实际情况调整 def load_sam_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. 初始化模型 sam = sam_model_registry[MODEL_TYPE]() sam.to(device) # 2. 加载权重文件 checkpoint = torch.load(MODEL_PATH, map_location=device) # 3. 检查权重文件结构 if 'model' in checkpoint: state_dict = checkpoint['model'] elif 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint # 4. 获取模型当前的状态字典的键 model_keys = sam.state_dict().keys() # 5. 打印权重状态字典的键和模型键的差异 state_dict_keys = set(state_dict.keys()) model_keys_set = set(model_keys) missing_keys = model_keys_set - state_dict_keys unexpected_keys = state_dict_keys - model_keys_set print(f"Missing keys: {missing_keys}") print(f"Unexpected keys: {unexpected_keys}") # 6. 尝试加载权重 try: sam.load_state_dict(state_dict, strict=True) print("✅ 权重加载成功 (严格模式)!") except RuntimeError as e: print(f"⚠️ 严格加载失败: {e}") print("尝试非严格加载...") sam.load_state_dict(state_dict, strict=False) print("✅ 权重加载成功 (非严格模式,忽略不匹配的键)!") sam.eval() return sam if __name__ == "__main__": sam_model = load_sam_model() ``` 运行此脚本后,我们将清楚地看到哪些键缺失,哪些键是多余的。根据输出,我们可以进一步调整模型结构或权重文件。 注意:如果缺失的键是无关紧要的(例如,分类头在分割中未使用),那么非严格加载是可接受的。但如果缺失了主干网络的关键层,则模型可能无法正常工作。 另外,如果您的模型是自定义的(如Hierarchical SAM),可能需要使用自定义的模型加载方式,而不是使用`sam_model_registry`。 如果问题仍然存在,请提供缺失键和多余键的具体信息,以便进一步分析。 ### 常见问题 1. **权重文件来自不同的代码库**:如果您使用的权重文件不是来自官方SAM仓库,而是来自其他实现(比如MedSAM或Hierarchical-SAM),那么模型结构可能不同,需要对应代码库的模型定义。 2. **模型类型选择错误**:确保`MODEL_TYPE`与权重匹配。例如,基础模型对应`vit_b`,大型模型对应`vit_l`,巨型模型对应`vit_h`。 3. **PyTorch版本不匹配**:不同版本的PyTorch保存的权重文件可能有差异,请确保使用相同版本的PyTorch。 ### 引用说明 在加载模型权重时,通过设置`strict=False`可以忽略不匹配的键[^1]。但是,这可能会导致模型性能下降,因此需要谨慎使用。 希望以上解决方案能帮助您成功加载模型并完成ONNX转换。 §§相关问题§§ 1. 如何解决PyTorch模型权重加载时的键不匹配问题? 2. 如何检查PyTorch模型权重文件的结构? 3. 在ONNX转换中如何处理自定义PyTorch模型? 4. 转换ONNX模型时遇到形状不匹配错误怎么办? 5. 如何将包含复杂后处理的SAM模型导出为ONNX?</think>### 解决 SAM 模型转换中的权重加载错误 您遇到的 `RuntimeError: Error(s) in loading state_dict for Sam` 错误通常是由于 PyTorch 权重文件结构与模型预期结构不匹配造成的。以下是完整的解决方案: #### 解决方案代码 ```python import torch from segment_anything import sam_model_registry # 配置参数 MODEL_PATH = r"E:\ketizu\XAnylabeling\sam2.1_hiera_base_plus.pt" MODEL_TYPE = "vit_h" # 根据模型类型调整 def load_sam_model(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. 初始化模型结构 sam = sam_model_registry[MODEL_TYPE]() # 2. 加载权重文件并检查结构 checkpoint = torch.load(MODEL_PATH, map_location=device) # 3. 检查权重文件是否包含顶层'model'键 if 'model' in checkpoint: print("检测到权重文件包含顶层'model'键,提取内部状态字典") state_dict = checkpoint['model'] else: state_dict = checkpoint # 4. 打印状态字典键以进行调试 print("状态字典键:", list(state_dict.keys())[:5]) # 5. 尝试加载权重 try: # 先尝试严格加载 sam.load_state_dict(state_dict, strict=True) print("✅ 权重加载成功 (严格模式)") except RuntimeError as e: print(f"⚠️ 严格加载失败: {e}") print("尝试非严格加载...") # 使用非严格模式加载 sam.load_state_dict(state_dict, strict=False) print("✅ 权重加载成功 (非严格模式,忽略不匹配的键)") sam.to(device) sam.eval() return sam if __name__ == "__main__": sam_model = load_sam_model() print("模型加载成功,准备进行ONNX转换") ``` #### 关键问题解析 1. **`unexpected key 'model'` 错误原因**: - 权重文件可能包含额外的顶层键 `'model'`,而模型直接期望状态字典 - 解决方案:提取 `checkpoint['model']` 作为实际状态字典[^1] 2. **`missing keys` 错误解决方案**: - 使用 `strict=False` 参数允许加载部分权重 - 此操作会忽略缺失的键,但可能影响模型性能[^1] 3. **权重文件结构调试技巧**: ```python # 打印权重文件结构 checkpoint = torch.load(MODEL_PATH) print("权重文件顶层键:", checkpoint.keys()) # 如果是分层结构 if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] elif 'model' in checkpoint: state_dict = checkpoint['model'] ``` 4. **模型类型匹配检查**: - 确认 `MODEL_TYPE` 与实际模型匹配 - 尝试不同模型类型: ```python # 可能的模型类型 MODEL_TYPES = ["vit_b", "vit_l", "vit_h", "default"] for model_type in MODEL_TYPES: try: sam = sam_model_registry[model_type]() sam.load_state_dict(state_dict, strict=False) print(f"✅ 匹配成功: {model_type}") break except: continue ``` #### 完整的ONNX转换代码(整合解决方案) ```python import torch import onnx from onnxsim import simplify from segment_anything import sam_model_registry # 配置参数 MODEL_PATH = r"E:\ketizu\XAnylabeling\sam2.1_hiera_base_plus.pt" ONNX_PATH = "sam_model.onnx" MODEL_TYPE = "vit_h" # 根据模型调整 def convert_sam_to_onnx(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 1. 加载模型权重(包含错误处理) checkpoint = torch.load(MODEL_PATH, map_location=device) state_dict = checkpoint.get('model', checkpoint) sam = sam_model_registry[MODEL_TYPE]() sam.load_state_dict(state_dict, strict=False) # 允许部分权重不匹配 sam.to(device) sam.eval() # 2. 准备动态输入 dummy_input = ( torch.randn(1, 3, 1024, 1024, device=device), torch.randint(0, 1024, (1, 2, 2), device=device).float(), torch.ones(1, 2, device=device).float() ) # 3. 导出ONNX torch.onnx.export( model=sam, args=dummy_input, f=ONNX_PATH, input_names=["image", "point_coords", "point_labels"], output_names=["masks", "iou_predictions"], dynamic_axes={ "image": {0: "batch_size"}, "point_coords": {0: "batch_size", 1: "num_points"}, "point_labels": {0: "batch_size", 1: "num_points"} }, opset_version=17, verbose=True ) # 4. 优化模型 model_onnx = onnx.load(ONNX_PATH) model_simp, check = simplify(model_onnx) assert check, "简化失败" onnx.save(model_simp, ONNX_PATH.replace('.onnx', '_sim.onnx')) print("✅ ONNX转换完成") if __name__ == "__main__": convert_sam_to_onnx() ``` ### 常见问题解决方案 1. **如果仍遇到键不匹配**: ```python # 手动过滤不需要的键 filtered_dict = {k.replace('module.', ''): v for k, v in state_dict.items() if not k.startswith('unexpected_key')} sam.load_state_dict(filtered_dict, strict=False) ``` 2. **模型版本兼容性问题**: - 尝试安装特定版本的segment-anything库: ``` pip install git+https://github.com/facebookresearch/segment-anything.git@v1.0 ``` 3. **自定义模型加载**: ```python # 如果标准注册表不适用 from segment_anything import Sam sam = Sam( image_encoder=your_custom_encoder, prompt_encoder=your_custom_prompt_encoder, mask_decoder=your_custom_decoder ) ``` > **重要提示**:使用 `strict=False` 后,建议验证模型输出质量,因为部分权重未加载可能影响精度[^1]。可以对比原始PyTorch模型和ONNX模型的输出差异。
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值