MAML模型无关的元学习代码完整复现(Pytorch版)

本文详细介绍了MAML元学习模型的原理,并提供了Pytorch实现的代码复现过程,包括数据预处理、构建训练集和测试集、Base-Learner与Meta-Learner的构造。实验在Omniglot数据集上进行,虽然在训练过程中遇到过拟合问题,最高测试准确率为92%,低于原论文的98%。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1 引言

元学习是今年来新起的一种深度学习任务,它主要是想训练出具有强学习能力的神经网络。元学习领域一开始是一个小众的领域,之前很多年都没有很好的进展,直到Finn, C.在就读博士期间发表了一篇元学习的论文,也就是大名鼎鼎的MAML,它在回归,分类,强化学习三个任务上都达到了当时最好的性能。

我曾经在半年前发表过一篇MAML的学习笔记,博文地址点这里

MAML出现之后算是掀起来了一波研究元学习的浪潮,此后改编MAML的论文层出不穷,但都没有实质性的突破。接下来我将引用我学习笔记中的语句来简要介绍一下MAML。

MAML主要是学习出模型的初始参数,使得这个参数在新任务上经过少量的迭代更新之后就能使模型达到最好的效果。过去的方法一般是学习出一个迭代函数或者一个学习规则。MAML没有新增参数,也没有对模型提出任何约束。MAML可以看作是最大化损失函数在新任务上的灵敏度,从而当参数只有很小的改编时,损失函数也能大幅减小。

由于元学习模型天然的快速训练出好的模型,所以其主要用于小样本学习之中。元学习的论文中也大多将小样本学习任务作为论文实验。

2 数据集

本文的复现用到的数据集小样本领域的通用数据集Omniglot,数据集的地址可以在我的github中找到omniglot_standard.zip
以下引用@心之宙对omniglot的介绍:

Omniglot 一般会被戏称为 MNIST 的转置,大家可以想想为什么?Omniglot 数据集包含来自 50个不同国家的字母表的 1623 个不同手写字符。每一个字符都是由 20个不同的人通过亚马逊的 Mechanical Turk 在线绘制的。
Omniglot 数据集总共包含 50个不同国家的字母表。我们通常将这些分成一组包含 30个字母表的背景(background)集和一组包含 20 个字母表的评估(evaluation)集。
更具挑战性的表示学习任务是使用较小的背景集 “background small 1” 和 “background small 2”。每一个都只包含 5个字母, 更类似于一个成年人在学习一般的字符时可能遇到的经验。

本文的复现主要基于omniglot的标准集。

3 代码分段详解

3.1 数据预处理

首先对zip文件进行解压,解压后可以在python子文件夹中获得如下数据集:

import torch
import numpy as np
import os
import zipfile

root_path = './../datasets'
processed_folder =  os.path.join(root_path)

zip_ref = zipfile.ZipFile(os.path.join(root_path,'omniglot_standard.zip'), 'r')
zip_ref.extractall(root_path)
zip_ref.close()
然后对图片进行预处理
# 数据预处理
root_dir = './../datasets/omniglot/python'

import torchvision.transforms as transforms
from PIL import Image

'''
an example of img_items:
( '0709_17.png',
  'Alphabet_of_the_Magi/character01',
  './../datasets/omniglot/python/images_background/Alphabet_of_the_Magi/character01')
'''
def find_classes(root_dir):
    img_items = []
    for (root, dirs, files) in os.walk(root_dir): 
        for file in files:
            if (file.endswith("png")):
                r = root.split('/')
                img_items.append((file, r[-2] + "/" + r[-1], root))
    print("== Found %d items " % len(img_items))
    return img_items

## 构建一个词典{class:idx}
def index_classes(items):
    class_idx = {
   
   }
    count = 0
    for item in items:
        if item[1] not in class_idx:
            class_idx[item[1]] = count
            count += 1
    print('== Found {} classes'.format(len(class_idx)))
    return class_idx
        

img_items =  find_classes(root_dir)
class_idx = index_classes(img_items)


temp = dict()
for imgname, classes, dirs in img_items:
    img = '{}/{}'.format(dirs, imgname)
    label = class_idx[classes]
    transform = transforms.Compose([lambda img: Image.open(img).convert('L'),
                              lambda img: img.resize((28,28)),
                              lambda img: np.reshape(img, (28,28,1)),
                              lambda img
评论 40
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值