基于伪标签的半监督学习——Pytorch框架识别MNIST数据集

部署运行你感兴趣的模型镜像

概述

在训练模型的时候,同时使用有标签数据和无标签数据进行训练,利用伪标记的方法给无标签数据赋予伪标签,再将无标签数据当作有标签数据进行训练,即利用无标签数据进行半监督学习。

伪标记

利用模型现有的预测能力,将无标签样本的预测值作为伪标签。例如将MNIST数据集输入到模型中,得到相应的0~9类别得分,将得分最高的类别作为伪标签,该伪标签当作一般标签使用,和原样本计算损失,迭代模型参数。

整体思路

模型需要先进行预训练,即先用少量的有标签数据训练模型,使得模型获得一定的准确率,之后再输入无标签数据和有标签一起训练模型,其中无标签数据的损失权重逐渐增加。整体两个阶段如下:

  1. 有标签数据预训练
  2. 加入无标签数据一起训练

关键代码

在无标签数据训练时,每到Nunlabel_batch_size就开始一个有标签数据“中途插入”批量训练,这个N是可以自己调控的超参数,N越大,无标签数据中有标签数据穿插训练的频率越快。

    for epoch in range(EPOCHS):
        for batch_idx, unlabeled_batch in enumerate(unlabeled_loader):
            # Forward Pass to get the pseudo labels
            x_unlabeled, y_unlabeled = unlabeled_batch[0],unlabeled_batch[1]
  
            output_unlabeled = model(x_unlabeled)
            _, pseudo_labeled = torch.max(output_unlabeled, 1)

            # Now calculate the unlabeled loss using the pseudo label
            output = model(x_unlabeled)
            # alpha = alpha_weight(step, T1, T2, af)
            unlabeled_loss = alpha * F.nll_loss(output, pseudo_labeled)

            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()

            # For every 50 batches train one epoch on labeled data
            if batch_idx % 30 == 0:

                # Normal training procedure
                for batch_idx, label_batch in enumerate(label_loader):
                    X_batch, y_batch = label_batch[0], label_batch[1] 
                    output = model(X_batch)
                    predicted = torch.max(output, 1)[1]
                    labeled_loss = F.nll_loss(output, y_batch)
                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()
                # Now we increment step by 1
                step += 1

重点是理解伪标记的方法以及如何使用伪标签,此处伪标记的方法就是简单的利用模型预测出最有可能的类别作为伪标签,这是一种非常暴力简单的赋予伪标签的方法,当然,实际上这样的半监督学习方法就能够很有效地利用大量无标签的数据。
感兴趣的朋友可以看看这篇文献,整体思路与这篇文献一致。欢迎朋友们交流探讨。
Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks

您可能感兴趣的与本文相关的镜像

PyTorch 2.9

PyTorch 2.9

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 2
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值