- Omniglot数据集
- 整个数据集包含1623个不同的手写字符,每个手写字符都由20个不同的人参与绘制
- 在Pytorch-meta的实现当中,其训练集包含了1028个手写字符,涉及到的数据增强有3种,分别为旋转90度、旋转180度以及旋转270度,这样一来,训练样本的总个数就从1028扩充到了1028*4=4112个
- Omniglot数据集的加载
- 将4112张扩充后的图片视为一个整体,并进行洗牌,然后在洗牌之后的数据集进行样本的挑选
- 采用5-way 5-shot的训练方式,在每一个epoch中,挑选5个手写字符,每个手写字符中随机选取5个样本作为support set,另取、选取15个样本作为query set
- 训练过程
- 在训练过程中,我们先是在support set上进行训练,并计算梯度,然后再在query set上进行测试,并计算loss。值得注意的是,在query set上进行测试的过程中,我们会根据训练过程中计算的梯度来计算模型的输出(但并不会更新模型的参数)
- 模型参数的更新最终依赖的是query set计算得到的loss,而不是在support set上计算得到的loss
- 由此基本上可以得到一个结论,MAML主要通过计算每个task的损失来计算loss,而不是依赖于每个训练样本来更新模型参数