Masked Autoencoder论文中 fine-tuning 和 linear probing含义

自监督学习任务中,将预训练模型转移到下游任务时,通常需要进行fine-tuning。

常见的方法有两种:

  1. full fine-tuning(更新所有模型参数),具体来说,冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层,因为这些层保留了大量底层信息)甚至不冻结任何网络层,训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层

  2. linear probing (只更新最后一个linear layer参数),具体来说,训练后,要评价模型的好坏,通过将最后的一层替换成线性层。 预训练模型的表征层的特征固定,参数固化后未发生改变,只通过监督数据去训练分类器(通常是Softmax分类器或者SVM分类器等等)。只训练这个线性层就是linear probe。

### Masked Autoencoder 的概念 Masked Autoencoder (MAE) 是一种自监督学习方法,主要用于处理图像数据。其核心思想是在输入端随机遮挡部分像素或补丁(patch),并通过模型预测这些被遮挡的部分来重建完整的图像[^2]。这种方法通过强制模型关注未观察到的信息,从而有效地学习数据的潜在表示。 在 MAE 中,编码器仅接收可见部分作为输入,而解码器则负责基于编码器的隐状态以及已知未知区域的位置信息来恢复整个图像。这种设计不仅提高了计算效率,还增强了模型对局部特征的学习能力[^1]。 ### 实现方法 以下是 MAE 的基本实现框架: #### 数据预处理 为了训练 MAE 模型,通常会将输入图像分割成多个固定大小的小块(称为 patches)。接着按照一定比例随机屏蔽其中一些 patch 并标记它们的位置以便后续重构阶段利用。 #### 编码过程 编码网络接受未经掩蔽的patches子集并将其映射至低维空间形成紧凑表征向量集合;由于只有一部分patch参与前馈传播操作因此能够显著减少资源消耗同时保留重要结构特性不变[^3]。 ```python class Encoder(nn.Module): def __init__(self, embed_dim=768): super().__init__() self.patch_embed = nn.Conv2d(in_channels=3, out_channels=embed_dim, kernel_size=(16, 16), stride=(16, 16)) def forward(self, x_visible): embeddings = self.patch_embed(x_visible) return embeddings.flatten(2).transpose(1, 2) # Example usage of encoder class above within a larger pipeline. ``` #### 解码流程 给定由编码步骤产生的隐藏态连同对应于缺失位置指示符一起送入解码模块完成最终还原任务。具体而言先拼接两者再经过多层感知机变换得到目标尺寸输出最后应用回归损失函数度量二者差异程度进而指导参数调整方向直至收敛为止[^4]。 ```python class Decoder(nn.Module): def __init__(self, num_patches=196, hidden_dim=512, output_channel=3): super().__init__() self.expand_dims = nn.Linear(embed_dim, hidden_dim * 16 ** 2) self.decoder_blocks = nn.Sequential(*[ Block(dim=hidden_dim) for _ in range(num_decoder_layers)]) self.output_layer = nn.ConvTranspose2d(hidden_dim, output_channel, kernel_size=(16, 16), stride=(16, 16)) def forward(self, latent_vector_with_mask_info): expanded_tokens = self.expand_dims(latent_vector_with_mask_info) reshaped_tokens = expanded_tokens.view(-1, num_patches, *(int(expanded_tokens.shape[-1])**0.5,) * 2) decoded_features = self.decoder_blocks(reshaped_tokens.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) reconstructed_image = self.output_layer(decoded_features) return reconstructed_image ``` ### 应用场景 - **视觉领域**: 如物体检测、语义分割等下游任务中表现优异因为可以捕捉更丰富的上下文关系有助于提升整体性能水平. - **异常检测**: 利用预先训练好的MAEs快速识别不符合正常模式的新样本特别适合工业监控视频流分析等工作环境当中.
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值