最近看了好久的MAML源代码,也学了TF,想写篇文章将MAML源码总结出来,揭开MAML的神秘面纱。
论文

摘要
We propose an algorithm for meta-learning that is model-agnostic, in the sense that it is compatible with any model trained with gradient descent and applicable to a variety of different learning problems, including classification, regression, and reinforcement learning. The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples. In our approach, the parameters of the model are explicitly trained such that a small number of gradient steps with a small amount of training data from a new task will produce good generalization performance on that task. In effect, our method trains the model to be easy to fine-tune. We demonstrate that this approach leads to state-of-the-art performance on two fewshot image classification benchmarks, produces good results on few-shot regression, and accelerates fine-tuning for policy gradient reinforcement learning with neural network policies.
总结:MAML找到一个好的初始参数,而不是0,极大减少训练时间和样本数。
算法


如图,我们有初始参数θ,有3个task,每个task有最好的参数θ*。此时θ可以有3个梯度下降的方向,但是我们没有选择梯度下降,而是往这3个点共有的方向前进了一步。这样新得到的θ只需要几步就能达到其他任务的 θ*
源码
我找的源码并不是这篇论文给的源码,而是“dragen1860”根据官方源码实现的简易版源码,链接,可以看人家的highlight:
- adopted from cbfin’s official implementation with equivalent performance on mini-imagenet
- clean, tiny code style and very easy-to-follow from comments almost every lines
- faster and trivial improvements, eg. 0.335s per epoch comparing with 0.563s per epoch, saving up to 3.8 hours for total 60,000 training process
文件结构

meta-learn的算法流程
- 从main函数入口进去
- 设定参数
- nway:5,分类数,如猫的、狗的、马的……
- kshot:1,样本数
- kquery:15,?
- meta_batchsz:4,元学习中的batch数,即task的数目
- K:5,为了找到每个task最好的θ*,MAML可以进行K次梯度下降,并不是固定一次
- 生成数据(张量)
这里先回顾下 support set和query set:每个task就是传统机器学习的任务,包含了训练集和测试集,但是容易搞混,我们不这么叫了,叫support set和query set(那这两个set 设置多大比较好?),然后4个task作为meta-train用的,叫做train set,4个task作为meta-test用的,叫做test set

以下就是4个任务在meta-train阶段的support set和query set
# image_tensor: [4, 80, 84*84*3]
support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x')
query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x')
support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y')
query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1,<

本文深入解析MAML(Model-Agnostic Meta-Learning)算法,揭示其如何通过元学习找到良好的初始化参数,显著减少新任务的学习时间和所需样本数量。文章详细介绍了MAML的工作原理、源代码实现及实验结果。
最低0.47元/天 解锁文章
587

被折叠的 条评论
为什么被折叠?



