深度迁移学习实战教程:基于ResNet-50的微调与领域自适应方法
前言
深度迁移学习是当前机器学习领域的重要研究方向,它能够将在一个领域(源领域)学习到的知识迁移到另一个相关但不同的领域(目标领域)。本文将基于经典的ResNet-50网络,通过实际代码演示两种最常用的迁移学习方法:微调(Finetune)和领域自适应(Domain Adaptation)。
环境准备
首先需要确保已安装必要的Python库:
!pip install torch torchvision
数据集准备
我们将使用经典的Office-31数据集,该数据集包含三个不同领域(Amazon、Webcam和DSLR)的图像数据,每个领域包含31个类别的物品图像。
数据集结构如下:
office31/
├── amazon
├── dslr
└── webcam
数据加载与预处理
我们定义了一个数据加载函数,对源领域和目标领域采用不同的预处理方式:
def load_data(root_path, domain, batch_size, phase):
transform_dict = {
'src': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
'tar': transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
}
data = datasets.ImageFolder(root=os.path.join(root_path, domain),
transform=transform_dict[phase])
data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size,
shuffle=phase=='src',
drop_last=phase=='tar',
num_workers=4)
return data_loader
微调(Finetune)方法
模型定义
我们基于预训练的ResNet-50构建迁移学习模型,替换最后的全连接层以适应新的分类任务:
class TransferModel(nn.Module):
def __init__(self, base_model='resnet50', pretrain=True, n_class=31):
super(TransferModel, self).__init__()
self.base_model = base_model
self.pretrain = pretrain
self.n_class = n_class
if self.base_model == 'resnet50':
self.model = torchvision.models.resnet50(pretrained=True)
n_features = self.model.fc.in_features
fc = torch.nn.Linear(n_features, n_class)
self.model.fc = fc
self.model.fc.weight.data.normal_(0, 0.005)
self.model.fc.bias.data.fill_(0.1)
def forward(self, x):
return self.model(x)
def predict(self, x):
return self.forward(x)
训练过程
微调训练的核心逻辑如下:
def finetune(model, dataloaders, optimizer):
best_acc = 0
stop = 0
for epoch in range(n_epoch):
stop += 1
for phase in ['src', 'val', 'tar']:
if phase == 'src':
model.train()
else:
model.eval()
total_loss, correct = 0, 0
for inputs, labels in dataloaders[phase]:
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'src'):
outputs = model(inputs)
loss = criterion(outputs, labels)
preds = torch.max(outputs, 1)[1]
if phase == 'src':
loss.backward()
optimizer.step()
total_loss += loss.item() * inputs.size(0)
correct += torch.sum(preds == labels.data)
epoch_loss = total_loss / len(dataloaders[phase].dataset)
epoch_acc = correct.double() / len(dataloaders[phase].dataset)
if phase == 'val' and epoch_acc > best_acc:
stop = 0
best_acc = epoch_acc
torch.save(model.state_dict(), 'model.pkl')
if stop >= early_stop:
break
优化器设置
我们采用SGD优化器,并为全连接层设置更高的学习率:
param_group = []
learning_rate = 0.0001
momentum = 5e-4
for k, v in model.named_parameters():
if not k.__contains__('fc'):
param_group += [{'params': v, 'lr': learning_rate}]
else:
param_group += [{'params': v, 'lr': learning_rate * 10}]
optimizer = torch.optim.SGD(param_group, momentum=momentum)
领域自适应(Domain Adaptation)方法
领域自适应在微调的基础上增加了领域对齐的损失函数,以减少源领域和目标领域之间的分布差异。
领域对齐损失函数
我们实现了两种常用的领域对齐方法:
MMD(最大均值差异)损失
class MMD_loss(nn.Module):
def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
super(MMD_loss, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
self.kernel_type = kernel_type
def guassian_kernel(self, source, target):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)),
int(total.size(0)),
int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)),
int(total.size(0)),
int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if self.fix_sigma:
bandwidth = self.fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= self.kernel_mul ** (self.kernel_num // 2)
bandwidth_list = [bandwidth * (self.kernel_mul**i)
for i in range(self.kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def forward(self, source, target):
batch_size = int(source.size()[0])
kernels = self.guassian_kernel(source, target)
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
return loss
CORAL(相关性对齐)损失
class CORAL_loss(nn.Module):
def __init__(self):
super(CORAL_loss, self).__init__()
def forward(self, source, target):
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# 计算源特征的协方差矩阵
tmp_s = torch.ones((1, ns)).cuda() @ source
cs = (source.t() @ source - (tmp_s.t() @ tmp_s)/ns) / (ns - 1)
# 计算目标特征的协方差矩阵
tmp_t = torch.ones((1, nt)).cuda() @ target
ct = (target.t() @ target - (tmp_t.t() @ tmp_t)/nt) / (nt - 1)
# 计算Frobenius范数
loss = (cs - ct).pow(2).sum().sqrt()
loss = loss / (4 * d * d)
return loss
领域自适应模型
我们扩展了基础模型,增加了特征提取器和领域对齐功能:
class TransferModel_DA(nn.Module):
def __init__(self, base_model='resnet50', pretrain=True, n_class=31):
super(TransferModel_DA, self).__init__()
# ...(初始化与TransferModel类似)...
# 增加特征提取器
self.feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
def forward(self, x, mode='source'):
features = self.feature_extractor(x)
features = features.view(features.size(0), -1)
if mode == 'source':
return self.model.fc(features)
else:
return features
训练过程
领域自适应的训练过程需要同时处理源数据和目标数据:
def train_da(model, dataloaders, optimizer, lambda_=0.1):
best_acc = 0
for epoch in range(n_epoch):
model.train()
for (src_data, src_labels), (tar_data, _) in zip(dataloaders['src'],
dataloaders['tar']):
src_data, src_labels = src_data.cuda(), src_labels.cuda()
tar_data = tar_data.cuda()
# 前向传播
src_pred = model(src_data, 'source')
src_features = model(src_data, 'target')
tar_features = model(tar_data, 'target')
# 计算分类损失和领域对齐损失
cls_loss = criterion(src_pred, src_labels)
mmd_loss = mmd_criterion(src_features, tar_features)
total_loss = cls_loss + lambda_ * mmd_loss
# 反向传播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# 验证过程
model.eval()
# ...(与微调类似的验证逻辑)...
实验结果分析
通过实验我们可以观察到:
- 微调方法在源领域上表现良好,但在目标领域上的性能会有所下降
- 领域自适应方法能够显著提升模型在目标领域上的表现
- MMD和CORAL两种对齐方法各有优劣,需要根据具体任务选择
总结与展望
本文详细介绍了基于ResNet-50的深度迁移学习方法,包括:
- 基本的微调方法及其实现
- 领域自适应方法及其核心对齐技术
- 两种常用的领域对齐损失函数
未来的改进方向包括:
- 尝试更先进的领域对齐方法
- 结合半监督学习利用目标领域的少量标注数据
- 探索更高效的网络结构适应策略
通过本教程,读者可以快速掌握深度迁移学习的核心思想与实践方法,并能够将其应用到自己的研究或工程项目中。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考