参考了解读: Few-shot learning with Graph Neural Networks - Yuan Z的文章 - 知乎
代码使用 o m n i g l o t omniglot omniglot数据集,以 5 w a y − 1 s h o t 5way-1shot 5way−1shot为例,一个 e p i s o d e episode episode只有一张 q u e r y query query,一个 b a t c h batch batch中有 300 300 300个 e p i s o d e episode episode
在
main.py中的第144行开始训练的迭代
数据
首先加载数据
# main.py line149
data = train_loader.get_task_batch(batch_size=args.batch_size, n_way=args.train_N_way,
unlabeled_extra=args.unlabeled_extra, num_shots=args.train_N_shots,
cuda=args.cuda, variable=True)
[batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi, hidden_labels] = data
在一般的情况下(不是半监督学习),我们只需要用到
batch_x查询的图片(300,1,28,28)label_x查询图片的 o n e − h o t one-hot one−hot标签(300,5)batches_xi支持集[(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28),(300,1,28,28)]labels_yi支持集 o n e − h o t one-hot o

本文深入探讨了使用Graph Neural Networks (GNN)进行Few-shot Learning的方法。通过5-way 1-shot任务为例,展示了在一个episode中如何处理查询图像和构建支持集。在训练过程中,首先利用编码器提取特征,接着GNN模块计算节点间相似度,最后通过softmax层进行分类。GNN内部包含了Wcompute和Gconv两个关键步骤,分别用于计算权重和生成新特征。整个流程涉及特征提取、相似度计算和损失计算,揭示了GNN在解决少样本学习问题中的核心机制。
最低0.47元/天 解锁文章
817





