TF版FasterRCNN:train_val.py代码解读笔记

本文详细解读了使用TensorFlow实现的Faster R-CNN模型训练与验证过程的代码,涵盖了关键步骤和核心算法的理解。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

个人代码阅读笔记。

# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Xinlei Chen and Zheqi He
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from model.config import cfg
import roi_data_layer.roidb as rdl_roidb
from roi_data_layer.layer import RoIDataLayer
from utils.timer import Timer
try:
  import cPickle as pickle
except ImportError:
  import pickle
import numpy as np
import os
import sys
import glob
import time

import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
'''
这是训练的第二层入口,第一层是train_faster_rcnn.sh
'''
class SolverWrapper(object):
  """
    A wrapper class for the training process
	训练过程的封装
  """
  #实例初始化
  def __init__(self, sess, network, imdb, roidb, valroidb, output_dir, tbdir, pretrained_model=None):
    self.net = network
    self.imdb = imdb#image database
    self.roidb = roidb#roi database
    self.valroidb = valroidb#valid roi database
    self.output_dir = output_dir#输出路径
    self.tbdir = tbdir
    # Simply put '_val' at the end to save the summaries from the validation set
    self.tbvaldir = tbdir + '_val'#?
    if not os.path.exists(self.tbvaldir):
      os.makedirs(self.tbvaldir)
    self.pretrained_model = pretrained_model#pre训练模型

  def snapshot(self, sess, iter):#保存参数snapshot,多久保存一次tf的会话
    net = self.net

    if not os.path.exists(self.output_dir):
      os.makedirs(self.output_dir)

    # Store the model snapshot
	#TRAIN.SNAPSHOT_PREFIX为名字前缀‘res101_faster_rcnn'’
    filename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.ckpt'
    filename = os.path.join(self.output_dir, filename)
    self.saver.save(sess, filename)
    print('Wrote snapshot to: {:s}'.format(filename))

    # Also store some meta information, random state, etc.
	#可以看出来ckpt才是主要的权重,pkl是保存了一些原始信息和随机状态
    nfilename = cfg.TRAIN.SNAPSHOT_PREFIX + '_iter_{:d}'.format(iter) + '.pkl'
    nfilename = os.path.join(self.output_dir, nfilename)
    # current state of numpy random
	#当前的numpy.random状态(随机种子什么的
    st0 = np.random.get_state()
    # current position in the database
	#当前载入数据库
    cur = self.data_layer._cur
    # current shuffled indexes of the database
	#当前打乱的数据索引
    perm = self.data_layer._perm
    # current position in the validation database
	#当前数据在validation database中的位置
    cur_val = self.data_layer_val._cur
    # current shuffled indexes of the validation database
	#当前打乱的验证数据集索引
    perm_val = self.data_layer_val._perm

    # Dump the meta info
    with open(nfilename, 'wb') as fid:#打开文件
      pickle.dump(st0, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(cur_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(perm_val, fid, pickle.HIGHEST_PROTOCOL)
      pickle.dump(iter, fid, pickle.HIGHEST_PROTOCOL)

    return filename, nfilename

  def from_snapshot(self, sess, sfile, nfile):#载入快照
    p
<think>我们面对的错误信息是:`dgl._ffi.base.DGLError:[10:36:44]...Checkfailed:allow_missing:DeviceAPIgpuisnotenabled.Pleaseinstallthecudaversionofdgl.`根据引用[1][2][3]的分析,这个错误的原因是当前安装的DGL库不支持CUDA(即不是CUDA本)。因此,我们需要安装与当前环境CUDA本匹配的DGLCUDA本。###解决步骤####1.**确认CUDA本**在安装DGL的CUDA本之前,需要确认系统CUDA本和PyTorch使用的CUDA本(两者应一致)。**检查系统CUDA本**:```bashnvcc--version```或```bashcat/usr/local/cuda/version.txt```**检查PyTorch的CUDA本**:```pythonimporttorchprint(torch.version.cuda)#输出PyTorch编译时使用的CUDA本```####2.**安装匹配的DGLCUDA本**根据上述确认的CUDA本,选择对应的DGL安装命令。DGL官方提供了针对不同CUDA本的安装包。可以通过以下命令安装(以PyPI为例):|CUDA本|安装命令||----------|----------||CUDA11.0|`pipinstalldgl-cuda11.0-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.1|`pipinstalldgl-cuda11.1-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.2|`pipinstalldgl-cuda11.2-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.3|`pipinstalldgl-cuda11.3-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.4|`pipinstalldgl-cuda11.4-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.5|`pipinstalldgl-cuda11.5-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.6|`pipinstalldgl-cuda11.6-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.7|`pipinstalldgl-cuda11.7-fhttps://data.dgl.ai/wheels/repo.html`||CUDA11.8|`pipinstalldgl-cuda11.8-fhttps://data.dgl.ai/wheels/repo.html`||CUDA12.1|`pipinstalldgl-cuda12.1-fhttps://data.dgl.ai/wheels/repo.html`|>**注意**:如果系统CUDA本与PyTorch的CUDA本不一致,建议重新安装PyTorch使其匹配系统CUDA本[^2]。####3.**验证安装**安装后运行以下代码验证DGL是否支持GPU:```pythonimportdglprint(dgl.__version__)#打印DGL本print(dgl.backend.backend_name)#后端框架(应为pytorch)#检查DGL是否支持CUDAprint('DGLCUDA支持:',dgl.backend.build_with_cuda())#尝试在GPU上创建图importtorchiftorch.cuda.is_available():u=torch.tensor([0,1,2]).cuda()v=torch.tensor([2,3,4]).cuda()g=dgl.graph((u,v))print('图设备:',g.device)#应输出cuda设备else:print("CUDA不可用")```####4.**常见问题处理**-**本冲突**:如果之前安装了CPU本的DGL,先卸载它:```bashpipuninstalldgl```再安装对应CUDA本的DGL。-**环境变量问题**:确保环境变量`CUDA_HOME`正确指向CUDA安装路径(通常在`/usr/local/cuda`)。-**PyTorch与DGL的CUDA本不匹配**:必须保证PyTorch和DGL使用相同的主要CUDA本(例如都是11.x)。###原理说明DGL库分为CPU本和CUDA本。错误信息明确提示当前安装的是CPU本,因此无法使用GPU功能。安装CUDA本的DGL后,库将链接CUDA运行时,从而支持GPU操作[^3]。引用[4][5]中的代码示例展示了如何将图数据移动到GPU,这只有在安装了CUDA本的DGL时才能成功执行。###相关问题1.如何检查PyTorch是否正确识别CUDA设备?2.DGL支持哪些图神经网络模型?3.如何在多GPU环境下使用DGL进行训练?4.DGL与PyTorchGeometric(PyG)的主要区别是什么?</think>### 解决 `Device API gpu is not enabled. Please install the cuda version of dgl` 错误 这个错误表示您安装的 DGL 库是 CPU 本,无法使用 GPU 功能。以下是系统性的解决方案: #### 1. **验证环境一致性**(关键步骤) ```python import torch, dgl print(f"PyTorch CUDA可用: {torch.cuda.is_available()}") # 必须为True print(f"PyTorch CUDA: {torch.version.cuda}") # 记录本号 print(f"DGL: {dgl.__version__}") # 记录本号 print(f"DGL CUDA支持: {dgl.backend.build_with_cuda()}") # 必须为True ``` - 如果 `build_with_cuda()` 返回 `False`,说明安装的是 CPU DGL[^3] #### 2. **重新安装匹配的 DGL CUDA 本** 根据 PyTorch 的 CUDA 本选择对应 DGL 本: ```bash # 先卸载现有本 pip uninstall dgl # 安装匹配本(示例) # CUDA 11.1 pip install dgl-cuda11.1 -f https://data.dgl.ai/wheels/repo.html # CUDA 11.3 pip install dgl-cuda11.3 -f https://data.dgl.ai/wheels/repo.html # CUDA 11.6 pip install dgl-cuda11.6 -f https://data.dgl.ai/wheels/repo.html ``` > **本匹配原则**: > - PyTorch CUDA 本必须与 DGL CUDA 本**完全一致** > - 查看 PyTorch CUDA 本:`print(torch.version.cuda)` #### 3. **完整环境重建步骤** ```bash # 1. 创建新环境(推荐) conda create -n dgl_env python=3.8 conda activate dgl_env # 2. 安装匹配的PyTorch pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html # 3. 安装对应DGL pip install dgl-cuda11.3 -f https://data.dgl.ai/wheels/repo.html # 4. 验证安装 python -c "import dgl; print(dgl.backend.build_with_cuda())" ``` #### 4. **特殊场景处理** **笔记本用户注意事项**: - 使用 `nvidia-smi` 确认驱动本支持 CUDA - 若 CUDA 本过高(>11.8),需降级: ```bash conda install cudatoolkit=11.3 -c nvidia ``` **Windows 系统额外步骤**: 1. 安装 Microsoft C++ Build Tools 2. 设置环境变量: ```cmd set CUDA_HOME=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.3 ``` 3. 重启终端后再安装 #### 5. **验证 GPU 功能** 运行官方测试代码[^5]: ```python import dgl import torch as th # 测试GPU图创建 u, v = th.tensor([0, 1, 2]).cuda(), th.tensor([2, 3, 4]).cuda() g = dgl.graph((u, v)) print(f"图设备: {g.device}") # 应输出 cuda:0 # 测试数据迁移 cpu_g = dgl.graph(([0, 1], [1, 2])) gpu_g = cpu_g.to('cuda:0') print(f"迁移后设备: {gpu_g.device}") ``` ### 错误原因分析 | 阶段 | 问题 | 解决方案 | |------|------|----------| | 安装 | DGL CPU 被安装 | 安装 `dgl-cudaXXX` 包 | | 环境 | PyTorch/DGL CUDA 本不匹配 | 严格本对齐 | | 驱动 | CUDA 未正确配置 | 验证 `nvcc --version` | | 系统 | 环境变量缺失 | 设置 `CUDA_HOME` | > 关键原则:**PyTorch CUDA 本、DGL CUDA 本、系统 CUDA 运行时三者必须一致**。不一致时优先对齐 PyTorch 和 DGL[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值