在深度学习领域中,元学习(Meta-Learning)一直是一个充满挑战性的研究方向。最近我尝试复现了OpenAI提出的Reptile算法,这是一个相对简单但有效的一阶元学习方法。虽然最终的实验结果与论文原始数据存在一定差距,但这个复现过程让我对元学习有了更深入的理解。
初识Reptile算法
当我第一次接触到《On First-Order Meta-Learning Algorithms》这篇论文时,就被Reptile算法的简洁性所吸引。与需要计算二阶导数的MAML算法相比,Reptile的核心思想异常简单:在每个任务上进行几步梯度下降,然后将模型参数朝着任务优化后的参数方向移动。这种"先学习再调整"的策略,让模型能够快速适应新的任务。
元学习的魅力在于它试图解决一个根本性问题:如何让机器像人类一样快速学习新知识。人类在看到几个新字符的样本后,往往能够快速识别相似的字符,这正是少样本学习想要达到的效果。Reptile算法通过在多个相关任务上的训练,让模型学会一个良好的初始化参数,使其能够通过少量梯度步骤快速适应新任务。
代码架构设计
整个项目的代码结构相当清晰,主要由四个Python文件组成。utils.py
提供了一些基础的工具函数,包括文件列表获取和最新检查点查找等功能。models.py
实现了Reptile算法的核心模型类,其中最关键的是point_grad_to
方法,它将梯度设置为当前模型与目标模型参数的差值,这正是Reptile算法的精髓所在。
omniglot.py
负责数据集的处理,这个文件让我印象深刻的地方在于它不仅实现了数据加载,还包含了自动下载Omniglot数据集的功能。当我第一次运行代码时,系统自动从GitHub下载了background和evaluation两个数据集,并智能地合并了数据结构。这种自动化的设计大大简化了环境搭建的复杂度。
最后的train_omniglot.py
是整个训练流程的核心,包含了完整的元学习训练循环。值得注意的是,代码还贴心地提供了TensorBoard支持和断点恢复功能,这在长时间训练中非常有用。
# utils.py
import os
import re
# Those two functions are taken from torchvision code because they are not available on pip as of 0.2.0
def list_dir(root, prefix=False):
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
def find_latest_file(folder):
files = []
for fname in os.listdir(folder):
s = re.findall(r'\d+', fname)
if len(s) == 1:
files.append((int(s[0]), fname))
if files:
return max(files)[1]
else:
return None
# models.py
import torch
from torch import nn
class ReptileModel(nn.Module):
def __init__(self):
nn.Module.__init__(self)
def point_grad_to(self, target):
'''
Set .grad attribute of each parameter to be proportional
to the difference between self and target
'''
for p, target_p in zip(self.parameters(), target.parameters()):
if p.grad is None:
if self.is_cuda():
p.grad = torch.zeros(p.size(), device=p.device)
else:
p.grad = torch.zeros(p.size())
p.grad.data.zero_() # not sure this is required
p.grad.data.add_(p.data - target_p.data)
def is_cuda(self):
return next(self.parameters()).is_cuda
class OmniglotModel(ReptileModel):
"""
A model for Omniglot classification.
"""
def __init__(self, num_classes):
ReptileModel.__init__(self)
self.num_classes = num_classes
self.conv = nn.Sequential(
# 28 x 28 - 1
nn.Conv2d(1, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 14 x 14 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 7 x 7 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 4 x 4 - 64
nn.Conv2d(64, 64, 3, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
# 2 x 2 - 64
)
self.classifier = nn.Sequential(
# 2 x 2 x 64 = 256
nn.Linear(256, num_classes),
nn.LogSoftmax(1)
)
def forward(self, x):
out = x.view(-1, 1, 28, 28)
out = self.conv(out)
out = out.view(len(out), -1)
out = self.classifier(out)
return out
def predict(self, prob):
__, argmax = prob.max(1)
return argmax
def clone(self):
clone = OmniglotModel(self.num_classes)
clone.load_state_dict(self.state_dict())
if self.is_cuda():
clone.cuda()
return clone
if __name__ == '__main__':
model = OmniglotModel(20)
x = torch.zeros(5, 28*28)
y = model(x)
print('x', x.size())
print('y', y.size())
# omniglot.py
from torch.utils import data
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import urllib.request
import zipfile
import shutil
from utils import list_files, list_dir
# 自动下载数据集的URLs
BACKGROUND_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip"
EVALUATION_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip"
def download_and_extract_omniglot(root='omniglot'):
"""
自动下载和提取Omniglot数据集
"""
if os.path.exists(root) and