python 训练Bilinear CNN检验模型

本文讲述了如何在运行Inference.py时因内存不足而遇到错误,通过安装Anaconda环境、PyTorch及相关库,并调整pip安装选项,解决MemoryError问题。还介绍了如何使用tkinter进行图像文件选择和模型预测。

运行Inference.py

import argparse
import sys
import os
import struct
import socket
import torch
import torchvision
from BCNN_fc import BCNN_fc
import config
from PIL import Image
import tkinter as tk
from tkinter import filedialog

def main(filename):
    parser = argparse.ArgumentParser(description='network_select')
    parser.add_argument('--net_select',
                        dest='net_select',
                        default='BCNN_fc',
                        help='select which net to train/test.')
    args = parser.parse_args()
    # 配置GPU
    device = torch.device('cpu')

    #加载模型
    if args.net_select == 'BCNN_fc':
        net = BCNN_fc().to(device)
        modelpath = os.path.join(config.PATH['model'],config.PATH['model_fc'])
        net.load_state_dict(torch.load(modelpath, map_location='cpu'),
                            strict=False)

    #设置图像转换参数
    test_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(size=448),
        torchvision.transforms.RandomCrop(size=448),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                         std=(0.229, 0.224, 0.225))
    ])

    #通过对话框指定图像文件

    # Inference 提供/添加接口 诸如 setPictureFile 
    # imfile 由服务器端调用模型接口来进行修改

    #读入图像,并转换为torch.cuda.FloatTensor
    #imfile = os.path.join(config.PATH['cub200_test'], '001.biotite/biotite_0001.jpg')
    imageRGB = Image.open(filename).convert('RGB')
    image = test_transforms(imageRGB).unsqueeze(0)
    imageTest=image.to(device, torch.float)#转换为与显卡相关的,cpu或者CUDA

    #imageTest = imageTest.reshape([1, 3, 448, 448])
    net.eval()
    output = net(imageTest)
    _, prediction = torch.max(output.data, 1)
    # label = torch.autograd.Variable(image1)
    la = prediction[0]
    # 多项分支 (多选一)
    if la == 0:
        print("001.biotite")
    elif la == 1:
        print("002.bornite")
    elif la == 2:
        print("003.chrysocolla")
    elif la == 3:
        print("004.malachite")
    elif la == 4:
        print("005.muscovite")
    elif la == 5:
        print("006.pyrite")
    elif la == 6:
        print("007.quartz")
    else:
        print("快走吧脑弟,一会好赶不上二路汽车了7")
    print(prediction)

if __name__ == '__main__':
    main('test.jpg')

1、使用python3运行程序,发现报错,缺少pytorch环境

在这里插入图片描述

2、安装aconda环境(可以不装,这里没有用到)

(1)下载aconda

https://repo.anaconda.com/archive/index.html

image-20210428165055715

(2)上传到服务器并安装

# 安装命令,一直按回车,安装过程需要同意将安装路径加入到环境变量的配置文件中。 source ~.bashrc使其生效。
bash Anaconda2-2019.07-Linux-x86_64.sh

3.安装pytorch环境

到pytorch官网下载对应版本的pytorch 即可

https://pytorch.org/

image-20210428170627386
在这里插入图片描述

根据最后一行红色字体我们知道出现了MemoryError,根据字面意思我们可知此问题与内存有关。因为pip安装的缓存机制想要先把整个文件读取到内存以后才开始安装,因此可能导致内存不足。所以我们在安装时指示不启用缓存即可,可以使用 --no-cache-dir 命令

pip3 install --no-cache-dir torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

image-20210428170445683

4.再次运行Inference.py

image-20210428170747339

尝试直接安装tkinter,结果没有发现tkinter包

image-20210428170832942

后查询资料得知:

tkinter其实是Python调用tcl程序的标准Python程序,可以通过这个interface调用tcl的程序,因为在大多数的unix系统中都内置了很多的tcl程序和命令。

Tcl 是“工具控制语言(Tool Command Language)”的缩写,其面向对象为otcl语言。Tk 是 Tcl“图形工具箱”的扩展,它提供各种标准的 GUI 接口项,以利于迅速进行高级应用程序开发。

于是,执行terminal 命令:sudo apt install python3-tk

image-20210428171149048

再次运行Inference.py,正在下载vgg16的预训练模型,成功运行
在这里插入图片描述
image-20210428175918962

在知识蒸馏(Knowledge Distillation, KD)中使用基于卷积神经网络(CNN)的目标检测模型作为教师模型,是一种有效提升学生模型性能的方法。由于目标检测任务不仅涉及分类,还包含定位信息,因此传统的针对分类任务的KD方法无法直接迁移至目标检测场景[^1]。 ### 基于CNN的目标检测知识蒸馏的关键策略 #### 1. 标签分配蒸馏(Label Assignment Distillation, LAD) LAD 是一种适用于大多数目标检测器的知识蒸馏方法,其核心思想是让学生模型通过模仿教师模型的标签分配过程来学习更精确的分类与定位能力。不同于传统KD中仅传递最终预测结果,LAD 利用教师模型生成软标签,并将这些软标签用于学生的训练过程中,从而提升整体性能。 #### 2. 定位蒸馏(Localization Distillation, LD) 为了提高目标定位的准确性,LD 方法被引入到目标检测任务中。该方法通过从教师模型中提取边界框的位置分布,并将其蒸馏到学生模型中,以优化定位分支的输出。这种方式可以缓解因标签分配不一致导致的定位模糊问题[^1]。 #### 3. 有价值定位区域(Valuable Localization Region, VLR) VLR 的提出是为了更好地利用教师模型提供的位置信息。不同于传统的基于标签分配的蒸馏区域,VLR 关注的是对定位性能提升最有价值的区域。通过设计特定算法提取这些区域,并结合区域加权机制进行蒸馏,可以显著提升学生模型的定位精度[^1]。 #### 4. 信息差异感知策略(Information Discrepancy-Aware Strategy, IDa-Det) 当学生模型为1比特检测器时,传统KD方法因忽略教师模型与学生模型之间的信息差异而效果不佳。IDa-Det 提出了一种根据信息差异选择代表性候选框的策略,从而更有针对性地进行蒸馏,提升了1比特检测器的性能[^1]。 #### 5. 特征图蒸馏 Yang 等人提出了基于焦点特征图和全局特征图的特征蒸馏方法,通过分别从教师模型的颈部获取不同尺度的特征图,并在学生模型上计算特征蒸馏损失,使得学生能够学习到更具判别性的特征表示。 ### 实现示例:目标检测中的特征图蒸馏 以下是一个简单的特征图蒸馏实现框架,使用PyTorch实现: ```python import torch import torch.nn as nn import torch.nn.functional as F class FeatureDistillationLoss(nn.Module): def __init__(self): super(FeatureDistillationLoss, self).__init__() self.mse_loss = nn.MSELoss() def forward(self, student_features, teacher_features): # 对齐特征图尺寸 if student_features.shape != teacher_features.shape: teacher_features = F.interpolate(teacher_features, size=student_features.shape[2:], mode='bilinear', align_corners=True) loss = self.mse_loss(student_features, teacher_features.detach()) return loss # 示例使用 student_backbone = ... # 学生模型的backbone teacher_backbone = ... # 教师模型的backbone(需冻结) # 获取特征图 student_feat = student_backbone(images) teacher_feat = teacher_backbone(images) # 计算特征蒸馏损失 distill_loss = FeatureDistillationLoss()(student_feat, teacher_feat) ``` ### 结论 在基于CNN的目标检测模型中应用知识蒸馏技术,可以通过多种方式提升学生模型的性能,包括标签分配、定位信息、特征表示等方面。结合具体任务需求选择合适的蒸馏策略,并设计合理的损失函数,是实现高效蒸馏的关键。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dumbking

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值