模型无关元学习/Model-Agnostic Meta-Learning(MAML)源码解读

本文深入解析MAML(Model-Agnostic Meta-Learning)算法,揭示其如何通过元学习找到良好的初始化参数,显著减少新任务的学习时间和所需样本数量。文章详细介绍了MAML的工作原理、源代码实现及实验结果。

最近看了好久的MAML源代码,也学了TF,想写篇文章将MAML源码总结出来,揭开MAML的神秘面纱。

论文

在这里插入图片描述

摘要

We propose an algorithm for meta-learning that is model-agnostic, in the sense that it is compatible with any model trained with gradient descent and applicable to a variety of different learning problems, including classification, regression, and reinforcement learning. The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples. In our approach, the parameters of the model are explicitly trained such that a small number of gradient steps with a small amount of training data from a new task will produce good generalization performance on that task. In effect, our method trains the model to be easy to fine-tune. We demonstrate that this approach leads to state-of-the-art performance on two fewshot image classification benchmarks, produces good results on few-shot regression, and accelerates fine-tuning for policy gradient reinforcement learning with neural network policies.

总结:MAML找到一个好的初始参数,而不是0,极大减少训练时间和样本数。

算法

在这里插入图片描述
在这里插入图片描述
如图,我们有初始参数θ,有3个task,每个task有最好的参数θ*。此时θ可以有3个梯度下降的方向,但是我们没有选择梯度下降,而是往这3个点共有的方向前进了一步。这样新得到的θ只需要几步就能达到其他任务的 θ*

具体地:算法第6步计算出每个task的θ*(也不是θ*,因为只走了几步没走完),假设θ走到了这一步,变成了θi*,再求这一点的梯度,加起来平均,得到一个中和的方向,往这个方向走,即第8步更新原始的θ,就得到了一个meta-learner

源码

我找的源码并不是这篇论文给的源码,而是“dragen1860”根据官方源码实现的简易版源码,链接,可以看人家的highlight:

  • adopted from cbfin’s official implementation with equivalent performance on mini-imagenet
  • clean, tiny code style and very easy-to-follow from comments almost every lines
  • faster and trivial improvements, eg. 0.335s per epoch comparing with 0.563s per epoch, saving up to 3.8 hours for total 60,000 training process

文件结构

在这里插入图片描述

meta-learn的算法流程

  1. 从main函数入口进去
  2. 设定参数
  • nway:5,分类数,如猫的、狗的、马的……
  • kshot:1,样本数
  • kquery:15,?
  • meta_batchsz:4,元学习中的batch数,即task的数目
  • K:5,为了找到每个task最好的θ*,MAML可以进行K次梯度下降,并不是固定一次
  1. 生成数据(张量)

这里先回顾下 support set和query set:每个task就是传统机器学习的任务,包含了训练集和测试集,但是容易搞混,我们不这么叫了,叫support set和query set(那这两个set 设置多大比较好?),然后4个task作为meta-train用的,叫做train set,4个task作为meta-test用的,叫做test set

在这里插入图片描述

以下就是4个任务在meta-train阶段的support set和query set

# image_tensor: [4, 80, 84*84*3]
support_x = tf.slice(image_tensor, [0, 0, 0], [-1,  nway *  kshot, -1], name='support_x')
query_x = tf.slice(image_tensor, [0,  nway *  kshot, 0], [-1, -1, -1], name='query_x')
support_y = tf.slice(label_tensor, [0, 0, 0], [-1,  nway *  kshot, -1], name='support_y')
query_y = tf.slice(label_tensor, [0,  nway *  kshot, 0], [-1,<
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值