Reptile元学习算法复现实战:在Omniglot数据集上的少样本学习探索

在深度学习领域中,元学习(Meta-Learning)一直是一个充满挑战性的研究方向。最近我尝试复现了OpenAI提出的Reptile算法,这是一个相对简单但有效的一阶元学习方法。虽然最终的实验结果与论文原始数据存在一定差距,但这个复现过程让我对元学习有了更深入的理解。

初识Reptile算法

当我第一次接触到《On First-Order Meta-Learning Algorithms》这篇论文时,就被Reptile算法的简洁性所吸引。与需要计算二阶导数的MAML算法相比,Reptile的核心思想异常简单:在每个任务上进行几步梯度下降,然后将模型参数朝着任务优化后的参数方向移动。这种"先学习再调整"的策略,让模型能够快速适应新的任务。

元学习的魅力在于它试图解决一个根本性问题:如何让机器像人类一样快速学习新知识。人类在看到几个新字符的样本后,往往能够快速识别相似的字符,这正是少样本学习想要达到的效果。Reptile算法通过在多个相关任务上的训练,让模型学会一个良好的初始化参数,使其能够通过少量梯度步骤快速适应新任务。

代码架构设计

整个项目的代码结构相当清晰,主要由四个Python文件组成。utils.py提供了一些基础的工具函数,包括文件列表获取和最新检查点查找等功能。models.py实现了Reptile算法的核心模型类,其中最关键的是point_grad_to方法,它将梯度设置为当前模型与目标模型参数的差值,这正是Reptile算法的精髓所在。

omniglot.py负责数据集的处理,这个文件让我印象深刻的地方在于它不仅实现了数据加载,还包含了自动下载Omniglot数据集的功能。当我第一次运行代码时,系统自动从GitHub下载了background和evaluation两个数据集,并智能地合并了数据结构。这种自动化的设计大大简化了环境搭建的复杂度。

最后的train_omniglot.py是整个训练流程的核心,包含了完整的元学习训练循环。值得注意的是,代码还贴心地提供了TensorBoard支持和断点恢复功能,这在长时间训练中非常有用。

# utils.py
import os
import re


# Those two functions are taken from torchvision code because they are not available on pip as of 0.2.0
def list_dir(root, prefix=False):
    """List all directories at a given root
    Args:
        root (str): Path to directory whose folders need to be listed
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the directories found
    """
    root = os.path.expanduser(root)
    directories = list(
        filter(
            lambda p: os.path.isdir(os.path.join(root, p)),
            os.listdir(root)
        )
    )

    if prefix is True:
        directories = [os.path.join(root, d) for d in directories]

    return directories


def list_files(root, suffix, prefix=False):
    """List all files ending with a suffix at a given root
    Args:
        root (str): Path to directory whose folders need to be listed
        suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
            It uses the Python "str.endswith" method and is passed directly
        prefix (bool, optional): If true, prepends the path to each result, otherwise
            only returns the name of the files found
    """
    root = os.path.expanduser(root)
    files = list(
        filter(
            lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
            os.listdir(root)
        )
    )

    if prefix is True:
        files = [os.path.join(root, d) for d in files]

    return files


def find_latest_file(folder):
    files = []
    for fname in os.listdir(folder):
        s = re.findall(r'\d+', fname)
        if len(s) == 1:
            files.append((int(s[0]), fname))
    if files:
        return max(files)[1]
    else:
        return None
# models.py
import torch
from torch import nn


class ReptileModel(nn.Module):

    def __init__(self):
        nn.Module.__init__(self)

    def point_grad_to(self, target):
        '''
        Set .grad attribute of each parameter to be proportional
        to the difference between self and target
        '''
        for p, target_p in zip(self.parameters(), target.parameters()):
            if p.grad is None:
                if self.is_cuda():
                    p.grad = torch.zeros(p.size(), device=p.device)
                else:
                    p.grad = torch.zeros(p.size())
            p.grad.data.zero_()  # not sure this is required
            p.grad.data.add_(p.data - target_p.data)

    def is_cuda(self):
        return next(self.parameters()).is_cuda


class OmniglotModel(ReptileModel):
    """
    A model for Omniglot classification.
    """
    def __init__(self, num_classes):
        ReptileModel.__init__(self)

        self.num_classes = num_classes

        self.conv = nn.Sequential(
            # 28 x 28 - 1
            nn.Conv2d(1, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 14 x 14 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 7 x 7 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 4 x 4 - 64
            nn.Conv2d(64, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            # 2 x 2 - 64
        )

        self.classifier = nn.Sequential(
            # 2 x 2 x 64 = 256
            nn.Linear(256, num_classes),
            nn.LogSoftmax(1)
        )

    def forward(self, x):
        out = x.view(-1, 1, 28, 28)
        out = self.conv(out)
        out = out.view(len(out), -1)
        out = self.classifier(out)
        return out

    def predict(self, prob):
        __, argmax = prob.max(1)
        return argmax

    def clone(self):
        clone = OmniglotModel(self.num_classes)
        clone.load_state_dict(self.state_dict())
        if self.is_cuda():
            clone.cuda()
        return clone


if __name__ == '__main__':
    model = OmniglotModel(20)
    x = torch.zeros(5, 28*28)
    y = model(x)
    print('x', x.size())
    print('y', y.size())
# omniglot.py
from torch.utils import data
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import urllib.request
import zipfile
import shutil

from utils import list_files, list_dir

# 自动下载数据集的URLs
BACKGROUND_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip"
EVALUATION_URL = "https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip"


def download_and_extract_omniglot(root='omniglot'):
    """
    自动下载和提取Omniglot数据集
    """
    if os.path.exists(root) and 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

智算菩萨

欢迎阅读最新融合AI编程内容

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值