深度迁移学习实战教程:基于ResNet-50的微调与领域自适应方法

深度迁移学习实战教程:基于ResNet-50的微调与领域自适应方法

transferlearning Transfer learning / domain adaptation / domain generalization / multi-task learning etc. Papers, codes, datasets, applications, tutorials.-迁移学习 transferlearning 项目地址: https://gitcode.com/gh_mirrors/tr/transferlearning

前言

深度迁移学习是当前机器学习领域的重要研究方向,它能够将在一个领域(源领域)学习到的知识迁移到另一个相关但不同的领域(目标领域)。本文将基于经典的ResNet-50网络,通过实际代码演示两种最常用的迁移学习方法:微调(Finetune)和领域自适应(Domain Adaptation)。

环境准备

首先需要确保已安装必要的Python库:

!pip install torch torchvision

数据集准备

我们将使用经典的Office-31数据集,该数据集包含三个不同领域(Amazon、Webcam和DSLR)的图像数据,每个领域包含31个类别的物品图像。

数据集结构如下:

office31/
├── amazon
├── dslr
└── webcam

数据加载与预处理

我们定义了一个数据加载函数,对源领域和目标领域采用不同的预处理方式:

def load_data(root_path, domain, batch_size, phase):
    transform_dict = {
        'src': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ]),
        'tar': transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])
        ])
    }
    data = datasets.ImageFolder(root=os.path.join(root_path, domain), 
                              transform=transform_dict[phase])
    data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, 
                                            shuffle=phase=='src', 
                                            drop_last=phase=='tar', 
                                            num_workers=4)
    return data_loader

微调(Finetune)方法

模型定义

我们基于预训练的ResNet-50构建迁移学习模型,替换最后的全连接层以适应新的分类任务:

class TransferModel(nn.Module):
    def __init__(self, base_model='resnet50', pretrain=True, n_class=31):
        super(TransferModel, self).__init__()
        self.base_model = base_model
        self.pretrain = pretrain
        self.n_class = n_class
        if self.base_model == 'resnet50':
            self.model = torchvision.models.resnet50(pretrained=True)
            n_features = self.model.fc.in_features
            fc = torch.nn.Linear(n_features, n_class)
            self.model.fc = fc
        self.model.fc.weight.data.normal_(0, 0.005)
        self.model.fc.bias.data.fill_(0.1)

    def forward(self, x):
        return self.model(x)
    
    def predict(self, x):
        return self.forward(x)

训练过程

微调训练的核心逻辑如下:

def finetune(model, dataloaders, optimizer):
    best_acc = 0
    stop = 0
    for epoch in range(n_epoch):
        stop += 1
        for phase in ['src', 'val', 'tar']:
            if phase == 'src':
                model.train()
            else:
                model.eval()
            total_loss, correct = 0, 0
            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'src'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                preds = torch.max(outputs, 1)[1]
                if phase == 'src':
                    loss.backward()
                    optimizer.step()
                total_loss += loss.item() * inputs.size(0)
                correct += torch.sum(preds == labels.data)
            epoch_loss = total_loss / len(dataloaders[phase].dataset)
            epoch_acc = correct.double() / len(dataloaders[phase].dataset)
            if phase == 'val' and epoch_acc > best_acc:
                stop = 0
                best_acc = epoch_acc
                torch.save(model.state_dict(), 'model.pkl')
        if stop >= early_stop:
            break

优化器设置

我们采用SGD优化器,并为全连接层设置更高的学习率:

param_group = []
learning_rate = 0.0001
momentum = 5e-4
for k, v in model.named_parameters():
    if not k.__contains__('fc'):
        param_group += [{'params': v, 'lr': learning_rate}]
    else:
        param_group += [{'params': v, 'lr': learning_rate * 10}]
optimizer = torch.optim.SGD(param_group, momentum=momentum)

领域自适应(Domain Adaptation)方法

领域自适应在微调的基础上增加了领域对齐的损失函数,以减少源领域和目标领域之间的分布差异。

领域对齐损失函数

我们实现了两种常用的领域对齐方法:

MMD(最大均值差异)损失
class MMD_loss(nn.Module):
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
        super(MMD_loss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target):
        n_samples = int(source.size()[0])+int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(int(total.size(0)), 
                                      int(total.size(0)), 
                                      int(total.size(1)))
        total1 = total.unsqueeze(1).expand(int(total.size(0)), 
                                      int(total.size(0)), 
                                      int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if self.fix_sigma:
            bandwidth = self.fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= self.kernel_mul ** (self.kernel_num // 2)
        bandwidth_list = [bandwidth * (self.kernel_mul**i) 
                         for i in range(self.kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp) 
                     for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def forward(self, source, target):
        batch_size = int(source.size()[0])
        kernels = self.guassian_kernel(source, target)
        XX = kernels[:batch_size, :batch_size]
        YY = kernels[batch_size:, batch_size:]
        XY = kernels[:batch_size, batch_size:]
        YX = kernels[batch_size:, :batch_size]
        loss = torch.mean(XX + YY - XY - YX)
        return loss
CORAL(相关性对齐)损失
class CORAL_loss(nn.Module):
    def __init__(self):
        super(CORAL_loss, self).__init__()

    def forward(self, source, target):
        d = source.size(1)
        ns, nt = source.size(0), target.size(0)
        
        # 计算源特征的协方差矩阵
        tmp_s = torch.ones((1, ns)).cuda() @ source
        cs = (source.t() @ source - (tmp_s.t() @ tmp_s)/ns) / (ns - 1)
        
        # 计算目标特征的协方差矩阵
        tmp_t = torch.ones((1, nt)).cuda() @ target
        ct = (target.t() @ target - (tmp_t.t() @ tmp_t)/nt) / (nt - 1)
        
        # 计算Frobenius范数
        loss = (cs - ct).pow(2).sum().sqrt()
        loss = loss / (4 * d * d)
        return loss

领域自适应模型

我们扩展了基础模型,增加了特征提取器和领域对齐功能:

class TransferModel_DA(nn.Module):
    def __init__(self, base_model='resnet50', pretrain=True, n_class=31):
        super(TransferModel_DA, self).__init__()
        # ...(初始化与TransferModel类似)...
        # 增加特征提取器
        self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
        
    def forward(self, x, mode='source'):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        if mode == 'source':
            return self.model.fc(features)
        else:
            return features

训练过程

领域自适应的训练过程需要同时处理源数据和目标数据:

def train_da(model, dataloaders, optimizer, lambda_=0.1):
    best_acc = 0
    for epoch in range(n_epoch):
        model.train()
        for (src_data, src_labels), (tar_data, _) in zip(dataloaders['src'], 
                                                       dataloaders['tar']):
            src_data, src_labels = src_data.cuda(), src_labels.cuda()
            tar_data = tar_data.cuda()
            
            # 前向传播
            src_pred = model(src_data, 'source')
            src_features = model(src_data, 'target')
            tar_features = model(tar_data, 'target')
            
            # 计算分类损失和领域对齐损失
            cls_loss = criterion(src_pred, src_labels)
            mmd_loss = mmd_criterion(src_features, tar_features)
            total_loss = cls_loss + lambda_ * mmd_loss
            
            # 反向传播
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
        
        # 验证过程
        model.eval()
        # ...(与微调类似的验证逻辑)...

实验结果分析

通过实验我们可以观察到:

  1. 微调方法在源领域上表现良好,但在目标领域上的性能会有所下降
  2. 领域自适应方法能够显著提升模型在目标领域上的表现
  3. MMD和CORAL两种对齐方法各有优劣,需要根据具体任务选择

总结与展望

本文详细介绍了基于ResNet-50的深度迁移学习方法,包括:

  1. 基本的微调方法及其实现
  2. 领域自适应方法及其核心对齐技术
  3. 两种常用的领域对齐损失函数

未来的改进方向包括:

  1. 尝试更先进的领域对齐方法
  2. 结合半监督学习利用目标领域的少量标注数据
  3. 探索更高效的网络结构适应策略

通过本教程,读者可以快速掌握深度迁移学习的核心思想与实践方法,并能够将其应用到自己的研究或工程项目中。

transferlearning Transfer learning / domain adaptation / domain generalization / multi-task learning etc. Papers, codes, datasets, applications, tutorials.-迁移学习 transferlearning 项目地址: https://gitcode.com/gh_mirrors/tr/transferlearning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

卫伊祺Ralph

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值