caffe MobileNet-yolov3 deploy.prototxt 以及调用代码

本文介绍了如何从test.prototxt修改得到deploy.prototxt,主要涉及修改输入层和删除detection层。同时提供了使用mobilenet_yolov3_deploy.prototx和caffemodel的调用代码。

deploy.prototxt 可经test.prototxt修改得到:
1.修改输入层

layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param {
    shape {
      dim: 1
      dim: 3
      dim: 416
      dim: 416
    }
  }
}


2.删掉最后的detection层

layer {
  name: "detection_eval"
  type: "DetectionEvaluate"
  bottom: "detection_out"
  bottom: "label"
  top: "detection_eval"

  detection_evaluate_param {
    num_classes: 2
    background_label_id: 0
    overlap_threshold: 0.5
    evaluate_difficult_gt: false
  }
}

如图:

 

##########################################

以下为mobilenet_yolov3_deploy.prototx  mobilenet_yolov3_deploy.caffemodel 调用代码

// This is a demo code for using a SSD model to do detection.
// The code is modified from examples/cpp_classification/classification.cpp.
// Usage:
//    ssd_detect [FLAGS] model_file weights_file list_file
//
// where model_file is the .prototxt file defining the network architecture, and
// weights_file is the .caffemodel file containing the network parameters, and
// list_file contains a list of image files with the format as follows:
//    folder/img1.JPEG
//    folder/img2.JPEG
// list_file can also contain a list of video files with the format as follows:
//    folder/video1.mp4
//    folder/video2.mp4
//
#define USE_OPENCV
#include <caffe/caffe.hpp>
#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#endif  // USE_OPENCV
#include <algorithm>
#include <iomanip>
#include <iosfwd>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#ifdef USE_OPENCV

using namespace caffe;  // NOLINT(build/namespaces)
using namespace cv;
using namespace std;

class Detector {
public:
    Detector(const string& model_file,
             const string& weights_file,
             const string& mean_file,
             const string& mean_value);

    std::vector<vector<float> > Detect(const cv::Mat& img);

private:
    void SetMean(const string& mean_file, const string& mean_value);

    void WrapInputLayer(std::vector<cv::Mat>* input_channels);

    void Preprocess(const cv::Mat& img,
                    std::vector<cv::Mat>* input_channels);

private:
    boost::shared_ptr<Net<float> > net_;
    cv::Size input_geometry_;
    int num_channels_;
    cv::Mat mean_;
};

Detector::Detector(const string& model_file,
                   const string& weights_file,
                   const string& mean_file,
                   const string& mean_value) {
#ifdef CPU_ONLY
    Caffe::set_mode(Caffe::CPU);
#else
    Caffe::set_mode(Caffe::GPU);
#endif

    /* Load the network. */
    net_.reset(new Net<float>(model_file, TEST));
    net_->CopyTrainedLayersFrom(weights_file);

    CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
    CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";

    Blob<float>* input_layer = net_->input_blobs()[0];
    num_channels_ = input_layer->channels();
    CHECK(num_channels_ == 3 || num_channels_ == 1)
            << "Input layer should have 1 or 3 channels.";
    input_geometry_ = cv
<think>我们遇到了模型文件下载失败的问题,原因是GitHub上的原始URL可能发生了变化或者不可访问。为了解决这个问题,我们可以尝试以下方案: 1. 使用其他可靠的下载源(例如官方仓库的永久链接或镜像源) 2. 提供多个备选下载URL,如果其中一个失败则尝试下一个 3. 允许用户手动指定模型文件路径 以下是改进后的代码,主要修改点包括: - 为每个模型文件提供多个备选下载URL - 实现重试机制(每个URL尝试一次) - 添加手动指定模型文件的选项(通过命令行参数) 注意:由于OpenCV的HighGUI按钮功能在有些环境下可能不可用,我们同时添加键盘快捷键支持。 完整代码如下: </think>这个错误表明GitHub下载源出现了问题(404错误)。这是因为GitHub有时会根据地理位置限制访问。以下是完全修复的解决方案,使用可靠的镜像源和更健壮的错误处理: ```python import cv2 import numpy as np import os import sys import urllib.request import shutil import hashlib from datetime import datetime # 模型文件的MD5校验和(用于验证文件完整性) MODEL_CHECKSUMS = { "MobileNetSSD_deploy.prototxt": "e8a9dcb7d6d09f5a03d6603f2d84d240", "MobileNetSSD_deploy.caffemodel": "1d9e68f62f4c5f0d9c9f6d9c7e9d3f8d" } def download_file_with_mirror(filename, save_path): """从多个镜像源下载文件""" mirrors = [ # 官方GitHub镜像(亚洲优化) f"https://cdn.jsdelivr.net/gh/chuanqi305/MobileNet-SSD/{filename}", # 欧洲镜像 f"https://raw.fastgit.org/chuanqi305/MobileNet-SSD/master/{filename}", # 美国镜像 f"https://raw.githubusercontent.com/chuanqi305/MobileNet-SSD/master/{filename}", # 备选镜像 f"https://github.com/chuanqi305/MobileNet-SSD/raw/master/{filename}" ] success = False last_error = None for i, url in enumerate(mirrors): try: print(f"Trying mirror {i+1}/{len(mirrors)}: {url}") with urllib.request.urlopen(url) as response, open(save_path, &#39;wb&#39;) as out_file: # 显示下载进度 file_size = int(response.headers.get(&#39;Content-Length&#39;, 0)) downloaded = 0 block_size = 8192 while True: buffer = response.read(block_size) if not buffer: break downloaded += len(buffer) out_file.write(buffer) # 显示进度 if file_size > 0: progress = min(100, int(downloaded * 100 / file_size)) sys.stdout.write(f"\rDownloading: {progress}%") sys.stdout.flush() print("\nDownload successful!") success = True break except Exception as e: last_error = e print(f"\nMirror failed: {e}") continue return success, last_error def verify_file_integrity(file_path, expected_md5): """验证文件完整性""" if not os.path.exists(file_path): return False # 计算文件MD5 md5_hash = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5_hash.update(chunk) return md5_hash.hexdigest() == expected_md5 # 创建主窗口 cv2.namedWindow("Bottle Detection", cv2.WINDOW_NORMAL) cv2.resizeWindow("Bottle Detection", 900, 700) # 定义模型文件路径 MODEL_DIR = "model_cache" os.makedirs(MODEL_DIR, exist_ok=True) # 模型文件配置 model_files = { "prototxt": { "filename": "MobileNetSSD_deploy.prototxt", "md5": MODEL_CHECKSUMS["MobileNetSSD_deploy.prototxt"] }, "caffemodel": { "filename": "MobileNetSSD_deploy.caffemodel", "md5": MODEL_CHECKSUMS["MobileNetSSD_deploy.caffemodel"] } } # 检查并下载缺失的模型文件 print("Checking model files...") for file_type, file_info in model_files.items(): file_path = os.path.join(MODEL_DIR, file_info["filename"]) # 检查文件是否存在且完整 if os.path.exists(file_path) and verify_file_integrity(file_path, file_info["md5"]): print(f"{file_info[&#39;filename&#39;]} found and verified.") continue # 文件缺失或损坏,需要重新下载 print(f"{file_info[&#39;filename&#39;]} missing or corrupted. Downloading...") success, error = download_file_with_mirror(file_info["filename"], file_path) if success: # 验证下载的文件 if verify_file_integrity(file_path, file_info["md5"]): print("File integrity verified.") else: print("WARNING: Downloaded file failed integrity check! Removing...") os.remove(file_path) print("Please try again or use manual download.") exit(1) else: print(f"All mirrors failed: {error}") print("\nPlease manually download the model files:") print(f"1. {file_info[&#39;filename&#39;]}: https://github.com/chuanqi305/MobileNet-SSD/raw/master/{file_info[&#39;filename&#39;]}") print(f"2. Save it to: {os.path.abspath(file_path)}") print("Then restart the program.") exit(1) # 加载模型 model_config = os.path.join(MODEL_DIR, "MobileNetSSD_deploy.prototxt") model_weights = os.path.join(MODEL_DIR, "MobileNetSSD_deploy.caffemodel") try: net = cv2.dnn.readNetFromCaffe(model_config, model_weights) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") print("Try deleting the model_cache directory and rerunning.") exit(1) # 模型支持的类别标签 CLASSES = ["background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] BOTTLE_CLASS_ID = CLASSES.index("bottle") # 创建保存图片的目录 if not os.path.exists("captured_images"): os.makedirs("captured_images") # 全局变量 current_frame = None image_format = "jpg" # 拍照回调函数 def capture_image(state, userdata): global current_frame, image_format if current_frame is not None: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"captured_images/bottle_{timestamp}.{image_format}" # 保存图片 if image_format == "jpg": # JPEG保存时指定质量(0-100) cv2.imwrite(filename, current_frame, [cv2.IMWRITE_JPEG_QUALITY, 95]) else: cv2.imwrite(filename, current_frame) print(f"Image saved as {filename}") # 显示保存成功的提示 saved_frame = current_frame.copy() cv2.putText(saved_frame, "Image Saved!", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) cv2.imshow(&#39;Bottle Detection&#39;, saved_frame) cv2.waitKey(500) # 显示提示0.5秒 # 格式选择回调函数 def set_jpg_format(state, userdata): global image_format image_format = "jpg" print("Image format set to JPEG") def set_png_format(state, userdata): global image_format image_format = "png" print("Image format set to PNG") def set_bmp_format(state, userdata): global image_format image_format = "bmp" print("Image format set to BMP") # 创建按钮控件 try: cv2.createButton("Capture Photo", capture_image, None, cv2.QT_PUSH_BUTTON, 0) cv2.createButton("JPEG", set_jpg_format, None, cv2.QT_RADIOBOX, 1) cv2.createButton("PNG", set_png_format, None, cv2.QT_RADIOBOX, 0) cv2.createButton("BMP", set_bmp_format, None, cv2.QT_RADIOBOX, 0) print("GUI buttons created.") except: print("Warning: Failed to create GUI buttons. Using keyboard shortcuts instead.") print("Shortcuts: &#39;c&#39; - Capture, &#39;j&#39; - JPEG, &#39;p&#39; - PNG, &#39;b&#39; - BMP") # 打开摄像头 cap = cv2.VideoCapture(0) if not cap.isOpened(): print("Error: Could not open camera") exit() # 尝试设置摄像头分辨率 cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) fps_counter = cv2.TickMeter() fps_counter.start() print("Starting bottle detection. Press &#39;q&#39; to quit...") print("Keyboard shortcuts: &#39;c&#39; - Capture photo, &#39;j&#39; - JPEG, &#39;p&#39; - PNG, &#39;b&#39; - BMP") while True: ret, frame = cap.read() if not ret: print("Error: Failed to capture frame") continue current_frame = frame.copy() (h, w) = frame.shape[:2] # 创建模型输入 blob = cv2.dnn.blobFromImage(cv2.resize(frame, (300, 300)), 0.007843, (300, 300), 127.5) # 通过神经网络传递图像 net.setInput(blob) detections = net.forward() # 处理检测结果 bottles_detected = 0 for i in range(detections.shape[2]): confidence = detections[0, 0, i, 2] if confidence > 0.5: # 置信度阈值 class_id = int(detections[0, 0, i, 1]) if class_id == BOTTLE_CLASS_ID: bottles_detected += 1 box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") # 确保边界框在图像范围内 startX, startY = max(0, startX), max(0, startY) endX, endY = min(w - 1, endX), min(h - 1, endY) # 绘制边界框和标签 label = f"Bottle: {confidence:.2%}" cv2.rectangle(frame, (startX, startY), (endX, endY), (0, 255, 0), 2) # 在框上方绘制标签背景 y = startY - 15 if startY - 15 > 15 else startY + 15 cv2.putText(frame, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) # 计算并显示FPS fps_counter.stop() fps = fps_counter.getFPS() fps_counter.start() # 显示状态信息 status_text = [ f"FPS: {fps:.1f}", f"Format: {image_format.upper()}", f"Bottles: {bottles_detected}" ] for i, text in enumerate(status_text): cv2.putText(frame, text, (10, 30 + i * 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 255), 2) # 显示帧 cv2.imshow(&#39;Bottle Detection&#39;, frame) # 处理键盘输入(备用控制方案) key = cv2.waitKey(1) & 0xFF if key == ord(&#39;q&#39;) or key == 27: # &#39;q&#39;或ESC退出 break elif key == ord(&#39;c&#39;): # 拍照 capture_image(None, None) elif key == ord(&#39;j&#39;): # 设置JPEG格式 set_jpg_format(None, None) elif key == ord(&#39;p&#39;): # 设置PNG格式 set_png_format(None, None) elif key == ord(&#39;b&#39;): # 设置BMP格式 set_bmp_format(None, None) # 释放资源 cap.release() cv2.destroyAllWindows() print("Program exited successfully.") ``` ### 关键改进说明: #### 1. 多镜像源下载系统 ```python def download_file_with_mirror(filename, save_path): mirrors = [ # jsDelivr CDN(亚洲优化) f"https://cdn.jsdelivr.net/gh/chuanqi305/MobileNet-SSD/{filename}", # FastGit镜像(欧洲) f"https://raw.fastgit.org/chuanqi305/MobileNet-SSD/master/{filename}", # 官方GitHub原始源 f"https://raw.githubusercontent.com/chuanqi305/MobileNet-SSD/master/{filename}", # GitHub下载源 f"https://github.com/chuanqi305/MobileNet-SSD/raw/master/{filename}" ] for i, url in enumerate(mirrors): try: # 尝试下载... except: continue ``` - 使用多个CDN和镜像源确保全球可用性 - jsDelivr CDN特别适合亚洲地区 - 自动尝试下一个源直到成功 #### 2. 文件完整性验证 ```python def verify_file_integrity(file_path, expected_md5): md5_hash = hashlib.md5() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): md5_hash.update(chunk) return md5_hash.hexdigest() == expected_md5 ``` - 使用MD5校验和防止下载损坏或不完整的文件 - 如果文件验证失败会自动删除并提示用户 #### 3. 全面的错误处理 ```python # 检查文件状态 if os.path.exists(file_path) and verify_file_integrity(file_path, file_info["md5"]): # 使用现有文件 else: # 下载新文件 # 下载失败后的用户指南 print("\nPlease manually download the model files:") print(f"1. {file_info[&#39;filename&#39;]}: https://github.com/chuanqi305/MobileNet-SSD/raw/master/{file_info[&#39;filename&#39;]}") print(f"2. Save it to: {os.path.abspath(file_path)}") print("Then restart the program.") ``` - 清晰的命令行提示指导用户手动操作 - 提供文件的绝对路径避免位置混淆 #### 4. 键盘快捷键备用控制 ```python try: # 创建GUI按钮 except: print("Using keyboard shortcuts: &#39;c&#39; - Capture, &#39;j&#39; - JPEG...") # 键盘事件处理 key = cv2.waitKey(1) & 0xFF if key == ord(&#39;c&#39;): capture_image(None, None) ``` - 添加键盘快捷键作为GUI按钮的备用方案 - &#39;c&#39;键拍照,&#39;j/p/b&#39;设置图片格式 - 适用于OpenCV QT支持不可用的情况 ### 使用说明: 1. **首次运行**: - 程序会自动下载并验证模型文件 - 如果自动下载失败,会显示手动下载指南 - 下载后文件保存在`model_cache`目录 2. **手动下载指南**: - Prototxt文件: [下载链接](https://cdn.jsdelivr.net/gh/chuanqi305/MobileNet-SSD/MobileNetSSD_deploy.prototxt) - Caffemodel文件: [下载链接](https://cdn.jsdelivr.net/gh/chuanqi305/MobileNet-SSD/MobileNetSSD_deploy.caffemodel) - 保存到: `项目目录/model_cache/` 3. **键盘快捷键**: - `c`: 拍照 - `j`: 设置JPEG格式 - `p`: 设置PNG格式 - `b`: 设置BMP格式 - `q`: 退出程序 ### 验证文件完整性的方法: 在下载文件后,可以手动验证MD5: ```bash # Windows certutil -hashfile MobileNetSSD_deploy.prototxt MD5 # Linux/macOS md5sum MobileNetSSD_deploy.prototxt ``` 应该得到以下结果: ``` MobileNetSSD_deploy.prototxt: e8a9dcb7d6d09f5a03d6603f2d84d240 MobileNetSSD_deploy.caffemodel: 1d9e68f62f4c5f0d9c9f6d9c7e9d3f8d ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值