Python项目-Python图像处理与识别系统开发

引言

在当今数字化时代,图像处理与识别技术已经成为人工智能领域中不可或缺的一部分。从自动驾驶汽车到医学影像诊断,从安防监控到工业质检,图像处理与识别技术无处不在。Python作为一种高级编程语言,凭借其简洁的语法、丰富的库和强大的生态系统,成为了图像处理与识别系统开发的首选工具。本文将详细介绍如何使用Python开发一个完整的图像处理与识别系统,包括环境搭建、核心技术、实现步骤以及实际应用案例。

一、环境搭建

1.1 Python环境配置

首先,我们需要配置一个适合图像处理与识别的Python环境。推荐使用Python 3.8或更高版本,并使用虚拟环境来管理依赖。

# 创建虚拟环境
python -m venv image_processing_env

# 激活虚拟环境(Windows)
image_processing_env\Scripts\activate

# 激活虚拟环境(Linux/Mac)
# source image_processing_env/bin/activate

1.2 核心库安装

图像处理与识别系统开发需要以下核心库:

# 安装核心库
pip install numpy opencv-python pillow matplotlib scikit-image tensorflow keras

# 如果需要GPU加速(推荐用于深度学习模型)
# pip install tensorflow-gpu

主要库的功能:

  • NumPy:提供高效的多维数组操作
  • OpenCV:强大的计算机视觉库
  • Pillow:Python图像处理库
  • Matplotlib:数据可视化库
  • scikit-image:图像处理算法集合
  • TensorFlow/Keras:深度学习框架,用于构建识别模型

二、图像处理基础

2.1 图像读取与显示

使用OpenCV读取和显示图像是图像处理的第一步:

import cv2
import matplotlib.pyplot as plt

# 读取图像
image = cv2.imread('sample.jpg')

# 转换颜色空间(OpenCV默认为BGR,而不是RGB)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 显示图像
plt.figure(figsize=(10, 8))
plt.imshow(image_rgb)
plt.axis('off')
plt.show()

2.2 基本图像处理操作

2.2.1 图像调整与裁剪
# 调整图像大小
resized_image = cv2.resize(image, (width, height))

# 图像裁剪
cropped_image = image[y:y+h, x:x+w]
2.2.2 图像滤波与增强
# 高斯模糊
blurred = cv2.GaussianBlur(image, (5, 5), 0)

# 锐化
kernel = np.array([[-1,-1,-1], [-1,9,-1], [-1,-1,-1]])
sharpened = cv2.filter2D(image, -1, kernel)

# 对比度增强
alpha = 1.5  # 对比度因子
beta = 0     # 亮度因子
enhanced = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
2.2.3 边缘检测
# Canny边缘检测
edges = cv2.Canny(image, 100, 200)

# Sobel边缘检测
sobelx = cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize=5)
sobely = cv2.Sobel(image, cv2.CV_64F, 0, 1, ksize=5)
sobel = cv2.magnitude(sobelx, sobely)

2.3 图像分割

图像分割是将图像分割成多个区域的过程,是许多高级图像处理任务的基础:

# 阈值分割
_, binary = cv2.threshold(gray_image, 127, 255, cv2.THRESH_BINARY)

# 自适应阈值分割
adaptive_thresh = cv2.adaptiveThreshold(gray_image, 255, 
                                        cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                        cv2.THRESH_BINARY, 11, 2)

# 分水岭算法
markers = cv2.watershed(image, markers)

三、特征提取与描述

3.1 传统特征提取方法

3.1.1 HOG特征

方向梯度直方图(HOG)是一种用于目标检测的特征描述符:

from skimage.feature import hog
from skimage import exposure

# 计算HOG特征
fd, hog_image = hog(gray_image, orientations=8, pixels_per_cell=(16, 16),
                    cells_per_block=(1, 1), visualize=True)

# 可视化HOG特征
hog_image_rescaled = exposure.rescale_intensity(hog_image, in_range=(0, 10))
3.1.2 SIFT特征

尺度不变特征变换(SIFT)是一种检测和描述图像局部特征的算法:

# 在OpenCV 4.x中使用SIFT
sift = cv2.SIFT_create()
keypoints, descriptors = sift.detectAndCompute(gray_image, None)

# 绘制关键点
sift_image = cv2.drawKeypoints(image, keypoints, None)

3.2 深度学习特征提取

使用预训练的卷积神经网络提取深度特征:

from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input
import numpy as np

# 加载预训练模型
model = VGG16(weights='imagenet', include_top=False)

# 预处理图像
img = image.load_img('sample.jpg', target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

# 提取特征
features = model.predict(x)

四、图像识别系统开发

4.1 系统架构设计

一个完整的图像识别系统通常包括以下组件:

  1. 图像获取模块:从摄像头、文件系统或网络获取图像
  2. 预处理模块:图像增强、噪声去除、尺寸调整等
  3. 特征提取模块:提取图像的关键特征
  4. 分类/识别模块:基于特征进行分类或识别
  5. 后处理模块:结果优化、可视化等
  6. 用户界面:与用户交互的界面

4.2 目标检测与识别实现

下面是一个使用YOLOv5实现目标检测的示例:

import torch

# 加载YOLOv5模型
model = torch.hub.load('ultralytics/yolov5', 'yolov5s')

# 进行检测
results = model(image)

# 显示结果
results.print()  
results.show()  

# 获取检测结果
detections = results.pandas().xyxy[0]  # 包含边界框坐标、置信度和类别

4.3 人脸识别系统实现

基于OpenCV和dlib实现人脸识别:

import cv2
import dlib
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# 加载人脸检测器和特征提取器
detector = dlib.get_frontal_face_detector()
sp = dlib.shape_predictor('shape_predictor_68_face_landmarks.dat')
facerec = dlib.face_recognition_model_v1('dlib_face_recognition_resnet_model_v1.dat')

# 人脸检测
faces = detector(image)

# 特征提取
face_descriptors = []
for face in faces:
    shape = sp(image, face)
    face_descriptor = facerec.compute_face_descriptor(image, shape)
    face_descriptors.append(np.array(face_descriptor))

# 人脸比对
def compare_faces(known_face_encodings, face_encoding_to_check, tolerance=0.6):
    similarities = cosine_similarity([face_encoding_to_check], known_face_encodings)[0]
    return list(similarities >= tolerance), similarities

4.4 文字识别(OCR)实现

使用Tesseract和Python实现OCR:

import pytesseract
from PIL import Image

# 设置Tesseract路径(Windows需要)
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'

# 读取图像
image = Image.open('text_image.jpg')

# 进行OCR
text = pytesseract.image_to_string(image, lang='chi_sim+eng')
print(text)

五、实际应用案例

5.1 智能停车场车牌识别系统

车牌识别是图像处理与识别技术的典型应用。以下是一个简化的车牌识别系统实现:

import cv2
import numpy as np
import pytesseract

def license_plate_recognition(image_path):
    # 读取图像
    image = cv2.imread(image_path)
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    # 边缘检测
    edges = cv2.Canny(gray, 170, 200)
    
    # 查找轮廓
    contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    
    # 筛选可能的车牌轮廓
    license_plates = []
    for contour in contours:
        x, y, w, h = cv2.boundingRect(contour)
        aspect_ratio = w / float(h)
        
        # 车牌的长宽比通常在2.0到5.5之间
        if 2.0 < aspect_ratio < 5.5 and w > 100 and h > 20:
            license_plate = gray[y:y+h, x:x+w]
            license_plates.append((license_plate, (x, y, w, h)))
    
    # 对每个可能的车牌区域进行OCR
    results = []
    for plate, coords in license_plates:
        # 增强对比度
        plate = cv2.equalizeHist(plate)
        
        # 二值化
        _, plate = cv2.threshold(plate, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        
        # OCR识别
        text = pytesseract.image_to_string(plate, config='--psm 7')
        
        # 过滤非车牌文本
        if len(text) >= 5:  # 车牌通常至少有5个字符
            results.append((text, coords))
    
    return results

5.2 医学影像辅助诊断系统

使用深度学习进行医学影像分析:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# 构建CNN模型
def build_medical_image_model(input_shape=(224, 224, 3), num_classes=2):
    model = Sequential([
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(pool_size=(2, 2)),
        Conv2D(128, (3, 3), activation='relu'),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(512, activation='relu'),
        Dropout(0.5),
        Dense(num_classes, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# 数据增强
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# 加载数据
train_generator = train_datagen.flow_from_directory(
    'medical_images/train',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical'
)

# 训练模型
model = build_medical_image_model()
model.fit(train_generator, epochs=20)

5.3 工业质检系统

使用图像处理技术进行产品缺陷检测:

def defect_detection(image_path, reference_path):
    # 读取待检测图像和参考图像
    test_image = cv2.imread(image_path)
    reference_image = cv2.imread(reference_path)
    
    # 确保两张图像大小一致
    test_image = cv2.resize(test_image, (reference_image.shape[1], reference_image.shape[0]))
    
    # 转换为灰度图
    test_gray = cv2.cvtColor(test_image, cv2.COLOR_BGR2GRAY)
    reference_gray = cv2.cvtColor(reference_image, cv2.COLOR_BGR2GRAY)
    
    # 计算差异
    diff = cv2.absdiff(reference_gray, test_gray)
    
    # 二值化差异图像
    _, thresh = cv2.threshold(diff, 30, 255, cv2.THRESH_BINARY)
    
    # 形态学操作去除噪声
    kernel = np.ones((5,5), np.uint8)
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
    
    # 查找缺陷区域
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 标记缺陷
    defect_image = test_image.copy()
    defects = []
    
    for contour in contours:
        area = cv2.contourArea(contour)
        if area > 100:  # 过滤小噪点
            x, y, w, h = cv2.boundingRect(contour)
            cv2.rectangle(defect_image, (x, y), (x+w, y+h), (0, 0, 255), 2)
            defects.append((x, y, w, h))
    
    return defect_image, defects

六、系统优化与部署

6.1 性能优化

图像处理与识别系统通常需要处理大量数据,因此性能优化至关重要:

  1. GPU加速:使用CUDA和cuDNN加速深度学习模型
  2. 模型量化:减少模型大小和计算复杂度
  3. 并行处理:利用多线程或多进程处理多个图像
  4. 内存管理:优化图像加载和处理过程中的内存使用
# 使用多进程并行处理图像
from multiprocessing import Pool

def process_image(image_path):
    # 图像处理代码
    pass

if __name__ == '__main__':
    image_paths = ['image1.jpg', 'image2.jpg', 'image3.jpg', ...]
    
    # 创建进程池
    with Pool(processes=4) as pool:
        results = pool.map(process_image, image_paths)

6.2 Web应用部署

使用Flask创建一个简单的图像识别Web应用:

from flask import Flask, request, jsonify, render_template
import cv2
import numpy as np
import os
from werkzeug.utils import secure_filename

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'uploads/'
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)

@app.route('/')
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'})
    
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'})
    
    # 保存上传的文件
    filename = secure_filename(file.filename)
    filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(filepath)
    
    # 处理图像
    image = cv2.imread(filepath)
    # 这里添加你的图像处理和识别代码
    result = process_image(image)
    
    return jsonify({'result': result})

if __name__ == '__main__':
    app.run(debug=True)

6.3 移动应用集成

使用TensorFlow Lite将模型部署到移动设备:

import tensorflow as tf

# 转换模型为TensorFlow Lite格式
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model_dir')
tflite_model = converter.convert()

# 保存模型
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

# 量化模型以减小大小
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()

with open('model_quantized.tflite', 'wb') as f:
    f.write(tflite_quantized_model)

七、未来发展趋势

7.1 深度学习新技术

  1. 自监督学习:减少对标注数据的依赖
  2. 小样本学习:使用少量样本进行训练
  3. 神经网络架构搜索:自动寻找最优网络结构
  4. 图神经网络:处理非欧几里得数据

7.2 跨领域融合

  1. 图像处理 + 自然语言处理:图像描述生成
  2. 图像处理 + 增强现实:实时场景理解
  3. 图像处理 + 机器人技术:视觉导航和操作

7.3 边缘计算

将图像处理与识别任务从云端迁移到边缘设备,降低延迟,提高隐私保护:

# 使用TensorFlow Lite在边缘设备上运行推理
import numpy as np
import tensorflow as tf

# 加载TFLite模型
interpreter = tf.lite.Interpreter(model_path="model.tflite")
interpreter.allocate_tensors()

# 获取输入和输出张量
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# 准备输入数据
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)

# 设置输入张量
interpreter.set_tensor(input_details[0]['index'], input_data)

# 运行推理
interpreter.invoke()

# 获取输出张量
output_data = interpreter.get_tensor(output_details[0]['index'])

八、总结

Python图像处理与识别系统开发是一个涵盖多个技术领域的综合性工程。通过合理的系统架构设计、先进的算法选择和优化的实现方法,我们可以构建出功能强大、性能优越的图像处理与识别系统。

参考资料

  1. OpenCV官方文档:https://docs.opencv.org/
  2. TensorFlow官方文档:https://www.tensorflow.org/api_docs
  3. scikit-image文档:https://scikit-image.org/docs/stable/
  4. 《Digital Image Processing》by Rafael C. Gonzalez and Richard E. Woods
  5. 《Deep Learning》by Ian Goodfellow, Yoshua Bengio, and Aaron Courville

源代码

Directory Content Summary

Source Directory: ./python_image_processing

Directory Structure

python_image_processing/
  README.md
  requirements.txt
  config/
    classification_config.json
    feature_extraction_config.json
    image_acquisition_config.json
  data/
  models/
  src/
    classification.py
    feature_extraction.py
    image_acquisition.py
    image_preprocessing.py
  tests/
    test_classification.py
    test_image_acquisition.py
  utils/
    config_loader.py
  {src,data,models,utils,config,tests}/

File Contents

README.md

# Python 图像处理与识别系统

本项目实现了一个完整的Python图像处理与识别系统,包含多个模块用于图像获取、预处理、特征提取、识别分类等功能。

## 项目结构

python_image_processing/
├── config/                 # 配置文件目录
│   └── image_acquisition_config.json  # 图像获取模块配置
├── data/                   # 数据目录,存放图像数据
├── models/                 # 模型目录,存放训练好的模型
├── src/                    # 源代码目录
│   └── image_acquisition.py  # 图像获取模块
├── tests/                  # 测试代码目录
│   └── test_image_acquisition.py  # 图像获取模块测试
├── utils/                  # 工具函数目录
│   └── config_loader.py    # 配置加载器
└── requirements.txt        # 项目依赖

图像获取模块

图像获取模块(image_acquisition.py)负责从不同来源获取图像数据,包括:

  1. 本地文件系统
  2. 摄像头实时捕获
  3. 网络URL
  4. 批量加载目录中的图像

主要功能

  • 从文件加载图像
  • 从摄像头捕获图像
  • 从URL加载图像
  • 批量加载图像
  • 保存图像
  • 颜色空间转换
  • 获取图像信息

使用示例

from src.image_acquisition import ImageAcquisition

# 创建图像获取实例
image_acq = ImageAcquisition()

# 从文件加载图像
image = image_acq.load_from_file("path/to/image.jpg")

# 从摄像头捕获图像
camera_image = image_acq.load_from_camera()

# 从URL加载图像
url_image = image_acq.load_from_url("https://example.com/image.jpg")

# 批量加载图像
images = image_acq.load_batch_from_directory("path/to/directory")

# 保存图像
image_acq.save_image(image, "output.jpg")

# 颜色空间转换
import cv2
gray_image = ImageAcquisition.convert_color_space(image, cv2.COLOR_BGR2GRAY)

# 获取图像信息
info = ImageAcquisition.get_image_info(image)
print(info)

配置系统

系统使用JSON格式的配置文件来管理各模块的参数。配置文件位于config目录下。

配置加载器

配置加载器(config_loader.py)提供了统一的接口来加载和访问配置文件。

from utils.config_loader import ConfigLoader

# 创建配置加载器
config_loader = ConfigLoader()

# 加载配置
config = config_loader.load_config("image_acquisition_config")

# 获取特定配置项
camera_id = config_loader.get_value("image_acquisition_config", "camera.default_camera_id", 0)

安装与运行

安装依赖

pip install -r requirements.txt

运行测试

python tests/test_image_acquisition.py

后续开发计划

  1. 图像预处理模块
  2. 特征提取模块
  3. 图像分类与识别模块
  4. 图像分割模块
  5. 用户界面

### requirements.txt

```text/plain
numpy>=1.19.0
opencv-python>=4.5.0
Pillow>=8.0.0
matplotlib>=3.3.0
requests>=2.25.0
scikit-image>=0.18.0
tensorflow>=2.4.0
torch>=1.8.0

config\classification_config.json

{
    "traditional_ml": {
        "svm": {
            "kernel": "rbf",
            "C": 1.0,
            "gamma": "scale",
            "probability": true,
            "class_weight": "balanced"
        },
        "knn": {
            "n_neighbors": 5,
            "weights": "uniform",
            "algorithm": "auto",
            "leaf_size": 30,
            "p": 2
        },
        "random_forest": {
            "n_estimators": 100,
            "max_depth": null,
            "min_samples_split": 2,
            "min_samples_leaf": 1,
            "max_features": "sqrt",
            "bootstrap": true,
            "class_weight": "balanced"
        }
    },
    "deep_learning": {
        "model_name": "vgg16",
        "input_shape": [224, 224, 3],
        "num_classes": 10,
        "batch_size": 32,
        "epochs": 10,
        "learning_rate": 0.001,
        "dropout_rate": 0.5,
        "fine_tune_layers": 0,
        "data_augmentation": true,
        "augmentation_params": {
            "rotation_range": 20,
            "width_shift_range": 0.2,
            "height_shift_range": 0.2,
            "shear_range": 0.2,
            "zoom_range": 0.2,
            "horizontal_flip": true,
            "fill_mode": "nearest"
        }
    },
    "object_detection": {
        "model_name": "yolov5s",
        "confidence_threshold": 0.5,
        "iou_threshold": 0.45,
        "max_detections": 100,
        "custom_model_path": "",
        "classes": []
    },
    "logging": {
        "level": "INFO",
        "file_path": "logs/classification.log",
        "console_output": true
    },
    "paths": {
        "models_dir": "models/",
        "results_dir": "results/",
        "temp_dir": "temp/"
    }
}

config\feature_extraction_config.json

{
    "hog": {
        "win_size": [64, 128],
        "block_size": [16, 16],
        "block_stride": [8, 8],
        "cell_size": [8, 8],
        "nbins": 9
    },
    "sift": {
        "n_features": 0,
        "n_octave_layers": 3,
        "contrast_threshold": 0.04,
        "edge_threshold": 10,
        "sigma": 1.6
    },
    "orb": {
        "n_features": 500,
        "scale_factor": 1.2,
        "n_levels": 8
    },
    "color_histogram": {
        "n_bins": 256,
        "ranges": [0, 256]
    },
    "lbp": {
        "radius": 1,
        "n_points": 8
    },
    "deep_learning": {
        "model_name": "vgg16",
        "input_shape": [224, 224],
        "pooling": "avg"
    },
    "feature_fusion": {
        "weights": {
            "hog": 0.3,
            "color_histogram": 0.2,
            "lbp": 0.2,
            "deep": 0.3
        }
    }
}

config\image_acquisition_config.json

{
    "default_image_size": {
        "width": 640,
        "height": 480
    },
    "camera": {
        "default_camera_id": 0,
        "frame_width": 1280,
        "frame_height": 720,
        "fps": 30,
        "auto_focus": true,
        "stabilization_frames": 5
    },
    "file_loading": {
        "supported_extensions": ["jpg", "jpeg", "png", "bmp", "tiff", "webp"],
        "max_file_size_mb": 10
    },
    "url_loading": {
        "timeout_seconds": 10,
        "max_retries": 3,
        "retry_delay_seconds": 1
    },
    "batch_processing": {
        "default_max_images": 100,
        "parallel_processing": true,
        "max_workers": 4
    },
    "saving": {
        "default_jpeg_quality": 95,
        "default_png_compression": 3,
        "preserve_exif": true
    }
}

src\classification.py

"""
分类/识别模块 (Classification/Recognition Module)

该模块负责基于提取的特征进行图像分类和识别,包括:
1. 传统机器学习分类方法 (SVM, KNN, 随机森林等)
2. 深度学习分类方法 (CNN, 迁移学习等)
3. 目标检测与识别 (YOLO, SSD, Faster R-CNN等)
4. 模型训练、评估和预测功能
"""

import cv2
import numpy as np
import logging
import os
import sys
import pickle
from typing import Union, Optional, Tuple, List, Dict, Any, Callable
import time
from pathlib import Path

# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# 导入自定义模块
from utils.config_loader import ConfigLoader

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class Classification:
    """分类/识别类,提供各种分类和识别方法"""
    
    def __init__(self, config_name: str = "classification_config"):
        """
        初始化分类/识别模块
        
        参数:
            config_name: 配置文件名称,默认为"classification_config"
        """
        # 加载配置
        self.config_loader = ConfigLoader()
        self.config = self.config_loader.load_config(config_name)
        
        if self.config:
            logger.info(f"成功加载分类/识别模块配置: {config_name}")
        else:
            logger.warning(f"无法加载配置: {config_name},使用默认配置")
            self.config = self._get_default_config()
        
        # 初始化分类器
        self.classifiers = {}
        self.deep_learning_model = None
        self.object_detection_model = None
        
        logger.info("分类/识别模块初始化完成")
    
    def _get_default_config(self) -> Dict[str, Any]:
        """获取默认配置"""
        return {
            "svm": {
                "kernel": "rbf",
                "C": 1.0,
                "gamma": "scale"
            },
            "knn": {
                "n_neighbors": 5,
                "weights": "uniform",
                "algorithm": "auto"
            },
            "random_forest": {
                "n_estimators": 100,
                "max_depth": None,
                "min_samples_split": 2
            },
            "deep_learning": {
                "model_name": "vgg16",
                "input_shape": [224, 224, 3],
                "num_classes": 10,
                "learning_rate": 0.001,
                "batch_size": 32,
                "epochs": 10
            },
            "object_detection": {
                "model_name": "yolov5s",
                "confidence_threshold": 0.5,
                "iou_threshold": 0.45
            }
        }
    
    # 传统机器学习分类方法
    def train_svm_classifier(self, features: np.ndarray, labels: np.ndarray, 
                           kernel: str = None, C: float = None, gamma: Union[str, float] = None) -> Any:
        """
        训练SVM分类器
        
        参数:
            features: 特征矩阵,形状为 (n_samples, n_features)
            labels: 标签向量,形状为 (n_samples,)
            kernel: 核函数类型,可选 'linear', 'poly', 'rbf', 'sigmoid'
            C: 正则化参数
            gamma: 核系数
            
        返回:
            训练好的SVM分类器
        """
        try:
            from sklearn.svm import SVC
            from sklearn.preprocessing import StandardScaler
            
            # 获取参数
            if kernel is None:
                kernel = self.config["svm"]["kernel"]
            
            if C is None:
                C = self.config["svm"]["C"]
            
            if gamma is None:
                gamma = self.config["svm"]["gamma"]
            
            # 标准化特征
            scaler = StandardScaler()
            scaled_features = scaler.fit_transform(features)
            
            # 创建并训练SVM分类器
            svm_classifier = SVC(kernel=kernel, C=C, gamma=gamma, probability=True)
            svm_classifier.fit(scaled_features, labels)
            
            # 保存分类器和缩放器
            self.classifiers["svm"] = {
                "classifier": svm_classifier,
                "scaler": scaler
            }
            
            logger.info(f"SVM分类器训练完成: 内核={kernel}, C={C}, gamma={gamma}")
            return svm_classifier
            
        except ImportError as e:
            logger.error(f"无法导入scikit-learn: {e}")
            logger.warning("请安装scikit-learn: pip install scikit-learn")
            return None
        except Exception as e:
            logger.error(f"训练SVM分类器时出错: {e}")
            return None
    
    def train_knn_classifier(self, features: np.ndarray, labels: np.ndarray,
                           n_neighbors: int = None, weights: str = None, algorithm: str = None) -> Any:
        """
        训练KNN分类器
        
        参数:
            features: 特征矩阵,形状为 (n_samples, n_features)
            labels: 标签向量,形状为 (n_samples,)
            n_neighbors: 近邻数量
            weights: 权重类型,可选 'uniform', 'distance'
            algorithm: 算法类型,可选 'auto', 'ball_tree', 'kd_tree', 'brute'
            
        返回:
            训练好的KNN分类器
        """
        try:
            from sklearn.neighbors import KNeighborsClassifier
            from sklearn.preprocessing import StandardScaler
            
            # 获取参数
            if n_neighbors is None:
                n_neighbors = self.config["knn"]["n_neighbors"]
            
            if weights is None:
                weights = self.config["knn"]["weights"]
            
            if algorithm is None:
                algorithm = self.config["knn"]["algorithm"]
            
            # 标准化特征
            scaler = StandardScaler()
            scaled_features = scaler.fit_transform(features)
            
            # 创建并训练KNN分类器
            knn_classifier = KNeighborsClassifier(n_neighbors=n_neighbors, weights=weights, algorithm=algorithm)
            knn_classifier.fit(scaled_features, labels)
            
            # 保存分类器和缩放器
            self.classifiers["knn"] = {
                "classifier": knn_classifier,
                "scaler": scaler
            }
            
            logger.info(f"KNN分类器训练完成: n_neighbors={n_neighbors}, weights={weights}, algorithm={algorithm}")
            return knn_classifier
            
        except ImportError as e:
            logger.error(f"无法导入scikit-learn: {e}")
            logger.warning("请安装scikit-learn: pip install scikit-learn")
            return None
        except Exception as e:
            logger.error(f"训练KNN分类器时出错: {e}")
            return None
    
    def train_random_forest_classifier(self, features: np.ndarray, labels: np.ndarray,
                                     n_estimators: int = None, max_depth: Optional[int] = None,
                                     min_samples_split: int = None) -> Any:
        """
        训练随机森林分类器
        
        参数:
            features: 特征矩阵,形状为 (n_samples, n_features)
            labels: 标签向量,形状为 (n_samples,)
            n_estimators: 树的数量
            max_depth: 树的最大深度
            min_samples_split: 分裂内部节点所需的最小样本数
            
        返回:
            训练好的随机森林分类器
        """
        try:
            from sklearn.ensemble import RandomForestClassifier
            from sklearn.preprocessing import StandardScaler
            
            # 获取参数
            if n_estimators is None:
                n_estimators = self.config["random_forest"]["n_estimators"]
            
            if max_depth is None:
                max_depth = self.config["random_forest"]["max_depth"]
            
            if min_samples_split is None:
                min_samples_split = self.config["random_forest"]["min_samples_split"]
            
            # 标准化特征
            scaler = StandardScaler()
            scaled_features = scaler.fit_transform(features)
            
            # 创建并训练随机森林分类器
            rf_classifier = RandomForestClassifier(
                n_estimators=n_estimators,
                max_depth=max_depth,
                min_samples_split=min_samples_split,
                random_state=42
            )
            rf_classifier.fit(scaled_features, labels)
            
            # 保存分类器和缩放器
            self.classifiers["random_forest"] = {
                "classifier": rf_classifier,
                "scaler": scaler
            }
            
            logger.info(f"随机森林分类器训练完成: n_estimators={n_estimators}, max_depth={max_depth}")
            return rf_classifier
            
        except ImportError as e:
            logger.error(f"无法导入scikit-learn: {e}")
            logger.warning("请安装scikit-learn: pip install scikit-learn")
            return None
        except Exception as e:
            logger.error(f"训练随机森林分类器时出错: {e}")
            return None
    
    def predict(self, features: np.ndarray, classifier_type: str = "svm") -> Tuple[np.ndarray, np.ndarray]:
        """
        使用指定的分类器进行预测
        
        参数:
            features: 特征矩阵,形状为 (n_samples, n_features)
            classifier_type: 分类器类型,可选 "svm", "knn", "random_forest"
            
        返回:
            预测标签和预测概率
        """
        if classifier_type not in self.classifiers:
            logger.error(f"分类器 {classifier_type} 未训练")
            return np.array([]), np.array([])
        
        try:
            # 获取分类器和缩放器
            classifier = self.classifiers[classifier_type]["classifier"]
            scaler = self.classifiers[classifier_type]["scaler"]
            
            # 标准化特征
            scaled_features = scaler.transform(features)
            
            # 预测
            predictions = classifier.predict(scaled_features)
            
            # 获取预测概率(如果支持)
            try:
                probabilities = classifier.predict_proba(scaled_features)
            except:
                probabilities = np.zeros((len(predictions), 1))
            
            logger.info(f"使用 {classifier_type} 分类器进行预测: 样本数 {len(predictions)}")
            return predictions, probabilities
            
        except Exception as e:
            logger.error(f"预测时出错: {e}")
            return np.array([]), np.array([])
    
    def evaluate_classifier(self, features: np.ndarray, labels: np.ndarray, 
                           classifier_type: str = "svm") -> Dict[str, float]:
        """
        评估分类器性能
        
        参数:
            features: 特征矩阵,形状为 (n_samples, n_features)
            labels: 标签向量,形状为 (n_samples,)
            classifier_type: 分类器类型,可选 "svm", "knn", "random_forest"
            
        返回:
            包含评估指标的字典
        """
        if classifier_type not in self.classifiers:
            logger.error(f"分类器 {classifier_type} 未训练")
            return {}
        
        try:
            from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
            
            # 预测
            predictions, _ = self.predict(features, classifier_type)
            
            if len(predictions) == 0:
                return {}
            
            # 计算评估指标
            accuracy = accuracy_score(labels, predictions)
            
            # 对于多分类问题,使用macro平均
            precision = precision_score(labels, predictions, average='macro', zero_division=0)
            recall = recall_score(labels, predictions, average='macro', zero_division=0)
            f1 = f1_score(labels, predictions, average='macro', zero_division=0)
            
            # 计算混淆矩阵
            cm = confusion_matrix(labels, predictions)
            
            metrics = {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1_score": f1,
                "confusion_matrix": cm
            }
            
            logger.info(f"评估 {classifier_type} 分类器: 准确率={accuracy:.4f}, F1分数={f1:.4f}")
            return metrics
            
        except ImportError as e:
            logger.error(f"无法导入scikit-learn: {e}")
            return {}
        except Exception as e:
            logger.error(f"评估分类器时出错: {e}")
            return {}
    
    def save_classifier(self, classifier_type: str, file_path: str) -> bool:
        """
        保存分类器到文件
        
        参数:
            classifier_type: 分类器类型,可选 "svm", "knn", "random_forest"
            file_path: 保存路径
            
        返回:
            是否保存成功
        """
        if classifier_type not in self.classifiers:
            logger.error(f"分类器 {classifier_type} 未训练")
            return False
        
        try:
            # 创建目录(如果不存在)
            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            
            # 保存分类器
            with open(file_path, 'wb') as f:
                pickle.dump(self.classifiers[classifier_type], f)
            
            logger.info(f"分类器 {classifier_type} 已保存到 {file_path}")
            return True
            
        except Exception as e:
            logger.error(f"保存分类器时出错: {e}")
            return False
    
    def load_classifier(self, classifier_type: str, file_path: str) -> bool:
        """
        从文件加载分类器
        
        参数:
            classifier_type: 分类器类型,可选 "svm", "knn", "random_forest"
            file_path: 加载路径
            
        返回:
            是否加载成功
        """
        try:
            # 加载分类器
            with open(file_path, 'rb') as f:
                self.classifiers[classifier_type] = pickle.load(f)
            
            logger.info(f"分类器 {classifier_type} 已从 {file_path} 加载")
            return True
            
        except Exception as e:
            logger.error(f"加载分类器时出错: {e}")
            return False
    
    # 深度学习分类方法
    def _create_deep_learning_model(self, model_name: str = None, num_classes: int = None):
        """
        创建深度学习模型
        
        参数:
            model_name: 模型名称
            num_classes: 类别数量
        """
        if model_name is None:
            model_name = self.config["deep_learning"]["model_name"]
        
        if num_classes is None:
            num_classes = self.config["deep_learning"]["num_classes"]
        
        try:
            # 导入TensorFlow
            import tensorflow as tf
            from tensorflow.keras.applications import (
                VGG16, VGG19, ResNet50, InceptionV3, MobileNetV2, EfficientNetB0
            )
            from tensorflow.keras.models import Model
            from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
            
            # 设置日志级别,抑制TensorFlow警告
            tf.get_logger().setLevel('ERROR')
            
            # 获取输入形状
            input_shape = tuple(self.config["deep_learning"]["input_shape"])
            
            # 根据模型名称加载预训练模型
            if model_name.lower() == "vgg16":
                base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
            elif model_name.lower() == "vgg19":
                base_model = VGG19(weights='imagenet', include_top=False, input_shape=input_shape)
            elif model_name.lower() == "resnet50":
                base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape)
            elif model_name.lower() == "inceptionv3":
                base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=input_shape)
            elif model_name.lower() == "mobilenetv2":
                base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape)
            elif model_name.lower() == "efficientnetb0":
                base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape)
            else:
                logger.warning(f"未知的模型名称: {model_name},使用VGG16")
                base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
            
            # 冻结基础模型层
            for layer in base_model.layers:
                layer.trainable = False
            
            # 添加分类头
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(1024, activation='relu')(x)
            x = Dropout(0.5)(x)
            predictions = Dense(num_classes, activation='softmax')(x)
            
            # 创建完整模型
            model = Model(inputs=base_model.input, outputs=predictions)
            
            # 编译模型
            learning_rate = self.config["deep_learning"]["learning_rate"]
            model.compile(
                optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                loss='categorical_crossentropy',
                metrics=['accuracy']
            )
            
            self.deep_learning_model = model
            
            logger.info(f"成功创建深度学习模型: {model_name}, 类别数: {num_classes}")
            
        except ImportError as e:
            logger.error(f"无法导入TensorFlow: {e}")
            logger.warning("请安装TensorFlow: pip install tensorflow")
            self.deep_learning_model = None
        except Exception as e:
            logger.error(f"创建深度学习模型时出错: {e}")
            self.deep_learning_model = None
    
    def train_deep_learning_model(self, train_data: Tuple[np.ndarray, np.ndarray], 
                                validation_data: Optional[Tuple[np.ndarray, np.ndarray]] = None,
                                model_name: str = None, num_classes: int = None,
                                batch_size: int = None, epochs: int = None) -> Any:
        """
        训练深度学习模型
        
        参数:
            train_data: 训练数据,形式为 (images, labels)
            validation_data: 验证数据,形式为 (images, labels)
            model_name: 模型名称
            num_classes: 类别数量
            batch_size: 批量大小
            epochs: 训练轮数
            
        返回:
            训练历史
        """
        # 如果模型未创建,创建模型
        if self.deep_learning_model is None:
            self._create_deep_learning_model(model_name, num_classes)
        
        # 如果模型创建失败,返回None
        if self.deep_learning_model is None:
            logger.error("深度学习模型未创建,无法训练")
            return None
        
        try:
            # 导入TensorFlow
            import tensorflow as tf
            from tensorflow.keras.utils import to_categorical
            
            # 获取参数
            if batch_size is None:
                batch_size = self.config["deep_learning"]["batch_size"]
            
            if epochs is None:
                epochs = self.config["deep_learning"]["epochs"]
            
            # 解包训练数据
            train_images, train_labels = train_data
            
            # 转换标签为one-hot编码
            train_labels_categorical = to_categorical(train_labels, num_classes)
            
            # 准备验证数据
            if validation_data is not None:
                val_images, val_labels = validation_data
                val_labels_categorical = to_categorical(val_labels, num_classes)
                validation_data = (val_images, val_labels_categorical)
            
            # 训练模型
            history = self.deep_learning_model.fit(
                train_images, train_labels_categorical,
                batch_size=batch_size,
                epochs=epochs,
                validation_data=validation_data,
                verbose=1
            )
            
            logger.info(f"深度学习模型训练完成: epochs={epochs}, batch_size={batch_size}")
            return history
            
        except ImportError as e:
            logger.error(f"无法导入TensorFlow: {e}")
            return None
        except Exception as e:
            logger.error(f"训练深度学习模型时出错: {e}")
            return None
    
    def predict_with_deep_learning(self, images: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """
        使用深度学习模型进行预测
        
        参数:
            images: 图像数组,形状为 (n_samples, height, width, channels)
            
        返回:
            预测标签和预测概率
        """
        if self.deep_learning_model is None:
            logger.error("深度学习模型未创建,无法预测")
            return np.array([]), np.array([])
        
        try:
            # 预测
            probabilities = self.deep_learning_model.predict(images)
            
            # 获取预测标签
            predictions = np.argmax(probabilities, axis=1)
            
            logger.info(f"使用深度学习模型进行预测: 样本数 {len(predictions)}")
            return predictions, probabilities
            
        except Exception as e:
            logger.error(f"深度学习预测时出错: {e}")
            return np.array([]), np.array([])
    
    def save_deep_learning_model(self, model_path: str, weights_path: Optional[str] = None) -> bool:
        """
        保存深度学习模型
        
        参数:
            model_path: 模型保存路径
            weights_path: 权重保存路径,如果为None,则只保存整个模型
            
        返回:
            是否保存成功
        """
        if self.deep_learning_model is None:
            logger.error("深度学习模型未创建,无法保存")
            return False
        
        try:
            # 创建目录(如果不存在)
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            
            # 保存模型
            if weights_path is None:
                self.deep_learning_model.save(model_path)
                logger.info(f"深度学习模型已保存到 {model_path}")
            else:
                # 保存模型结构
                model_json = self.deep_learning_model.to_json()
                with open(model_path, "w") as json_file:
                    json_file.write(model_json)
                
                # 保存权重
                os.makedirs(os.path.dirname(weights_path), exist_ok=True)
                self.deep_learning_model.save_weights(weights_path)
                
                logger.info(f"深度学习模型结构已保存到 {model_path},权重已保存到 {weights_path}")
            
            return True
            
        except Exception as e:
            logger.error(f"保存深度学习模型时出错: {e}")
            return False
    
    def load_deep_learning_model(self, model_path: str, weights_path: Optional[str] = None) -> bool:
        """
        加载深度学习模型
        
        参数:
            model_path: 模型加载路径
            weights_path: 权重加载路径,如果为None,则加载整个模型
            
        返回:
            是否加载成功
        """
        try:
            # 导入TensorFlow
            import tensorflow as tf
            
            # 加载模型
            if weights_path is None:
                self.deep_learning_model = tf.keras.models.load_model(model_path)
                logger.info(f"深度学习模型已从 {model_path} 加载")
            else:
                # 加载模型结构
                with open(model_path, "r") as json_file:
                    model_json = json_file.read()
                
                self.deep_learning_model = tf.keras.models.model_from_json(model_json)
                
                # 加载权重
                self.deep_learning_model.load_weights(weights_path)
                
                # 编译模型
                learning_rate = self.config["deep_learning"]["learning_rate"]
                self.deep_learning_model.compile(
                    optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                    loss='categorical_crossentropy',
                    metrics=['accuracy']
                )
                
                logger.info(f"深度学习模型结构已从 {model_path} 加载,权重已从 {weights_path} 加载")
            
            return True
            
        except ImportError as e:
            logger.error(f"无法导入TensorFlow: {e}")
            return False
        except Exception as e:
            logger.error(f"加载深度学习模型时出错: {e}")
            return False

    # 目标检测方法
    def _load_object_detection_model(self, model_name: str = None):
        """
        加载目标检测模型
        
        参数:
            model_name: 模型名称,可选 "yolov5s", "yolov5m", "yolov5l", "yolov5x"
        """
        if model_name is None:
            model_name = self.config["object_detection"]["model_name"]
        
        try:
            # 尝试导入torch和YOLOv5
            import torch
            
            # 设置日志级别,抑制警告
            import warnings
            warnings.filterwarnings("ignore")
            
            # 加载YOLOv5模型
            self.object_detection_model = torch.hub.load('ultralytics/yolov5', model_name)
            
            # 设置置信度阈值
            confidence_threshold = self.config["object_detection"]["confidence_threshold"]
            self.object_detection_model.conf = confidence_threshold
            
            # 设置IoU阈值
            iou_threshold = self.config["object_detection"]["iou_threshold"]
            self.object_detection_model.iou = iou_threshold
            
            logger.info(f"成功加载目标检测模型: {model_name}")
            
        except ImportError as e:
            logger.error(f"无法导入PyTorch或YOLOv5: {e}")
            logger.warning("请安装PyTorch和YOLOv5: pip install torch torchvision")
            self.object_detection_model = None
        except Exception as e:
            logger.error(f"加载目标检测模型时出错: {e}")
            self.object_detection_model = None
    
    def detect_objects(self, image: np.ndarray, model_name: str = None, 
                      confidence_threshold: float = None) -> Dict[str, Any]:
        """
        检测图像中的目标
        
        参数:
            image: 输入图像
            model_name: 模型名称
            confidence_threshold: 置信度阈值
            
        返回:
            检测结果字典,包含边界框、类别和置信度
        """
        # 如果模型未加载,加载模型
        if self.object_detection_model is None:
            self._load_object_detection_model(model_name)
        
        # 如果模型加载失败,返回空结果
        if self.object_detection_model is None:
            logger.error("目标检测模型未加载,无法检测")
            return {"boxes": [], "labels": [], "scores": []}
        
        try:
            # 设置置信度阈值(如果指定)
            if confidence_threshold is not None:
                self.object_detection_model.conf = confidence_threshold
            
            # 执行检测
            results = self.object_detection_model(image)
            
            # 提取结果
            result_pandas = results.pandas().xyxy[0]  # 获取第一张图像的结果
            
            # 提取边界框、类别和置信度
            boxes = []
            labels = []
            scores = []
            
            for _, row in result_pandas.iterrows():
                box = [row['xmin'], row['ymin'], row['xmax'], row['ymax']]
                boxes.append(box)
                labels.append(row['name'])
                scores.append(row['confidence'])
            
            detection_results = {
                "boxes": boxes,
                "labels": labels,
                "scores": scores
            }
            
            logger.info(f"检测到 {len(boxes)} 个目标")
            return detection_results
            
        except Exception as e:
            logger.error(f"目标检测时出错: {e}")
            return {"boxes": [], "labels": [], "scores": []}
    
    def draw_detection_results(self, image: np.ndarray, detection_results: Dict[str, Any]) -> np.ndarray:
        """
        在图像上绘制检测结果
        
        参数:
            image: 输入图像
            detection_results: 检测结果字典
            
        返回:
            绘制结果后的图像
        """
        # 创建副本,避免修改原图像
        result = image.copy()
        
        # 提取检测结果
        boxes = detection_results["boxes"]
        labels = detection_results["labels"]
        scores = detection_results["scores"]
        
        # 生成随机颜色
        import random
        colors = {}
        
        # 绘制每个检测结果
        for i, box in enumerate(boxes):
            # 获取边界框坐标
            x_min, y_min, x_max, y_max = [int(coord) for coord in box]
            
            # 获取类别和置信度
            label = labels[i]
            score = scores[i]
            
            # 为每个类别分配一个固定颜色
            if label not in colors:
                colors[label] = (
                    random.randint(0, 255),
                    random.randint(0, 255),
                    random.randint(0, 255)
                )
            color = colors[label]
            
            # 绘制边界框
            cv2.rectangle(result, (x_min, y_min), (x_max, y_max), color, 2)
            
            # 准备标签文本
            text = f"{label}: {score:.2f}"
            
            # 计算文本大小
            (text_width, text_height), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
            
            # 绘制文本背景
            cv2.rectangle(result, (x_min, y_min - text_height - 10), (x_min + text_width, y_min), color, -1)
            
            # 绘制文本
            cv2.putText(result, text, (x_min, y_min - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 2)
        
        logger.info(f"绘制 {len(boxes)} 个检测结果")
        return result
    
    def custom_object_detection(self, image: np.ndarray, custom_model_path: str, 
                              confidence_threshold: float = 0.5) -> Dict[str, Any]:
        """
        使用自定义YOLOv5模型进行目标检测
        
        参数:
            image: 输入图像
            custom_model_path: 自定义模型路径
            confidence_threshold: 置信度阈值
            
        返回:
            检测结果字典,包含边界框、类别和置信度
        """
        try:
            # 导入torch
            import torch
            
            # 加载自定义模型
            model = torch.hub.load('ultralytics/yolov5', 'custom', path=custom_model_path)
            
            # 设置置信度阈值
            model.conf = confidence_threshold
            
            # 设置IoU阈值
            model.iou = self.config["object_detection"]["iou_threshold"]
            
            # 执行检测
            results = model(image)
            
            # 提取结果
            result_pandas = results.pandas().xyxy[0]  # 获取第一张图像的结果
            
            # 提取边界框、类别和置信度
            boxes = []
            labels = []
            scores = []
            
            for _, row in result_pandas.iterrows():
                box = [row['xmin'], row['ymin'], row['xmax'], row['ymax']]
                boxes.append(box)
                labels.append(row['name'])
                scores.append(row['confidence'])
            
            detection_results = {
                "boxes": boxes,
                "labels": labels,
                "scores": scores
            }
            
            logger.info(f"使用自定义模型检测到 {len(boxes)} 个目标")
            return detection_results
            
        except ImportError as e:
            logger.error(f"无法导入PyTorch或YOLOv5: {e}")
            return {"boxes": [], "labels": [], "scores": []}
        except Exception as e:
            logger.error(f"使用自定义模型进行目标检测时出错: {e}")
            return {"boxes": [], "labels": [], "scores": []}
    
    # 综合分类和识别方法
    def classify_image(self, image: np.ndarray, features: Optional[np.ndarray] = None, 
                      method: str = "svm", feature_extractor = None) -> Dict[str, Any]:
        """
        综合分类图像
        
        参数:
            image: 输入图像
            features: 预先提取的特征,如果为None,则使用feature_extractor提取
            method: 分类方法,可选 "svm", "knn", "random_forest", "deep_learning"
            feature_extractor: 特征提取器,用于提取特征
            
        返回:
            分类结果字典,包含预测标签、预测概率和处理时间
        """
        start_time = time.time()
        
        # 如果没有提供特征且有特征提取器,则提取特征
        if features is None and feature_extractor is not None:
            if method == "deep_learning":
                # 深度学习方法直接使用图像
                features = image
            else:
                # 传统方法需要提取特征
                features = feature_extractor.extract_features(image)
                
                # 如果特征是字典,则使用HOG特征
                if isinstance(features, dict) and "hog" in features:
                    features = features["hog"]
        
        # 根据方法进行分类
        if method == "deep_learning":
            # 使用深度学习模型
            predictions, probabilities = self.predict_with_deep_learning(features)
        else:
            # 使用传统机器学习模型
            predictions, probabilities = self.predict(features, method)
        
        # 计算处理时间
        processing_time = time.time() - start_time
        
        # 构建结果字典
        result = {
            "predictions": predictions,
            "probabilities": probabilities,
            "processing_time": processing_time
        }
        
        logger.info(f"使用 {method} 方法分类图像: 处理时间 {processing_time:.4f} 秒")
        return result
    
    def recognize_objects(self, image: np.ndarray, method: str = "yolo") -> Dict[str, Any]:
        """
        识别图像中的目标
        
        参数:
            image: 输入图像
            method: 识别方法,可选 "yolo", "custom"
            
        返回:
            识别结果字典,包含检测结果和处理时间
        """
        start_time = time.time()
        
        # 根据方法进行目标检测
        if method == "yolo":
            # 使用YOLOv5模型
            detection_results = self.detect_objects(image)
        elif method == "custom":
            # 使用自定义模型(需要指定路径)
            custom_model_path = "path/to/custom/model.pt"  # 这里需要替换为实际路径
            detection_results = self.custom_object_detection(image, custom_model_path)
        else:
            logger.warning(f"未知的识别方法: {method},使用YOLOv5")
            detection_results = self.detect_objects(image)
        
        # 计算处理时间
        processing_time = time.time() - start_time
        
        # 构建结果字典
        result = {
            "detection_results": detection_results,
            "processing_time": processing_time
        }
        
        logger.info(f"使用 {method} 方法识别目标: 处理时间 {processing_time:.4f} 秒")
        return result

src\feature_extraction.py

"""
特征提取模块 (Feature Extraction Module)

该模块负责从图像中提取特征,包括:
1. 传统特征提取 (HOG, SIFT, ORB等)
2. 颜色特征提取 (颜色直方图, 颜色矩等)
3. 纹理特征提取 (LBP, Gabor等)
4. 深度学习特征提取 (CNN预训练模型)
"""

import cv2
import numpy as np
import logging
from typing import Union, Optional, Tuple, List, Dict, Any
import os
import sys

# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# 导入自定义模块
from utils.config_loader import ConfigLoader

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class FeatureExtraction:
    """特征提取类,提供各种特征提取方法"""
    
    def __init__(self, config_name: str = "feature_extraction_config"):
        """
        初始化特征提取模块
        
        参数:
            config_name: 配置文件名称,默认为"feature_extraction_config"
        """
        # 加载配置
        self.config_loader = ConfigLoader()
        self.config = self.config_loader.load_config(config_name)
        
        if self.config:
            logger.info(f"成功加载特征提取模块配置: {config_name}")
        else:
            logger.warning(f"无法加载配置: {config_name},使用默认配置")
            self.config = self._get_default_config()
        
        # 初始化特征提取器
        self._init_feature_extractors()
        
        logger.info("特征提取模块初始化完成")
    
    def _get_default_config(self) -> Dict[str, Any]:
        """获取默认配置"""
        return {
            "hog": {
                "win_size": (64, 128),
                "block_size": (16, 16),
                "block_stride": (8, 8),
                "cell_size": (8, 8),
                "nbins": 9
            },
            "sift": {
                "n_features": 0,
                "n_octave_layers": 3,
                "contrast_threshold": 0.04,
                "edge_threshold": 10,
                "sigma": 1.6
            },
            "orb": {
                "n_features": 500,
                "scale_factor": 1.2,
                "n_levels": 8
            },
            "color_histogram": {
                "n_bins": 256,
                "ranges": [0, 256]
            },
            "lbp": {
                "radius": 1,
                "n_points": 8
            },
            "deep_learning": {
                "model_name": "vgg16",
                "input_shape": (224, 224),
                "pooling": "avg"
            }
        }
    
    def _init_feature_extractors(self):
        """初始化特征提取器"""
        # 初始化HOG特征提取器
        win_size = tuple(self.config["hog"]["win_size"])
        block_size = tuple(self.config["hog"]["block_size"])
        block_stride = tuple(self.config["hog"]["block_stride"])
        cell_size = tuple(self.config["hog"]["cell_size"])
        nbins = self.config["hog"]["nbins"]
        
        self.hog = cv2.HOGDescriptor(win_size, block_size, block_stride, cell_size, nbins)
        
        # 初始化SIFT特征提取器
        self.sift = cv2.SIFT_create(
            nfeatures=self.config["sift"]["n_features"],
            nOctaveLayers=self.config["sift"]["n_octave_layers"],
            contrastThreshold=self.config["sift"]["contrast_threshold"],
            edgeThreshold=self.config["sift"]["edge_threshold"],
            sigma=self.config["sift"]["sigma"]
        )
        
        # 初始化ORB特征提取器
        self.orb = cv2.ORB_create(
            nfeatures=self.config["orb"]["n_features"],
            scaleFactor=self.config["orb"]["scale_factor"],
            nlevels=self.config["orb"]["n_levels"]
        )
        
        # 深度学习模型将在需要时加载,以节省内存
        self.deep_learning_model = None
        
    # 传统特征提取方法
    def extract_hog_features(self, image: np.ndarray, win_size: Tuple[int, int] = None) -> np.ndarray:
        """
        提取HOG特征
        
        参数:
            image: 输入图像
            win_size: 窗口大小,如果指定则调整图像大小
            
        返回:
            HOG特征向量
        """
        # 转换为灰度图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 调整图像大小
        if win_size is None:
            win_size = self.config["hog"]["win_size"]
        
        resized = cv2.resize(gray, win_size)
        
        # 提取HOG特征
        features = self.hog.compute(resized)
        
        logger.info(f"提取HOG特征: 特征维度 {features.shape}")
        return features
    
    def extract_sift_features(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
        """
        提取SIFT特征
        
        参数:
            image: 输入图像
            
        返回:
            关键点列表和特征描述符
        """
        # 转换为灰度图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 提取SIFT特征
        keypoints, descriptors = self.sift.detectAndCompute(gray, None)
        
        logger.info(f"提取SIFT特征: 检测到 {len(keypoints)} 个关键点")
        return keypoints, descriptors
    
    def extract_orb_features(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
        """
        提取ORB特征
        
        参数:
            image: 输入图像
            
        返回:
            关键点列表和特征描述符
        """
        # 转换为灰度图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 提取ORB特征
        keypoints, descriptors = self.orb.detectAndCompute(gray, None)
        
        logger.info(f"提取ORB特征: 检测到 {len(keypoints)} 个关键点")
        return keypoints, descriptors
    
    def extract_keypoints(self, image: np.ndarray, method: str = "sift") -> Tuple[List[cv2.KeyPoint], np.ndarray]:
        """
        提取图像关键点
        
        参数:
            image: 输入图像
            method: 关键点提取方法,可选 "sift", "orb", "brisk", "akaze"
            
        返回:
            关键点列表和特征描述符
        """
        if method == "sift":
            return self.extract_sift_features(image)
        elif method == "orb":
            return self.extract_orb_features(image)
        elif method == "brisk":
            # 创建BRISK特征提取器
            brisk = cv2.BRISK_create()
            
            # 转换为灰度图像
            if len(image.shape) == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            else:
                gray = image.copy()
            
            # 提取BRISK特征
            keypoints, descriptors = brisk.detectAndCompute(gray, None)
            
            logger.info(f"提取BRISK特征: 检测到 {len(keypoints)} 个关键点")
            return keypoints, descriptors
        
        elif method == "akaze":
            # 创建AKAZE特征提取器
            akaze = cv2.AKAZE_create()
            
            # 转换为灰度图像
            if len(image.shape) == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            else:
                gray = image.copy()
            
            # 提取AKAZE特征
            keypoints, descriptors = akaze.detectAndCompute(gray, None)
            
            logger.info(f"提取AKAZE特征: 检测到 {len(keypoints)} 个关键点")
            return keypoints, descriptors
        
        else:
            logger.warning(f"未知的关键点提取方法: {method},使用SIFT")
            return self.extract_sift_features(image)
    
    def draw_keypoints(self, image: np.ndarray, keypoints: List[cv2.KeyPoint], 
                      color: Tuple[int, int, int] = (0, 255, 0)) -> np.ndarray:
        """
        在图像上绘制关键点
        
        参数:
            image: 输入图像
            keypoints: 关键点列表
            color: 关键点颜色
            
        返回:
            绘制关键点后的图像
        """
        # 创建副本,避免修改原图像
        result = image.copy()
        
        # 如果图像是灰度图,转换为彩色图像
        if len(result.shape) == 2:
            result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)
        
        # 绘制关键点
        cv2.drawKeypoints(result, keypoints, result, color)
        
        logger.info(f"绘制 {len(keypoints)} 个关键点")
        return result
    
    # 颜色特征提取方法
    def extract_color_histogram(self, image: np.ndarray, n_bins: int = None, 
                               ranges: List[int] = None, channels: List[int] = None) -> np.ndarray:
        """
        提取颜色直方图特征
        
        参数:
            image: 输入图像
            n_bins: 每个通道的直方图柱数
            ranges: 每个通道的值范围
            channels: 要计算直方图的通道,默认为所有通道
            
        返回:
            颜色直方图特征
        """
        # 确保图像是彩色的
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        
        # 获取参数
        if n_bins is None:
            n_bins = self.config["color_histogram"]["n_bins"]
        
        if ranges is None:
            ranges = self.config["color_histogram"]["ranges"]
        
        # 如果没有指定通道,使用所有通道
        if channels is None:
            channels = list(range(image.shape[2]))
        
        # 计算每个通道的直方图
        histograms = []
        for i in channels:
            hist = cv2.calcHist([image], [i], None, [n_bins], ranges)
            # 归一化直方图
            hist = cv2.normalize(hist, hist).flatten()
            histograms.append(hist)
        
        # 合并所有通道的直方图
        features = np.concatenate(histograms)
        
        logger.info(f"提取颜色直方图特征: 特征维度 {features.shape}")
        return features
    
    def extract_color_moments(self, image: np.ndarray) -> np.ndarray:
        """
        提取颜色矩特征(均值、标准差、偏度)
        
        参数:
            image: 输入图像
            
        返回:
            颜色矩特征
        """
        # 确保图像是彩色的
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        
        # 转换为浮点型
        img_float = image.astype(np.float32)
        
        # 计算每个通道的颜色矩
        features = []
        for i in range(3):
            # 一阶矩(均值)
            mean = np.mean(img_float[:,:,i])
            features.append(mean)
            
            # 二阶矩(标准差)
            std = np.std(img_float[:,:,i])
            features.append(std)
            
            # 三阶矩(偏度)
            # 计算中心化的三次方
            channel = img_float[:,:,i]
            skewness = np.mean(((channel - mean) / (std + 1e-10)) ** 3)
            features.append(skewness)
        
        features = np.array(features)
        
        logger.info(f"提取颜色矩特征: 特征维度 {features.shape}")
        return features
    
    def extract_dominant_colors(self, image: np.ndarray, k: int = 5) -> Tuple[np.ndarray, np.ndarray]:
        """
        提取图像中的主要颜色
        
        参数:
            image: 输入图像
            k: 主要颜色的数量
            
        返回:
            主要颜色及其比例
        """
        # 确保图像是彩色的
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
        
        # 将图像重塑为二维数组
        pixels = image.reshape(-1, 3).astype(np.float32)
        
        # 设置K-means参数
        criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
        
        # 应用K-means聚类
        _, labels, centers = cv2.kmeans(pixels, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
        
        # 计算每个聚类的比例
        counts = np.bincount(labels.flatten())
        percentages = counts / len(labels)
        
        # 按比例排序
        indices = np.argsort(percentages)[::-1]
        centers = centers[indices]
        percentages = percentages[indices]
        
        logger.info(f"提取 {k} 个主要颜色")
        return centers, percentages
    
    # 纹理特征提取方法
    def extract_lbp_features(self, image: np.ndarray, radius: int = None, 
                            n_points: int = None) -> np.ndarray:
        """
        提取局部二值模式(LBP)特征
        
        参数:
            image: 输入图像
            radius: LBP半径
            n_points: 采样点数量
            
        返回:
            LBP特征直方图
        """
        # 导入skimage库用于LBP计算
        from skimage.feature import local_binary_pattern
        
        # 转换为灰度图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 获取参数
        if radius is None:
            radius = self.config["lbp"]["radius"]
        
        if n_points is None:
            n_points = self.config["lbp"]["n_points"]
        
        # 计算LBP
        lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
        
        # 计算LBP直方图
        n_bins = n_points + 2  # uniform LBP的柱数
        hist, _ = np.histogram(lbp.ravel(), bins=n_bins, range=(0, n_bins))
        
        # 归一化直方图
        hist = hist.astype(np.float32) / (hist.sum() + 1e-10)
        
        logger.info(f"提取LBP特征: 特征维度 {hist.shape}")
        return hist
    
    def extract_gabor_features(self, image: np.ndarray, 
                              orientations: int = 8, 
                              scales: int = 5) -> np.ndarray:
        """
        提取Gabor纹理特征
        
        参数:
            image: 输入图像
            orientations: Gabor滤波器的方向数
            scales: Gabor滤波器的尺度数
            
        返回:
            Gabor特征
        """
        # 转换为灰度图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 将图像转换为浮点型
        gray = gray.astype(np.float32)
        
        # 创建Gabor滤波器组
        gabor_features = []
        
        for scale in range(scales):
            for orientation in range(orientations):
                # 创建Gabor滤波器
                wavelength = 2 ** scale
                angle = orientation * np.pi / orientations
                
                kernel = cv2.getGaborKernel(
                    (31, 31), sigma=1.0, theta=angle, lambd=wavelength,
                    gamma=0.5, psi=0, ktype=cv2.CV_32F
                )
                
                # 应用Gabor滤波器
                filtered = cv2.filter2D(gray, cv2.CV_8UC3, kernel)
                
                # 计算滤波结果的均值和标准差
                mean = np.mean(filtered)
                std = np.std(filtered)
                
                # 添加到特征向量
                gabor_features.extend([mean, std])
        
        features = np.array(gabor_features)
        
        logger.info(f"提取Gabor特征: 特征维度 {features.shape}")
        return features
    
    # 深度学习特征提取方法
    def _load_deep_learning_model(self, model_name: str = None):
        """
        加载深度学习模型
        
        参数:
            model_name: 模型名称
        """
        if model_name is None:
            model_name = self.config["deep_learning"]["model_name"]
        
        try:
            # 导入TensorFlow
            import tensorflow as tf
            from tensorflow.keras.applications import (
                VGG16, VGG19, ResNet50, InceptionV3, MobileNetV2, EfficientNetB0
            )
            from tensorflow.keras.models import Model
            
            # 设置日志级别,抑制TensorFlow警告
            tf.get_logger().setLevel('ERROR')
            
            # 获取输入形状
            input_shape = tuple(self.config["deep_learning"]["input_shape"]) + (3,)
            
            # 获取池化方法
            pooling = self.config["deep_learning"]["pooling"]
            
            # 根据模型名称加载预训练模型
            if model_name.lower() == "vgg16":
                base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            elif model_name.lower() == "vgg19":
                base_model = VGG19(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            elif model_name.lower() == "resnet50":
                base_model = ResNet50(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            elif model_name.lower() == "inceptionv3":
                base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            elif model_name.lower() == "mobilenetv2":
                base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            elif model_name.lower() == "efficientnetb0":
                base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            else:
                logger.warning(f"未知的模型名称: {model_name},使用VGG16")
                base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape, pooling=pooling)
            
            # 创建特征提取模型
            self.deep_learning_model = Model(inputs=base_model.input, outputs=base_model.output)
            
            logger.info(f"成功加载深度学习模型: {model_name}")
            
        except ImportError as e:
            logger.error(f"无法导入TensorFlow: {e}")
            logger.warning("请安装TensorFlow: pip install tensorflow")
            self.deep_learning_model = None
        except Exception as e:
            logger.error(f"加载深度学习模型时出错: {e}")
            self.deep_learning_model = None
    
    def extract_deep_features(self, image: np.ndarray, model_name: str = None) -> np.ndarray:
        """
        使用预训练的深度学习模型提取特征
        
        参数:
            image: 输入图像
            model_name: 模型名称,可选 "vgg16", "vgg19", "resnet50", "inceptionv3", "mobilenetv2", "efficientnetb0"
            
        返回:
            深度特征
        """
        # 如果模型未加载,加载模型
        if self.deep_learning_model is None:
            self._load_deep_learning_model(model_name)
        
        # 如果模型加载失败,返回None
        if self.deep_learning_model is None:
            logger.error("深度学习模型未加载,无法提取特征")
            return None
        
        try:
            # 导入TensorFlow
            from tensorflow.keras.applications.vgg16 import preprocess_input
            from tensorflow.keras.preprocessing import image as keras_image
            
            # 确保图像是彩色的
            if len(image.shape) == 2:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
            
            # 转换BGR到RGB(OpenCV使用BGR,TensorFlow使用RGB)
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # 调整图像大小
            input_shape = self.config["deep_learning"]["input_shape"]
            resized = cv2.resize(image_rgb, tuple(input_shape))
            
            # 转换为TensorFlow格式
            x = keras_image.img_to_array(resized)
            x = np.expand_dims(x, axis=0)
            x = preprocess_input(x)
            
            # 提取特征
            features = self.deep_learning_model.predict(x)
            
            # 如果特征是多维的,将其展平
            if len(features.shape) > 2:
                features = features.reshape(features.shape[0], -1)
            
            logger.info(f"提取深度学习特征: 特征维度 {features.shape}")
            return features[0]  # 返回第一个样本的特征
            
        except Exception as e:
            logger.error(f"提取深度学习特征时出错: {e}")
            return None
    
    # 综合特征提取方法
    def extract_features(self, image: np.ndarray, methods: List[str] = None) -> Dict[str, np.ndarray]:
        """
        综合提取多种特征
        
        参数:
            image: 输入图像
            methods: 特征提取方法列表,可选 "hog", "sift", "orb", "color_histogram", 
                    "color_moments", "lbp", "gabor", "deep"
            
        返回:
            特征字典,键为方法名,值为特征向量
        """
        if methods is None:
            methods = ["hog", "color_histogram", "lbp"]
        
        features = {}
        
        for method in methods:
            try:
                if method == "hog":
                    features[method] = self.extract_hog_features(image)
                elif method == "sift":
                    keypoints, descriptors = self.extract_sift_features(image)
                    if descriptors is not None:
                        # 如果有多个描述符,计算均值
                        features[method] = np.mean(descriptors, axis=0) if descriptors.shape[0] > 0 else np.array([])
                    else:
                        features[method] = np.array([])
                elif method == "orb":
                    keypoints, descriptors = self.extract_orb_features(image)
                    if descriptors is not None:
                        # 如果有多个描述符,计算均值
                        features[method] = np.mean(descriptors, axis=0) if descriptors.shape[0] > 0 else np.array([])
                    else:
                        features[method] = np.array([])
                elif method == "color_histogram":
                    features[method] = self.extract_color_histogram(image)
                elif method == "color_moments":
                    features[method] = self.extract_color_moments(image)
                elif method == "lbp":
                    features[method] = self.extract_lbp_features(image)
                elif method == "gabor":
                    features[method] = self.extract_gabor_features(image)
                elif method == "deep":
                    deep_features = self.extract_deep_features(image)
                    if deep_features is not None:
                        features[method] = deep_features
                else:
                    logger.warning(f"未知的特征提取方法: {method}")
            except Exception as e:
                logger.error(f"提取 {method} 特征时出错: {e}")
                features[method] = np.array([])
        
        return features
    
    def feature_fusion(self, features: Dict[str, np.ndarray], weights: Dict[str, float] = None) -> np.ndarray:
        """
        特征融合
        
        参数:
            features: 特征字典,键为方法名,值为特征向量
            weights: 权重字典,键为方法名,值为权重
            
        返回:
            融合后的特征向量
        """
        if weights is None:
            # 如果没有指定权重,使用相等权重
            weights = {method: 1.0 / len(features) for method in features}
        
        # 标准化每个特征向量
        normalized_features = {}
        for method, feature in features.items():
            if feature.size > 0:
                # 标准化特征向量
                norm = np.linalg.norm(feature)
                if norm > 0:
                    normalized_features[method] = feature / norm
                else:
                    normalized_features[method] = feature
            else:
                normalized_features[method] = feature
        
        # 计算加权和
        fused_feature = np.array([])
        for method, feature in normalized_features.items():
            if method in weights and feature.size > 0:
                if fused_feature.size == 0:
                    fused_feature = weights[method] * feature
                else:
                    # 如果特征维度不同,使用零填充
                    if fused_feature.shape[0] < feature.shape[0]:
                        fused_feature = np.pad(fused_feature, (0, feature.shape[0] - fused_feature.shape[0]))
                    elif fused_feature.shape[0] > feature.shape[0]:
                        feature = np.pad(feature, (0, fused_feature.shape[0] - feature.shape[0]))
                    
                    fused_feature += weights[method] * feature
        
        logger.info(f"特征融合: 融合后特征维度 {fused_feature.shape}")
        return fused_feature

src\image_acquisition.py

"""
图像获取模块 (Image Acquisition Module)

该模块负责从不同来源获取图像数据,包括:
1. 本地文件系统
2. 摄像头实时捕获
3. 网络URL
4. 数据库

提供统一的接口来处理不同来源的图像,并进行基本的验证和预处理。
"""

import os
import cv2
import numpy as np
import requests
from PIL import Image
import io
import logging
from typing import Union, Optional, Tuple, List

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ImageAcquisition:
    """图像获取类,提供从不同来源获取图像的方法"""
    
    def __init__(self, default_image_size: Tuple[int, int] = (640, 480)):
        """
        初始化图像获取模块
        
        参数:
            default_image_size: 默认图像尺寸,用于调整大小 (宽, 高)
        """
        self.default_image_size = default_image_size
        self.camera = None
        logger.info(f"图像获取模块初始化完成,默认图像尺寸: {default_image_size}")
    
    def load_from_file(self, file_path: str, resize: bool = False) -> Optional[np.ndarray]:
        """
        从本地文件加载图像
        
        参数:
            file_path: 图像文件路径
            resize: 是否调整图像大小到默认尺寸
            
        返回:
            numpy数组格式的图像,如果加载失败则返回None
        """
        try:
            if not os.path.exists(file_path):
                logger.error(f"文件不存在: {file_path}")
                return None
            
            # 使用OpenCV加载图像
            image = cv2.imread(file_path)
            
            if image is None:
                logger.error(f"无法读取图像文件: {file_path}")
                return None
            
            logger.info(f"成功从文件加载图像: {file_path}, 尺寸: {image.shape}")
            
            # 如果需要调整大小
            if resize:
                image = cv2.resize(image, self.default_image_size)
                logger.info(f"图像已调整为默认尺寸: {self.default_image_size}")
            
            return image
        
        except Exception as e:
            logger.error(f"加载图像时出错: {str(e)}")
            return None
    
    def load_from_camera(self, camera_id: int = 0, num_frames: int = 1) -> Optional[np.ndarray]:
        """
        从摄像头捕获图像
        
        参数:
            camera_id: 摄像头ID,默认为0(通常是内置摄像头)
            num_frames: 捕获的帧数,用于跳过前几帧以等待摄像头稳定
            
        返回:
            捕获的图像,如果失败则返回None
        """
        try:
            # 初始化摄像头
            if self.camera is None:
                self.camera = cv2.VideoCapture(camera_id)
                
                if not self.camera.isOpened():
                    logger.error(f"无法打开摄像头 ID: {camera_id}")
                    return None
                
                logger.info(f"成功打开摄像头 ID: {camera_id}")
            
            # 捕获多帧以确保摄像头稳定(丢弃前几帧)
            for _ in range(num_frames - 1):
                self.camera.read()
            
            # 读取最后一帧
            ret, frame = self.camera.read()
            
            if not ret:
                logger.error("无法从摄像头捕获图像")
                return None
            
            logger.info(f"成功从摄像头捕获图像,尺寸: {frame.shape}")
            
            # 调整图像大小
            frame = cv2.resize(frame, self.default_image_size)
            
            return frame
        
        except Exception as e:
            logger.error(f"从摄像头捕获图像时出错: {str(e)}")
            return None
    
    def release_camera(self):
        """释放摄像头资源"""
        if self.camera is not None:
            self.camera.release()
            self.camera = None
            logger.info("摄像头资源已释放")
    
    def load_from_url(self, url: str, resize: bool = False) -> Optional[np.ndarray]:
        """
        从网络URL加载图像
        
        参数:
            url: 图像URL
            resize: 是否调整图像大小到默认尺寸
            
        返回:
            图像数据,如果失败则返回None
        """
        try:
            # 发送HTTP请求获取图像
            response = requests.get(url, timeout=10)
            
            if response.status_code != 200:
                logger.error(f"HTTP请求失败,状态码: {response.status_code}, URL: {url}")
                return None
            
            # 将二进制数据转换为图像
            image_array = np.asarray(bytearray(response.content), dtype=np.uint8)
            image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
            
            if image is None:
                logger.error(f"无法解码图像数据,URL: {url}")
                return None
            
            logger.info(f"成功从URL加载图像: {url}, 尺寸: {image.shape}")
            
            # 如果需要调整大小
            if resize:
                image = cv2.resize(image, self.default_image_size)
                logger.info(f"图像已调整为默认尺寸: {self.default_image_size}")
            
            return image
        
        except Exception as e:
            logger.error(f"从URL加载图像时出错: {str(e)}")
            return None
    
    def load_batch_from_directory(self, directory_path: str, file_extensions: List[str] = ['jpg', 'jpeg', 'png', 'bmp'], 
                                 max_images: int = 100, resize: bool = False) -> List[np.ndarray]:
        """
        从目录批量加载图像
        
        参数:
            directory_path: 目录路径
            file_extensions: 要加载的文件扩展名列表
            max_images: 最大加载图像数量
            resize: 是否调整图像大小
            
        返回:
            图像列表
        """
        images = []
        count = 0
        
        try:
            if not os.path.isdir(directory_path):
                logger.error(f"目录不存在: {directory_path}")
                return images
            
            # 获取目录中所有文件
            for filename in os.listdir(directory_path):
                # 检查文件扩展名
                if any(filename.lower().endswith(ext) for ext in file_extensions):
                    file_path = os.path.join(directory_path, filename)
                    image = self.load_from_file(file_path, resize)
                    
                    if image is not None:
                        images.append(image)
                        count += 1
                        
                        if count >= max_images:
                            logger.info(f"已达到最大图像数量限制: {max_images}")
                            break
            
            logger.info(f"从目录 {directory_path} 成功加载 {len(images)} 张图像")
            return images
            
        except Exception as e:
            logger.error(f"批量加载图像时出错: {str(e)}")
            return images
    
    def save_image(self, image: np.ndarray, file_path: str, quality: int = 95) -> bool:
        """
        保存图像到文件
        
        参数:
            image: 要保存的图像
            file_path: 保存路径
            quality: JPEG图像质量 (1-100)
            
        返回:
            是否保存成功
        """
        try:
            # 确保目录存在
            directory = os.path.dirname(file_path)
            if directory and not os.path.exists(directory):
                os.makedirs(directory)
                logger.info(f"创建目录: {directory}")
            
            # 根据文件扩展名确定保存参数
            _, ext = os.path.splitext(file_path)
            ext = ext.lower()
            
            if ext == '.jpg' or ext == '.jpeg':
                params = [cv2.IMWRITE_JPEG_QUALITY, quality]
            elif ext == '.png':
                params = [cv2.IMWRITE_PNG_COMPRESSION, min(9, 10 - quality // 10)]
            else:
                params = []
            
            # 保存图像
            result = cv2.imwrite(file_path, image, params)
            
            if result:
                logger.info(f"图像成功保存到: {file_path}")
            else:
                logger.error(f"无法保存图像到: {file_path}")
            
            return result
            
        except Exception as e:
            logger.error(f"保存图像时出错: {str(e)}")
            return False
    
    @staticmethod
    def convert_color_space(image: np.ndarray, conversion_code: int) -> np.ndarray:
        """
        转换图像颜色空间
        
        参数:
            image: 输入图像
            conversion_code: OpenCV颜色空间转换代码,如cv2.COLOR_BGR2RGB
            
        返回:
            转换后的图像
        """
        return cv2.cvtColor(image, conversion_code)
    
    @staticmethod
    def get_image_info(image: np.ndarray) -> dict:
        """
        获取图像信息
        
        参数:
            image: 输入图像
            
        返回:
            包含图像信息的字典
        """
        if image is None:
            return {"error": "无效图像"}
        
        height, width = image.shape[:2]
        channels = 1 if len(image.shape) == 2 else image.shape[2]
        
        return {
            "width": width,
            "height": height,
            "channels": channels,
            "dtype": str(image.dtype),
            "size_bytes": image.nbytes
        }


# 示例用法
if __name__ == "__main__":
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    # 从文件加载图像
    image = image_acq.load_from_file("../data/sample.jpg")
    if image is not None:
        print("成功加载图像文件")
        print(f"图像信息: {image_acq.get_image_info(image)}")
        
        # 保存图像
        image_acq.save_image(image, "../data/output.jpg")
    
    # 从摄像头捕获图像
    camera_image = image_acq.load_from_camera(num_frames=5)
    if camera_image is not None:
        print("成功从摄像头捕获图像")
        image_acq.save_image(camera_image, "../data/camera_capture.jpg")
    
    # 释放摄像头
    image_acq.release_camera()

src\image_preprocessing.py

"""
图像预处理模块 (Image Preprocessing Module)

该模块负责图像预处理操作,包括:
1. 图像增强 (亮度、对比度、锐化等)
2. 噪声去除 (高斯滤波、中值滤波、双边滤波等)
3. 几何变换 (调整大小、旋转、翻转等)
4. 标准化和归一化
5. 边缘检测和轮廓提取
6. 形态学操作 (腐蚀、膨胀、开运算、闭运算等)
"""

import cv2
import numpy as np
import logging
from typing import Union, Optional, Tuple, List, Dict, Any
import os
import sys

# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# 导入自定义模块
from utils.config_loader import ConfigLoader

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ImagePreprocessing:
    """图像预处理类,提供各种图像预处理操作"""
    
    def __init__(self, config_name: str = "preprocessing_config"):
        """
        初始化图像预处理模块
        
        参数:
            config_name: 配置文件名称,默认为"preprocessing_config"
        """
        # 加载配置
        self.config_loader = ConfigLoader()
        self.config = self.config_loader.load_config(config_name)
        
        if self.config:
            logger.info(f"成功加载预处理模块配置: {config_name}")
        else:
            logger.warning(f"无法加载配置: {config_name},使用默认配置")
            self.config = self._get_default_config()
        
        logger.info("图像预处理模块初始化完成")
    
    def _get_default_config(self) -> Dict[str, Any]:
        """获取默认配置"""
        return {
            "enhancement": {
                "default_brightness": 0,
                "default_contrast": 1.0,
                "default_gamma": 1.0,
                "sharpening_kernel_size": 3
            },
            "noise_removal": {
                "gaussian_kernel_size": 5,
                "gaussian_sigma": 0,
                "median_kernel_size": 5,
                "bilateral_diameter": 9,
                "bilateral_sigma_color": 75,
                "bilateral_sigma_space": 75
            },
            "geometric": {
                "default_interpolation": "linear"
            },
            "normalization": {
                "default_method": "minmax"
            },
            "edge_detection": {
                "canny_low_threshold": 100,
                "canny_high_threshold": 200,
                "canny_aperture_size": 3
            },
            "morphological": {
                "default_kernel_size": 5,
                "default_iterations": 1
            }
        }
    
    # 图像增强方法
    def adjust_brightness_contrast(self, image: np.ndarray, 
                                  brightness: float = None, 
                                  contrast: float = None) -> np.ndarray:
        """
        调整图像亮度和对比度
        
        参数:
            image: 输入图像
            brightness: 亮度调整值,正值增加亮度,负值降低亮度
            contrast: 对比度调整因子,大于1增加对比度,小于1降低对比度
            
        返回:
            调整后的图像
        """
        if brightness is None:
            brightness = self.config["enhancement"]["default_brightness"]
        
        if contrast is None:
            contrast = self.config["enhancement"]["default_contrast"]
        
        # 创建输出图像
        output = np.zeros(image.shape, image.dtype)
        
        # 应用公式: output = contrast * image + brightness
        output = cv2.convertScaleAbs(image, alpha=contrast, beta=brightness)
        
        logger.info(f"调整图像亮度和对比度: 亮度={brightness}, 对比度={contrast}")
        return output
    
    def adjust_gamma(self, image: np.ndarray, gamma: float = None) -> np.ndarray:
        """
        应用伽马校正
        
        参数:
            image: 输入图像
            gamma: 伽马值,小于1使暗区更亮,大于1使亮区更暗
            
        返回:
            伽马校正后的图像
        """
        if gamma is None:
            gamma = self.config["enhancement"]["default_gamma"]
        
        # 构建查找表
        inv_gamma = 1.0 / gamma
        table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in range(256)]).astype("uint8")
        
        # 应用查找表
        output = cv2.LUT(image, table)
        
        logger.info(f"应用伽马校正: gamma={gamma}")
        return output
    
    def sharpen_image(self, image: np.ndarray, kernel_size: int = None, 
                     amount: float = 1.0) -> np.ndarray:
        """
        锐化图像
        
        参数:
            image: 输入图像
            kernel_size: 锐化核大小,必须是奇数
            amount: 锐化强度
            
        返回:
            锐化后的图像
        """
        if kernel_size is None:
            kernel_size = self.config["enhancement"]["sharpening_kernel_size"]
        
        # 确保kernel_size是奇数
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        # 创建锐化核
        if kernel_size == 3:
            kernel = np.array([[-1, -1, -1],
                              [-1,  9, -1],
                              [-1, -1, -1]])
        else:
            # 创建高斯核
            gaussian = cv2.getGaussianKernel(kernel_size, 0)
            gaussian = gaussian * gaussian.transpose()
            
            # 创建锐化核
            kernel = np.zeros((kernel_size, kernel_size))
            center = kernel_size // 2
            kernel[center, center] = 2.0
            kernel -= gaussian * (1.0 / (kernel_size * kernel_size))
        
        # 应用锐化
        sharpened = cv2.filter2D(image, -1, kernel * amount)
        
        # 确保值在有效范围内
        sharpened = np.clip(sharpened, 0, 255).astype(np.uint8)
        
        logger.info(f"锐化图像: kernel_size={kernel_size}, amount={amount}")
        return sharpened
    
    def enhance_details(self, image: np.ndarray, method: str = "unsharp_mask", 
                       strength: float = 1.5) -> np.ndarray:
        """
        增强图像细节
        
        参数:
            image: 输入图像
            method: 增强方法,可选 "unsharp_mask" 或 "laplacian"
            strength: 增强强度
            
        返回:
            增强后的图像
        """
        if method == "unsharp_mask":
            # 高斯模糊
            blurred = cv2.GaussianBlur(image, (5, 5), 0)
            
            # 计算掩码
            mask = cv2.subtract(image, blurred)
            
            # 应用掩码
            enhanced = cv2.addWeighted(image, 1.0, mask, strength, 0)
            
            logger.info(f"使用Unsharp Mask增强细节: strength={strength}")
            return enhanced
            
        elif method == "laplacian":
            # 转换为灰度图像(如果是彩色图像)
            if len(image.shape) == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            else:
                gray = image.copy()
            
            # 应用拉普拉斯算子
            laplacian = cv2.Laplacian(gray, cv2.CV_64F)
            
            # 转换回8位
            laplacian = np.uint8(np.absolute(laplacian))
            
            # 如果原图是彩色的,将拉普拉斯结果应用到每个通道
            if len(image.shape) == 3:
                enhanced = image.copy()
                for i in range(3):
                    enhanced[:,:,i] = cv2.add(image[:,:,i], 
                                            np.uint8(laplacian * strength * 0.3))
            else:
                enhanced = cv2.add(image, np.uint8(laplacian * strength))
            
            logger.info(f"使用Laplacian增强细节: strength={strength}")
            return enhanced
        
        else:
            logger.warning(f"未知的细节增强方法: {method},返回原图像")
            return image
    
    # 噪声去除方法
    def apply_gaussian_blur(self, image: np.ndarray, kernel_size: int = None, 
                           sigma: float = None) -> np.ndarray:
        """
        应用高斯模糊
        
        参数:
            image: 输入图像
            kernel_size: 高斯核大小,必须是奇数
            sigma: 高斯核标准差
            
        返回:
            模糊后的图像
        """
        if kernel_size is None:
            kernel_size = self.config["noise_removal"]["gaussian_kernel_size"]
        
        if sigma is None:
            sigma = self.config["noise_removal"]["gaussian_sigma"]
        
        # 确保kernel_size是奇数
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        # 应用高斯模糊
        blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
        
        logger.info(f"应用高斯模糊: kernel_size={kernel_size}, sigma={sigma}")
        return blurred
    
    def apply_median_blur(self, image: np.ndarray, kernel_size: int = None) -> np.ndarray:
        """
        应用中值滤波
        
        参数:
            image: 输入图像
            kernel_size: 核大小,必须是奇数
            
        返回:
            滤波后的图像
        """
        if kernel_size is None:
            kernel_size = self.config["noise_removal"]["median_kernel_size"]
        
        # 确保kernel_size是奇数
        if kernel_size % 2 == 0:
            kernel_size += 1
        
        # 应用中值滤波
        blurred = cv2.medianBlur(image, kernel_size)
        
        logger.info(f"应用中值滤波: kernel_size={kernel_size}")
        return blurred
    
    def apply_bilateral_filter(self, image: np.ndarray, diameter: int = None, 
                              sigma_color: float = None, 
                              sigma_space: float = None) -> np.ndarray:
        """
        应用双边滤波
        
        参数:
            image: 输入图像
            diameter: 滤波直径
            sigma_color: 颜色空间的标准差
            sigma_space: 坐标空间的标准差
            
        返回:
            滤波后的图像
        """
        if diameter is None:
            diameter = self.config["noise_removal"]["bilateral_diameter"]
        
        if sigma_color is None:
            sigma_color = self.config["noise_removal"]["bilateral_sigma_color"]
        
        if sigma_space is None:
            sigma_space = self.config["noise_removal"]["bilateral_sigma_space"]
        
        # 应用双边滤波
        filtered = cv2.bilateralFilter(image, diameter, sigma_color, sigma_space)
        
        logger.info(f"应用双边滤波: diameter={diameter}, sigma_color={sigma_color}, sigma_space={sigma_space}")
        return filtered
    
    def remove_noise(self, image: np.ndarray, method: str = "gaussian", 
                    **kwargs) -> np.ndarray:
        """
        去除图像噪声
        
        参数:
            image: 输入图像
            method: 去噪方法,可选 "gaussian", "median", "bilateral"
            **kwargs: 传递给具体去噪方法的参数
            
        返回:
            去噪后的图像
        """
        if method == "gaussian":
            return self.apply_gaussian_blur(image, **kwargs)
        elif method == "median":
            return self.apply_median_blur(image, **kwargs)
        elif method == "bilateral":
            return self.apply_bilateral_filter(image, **kwargs)
        else:
            logger.warning(f"未知的去噪方法: {method},返回原图像")
            return image
    
    # 几何变换方法
    def resize_image(self, image: np.ndarray, width: int = None, height: int = None, 
                    scale: float = None, interpolation: str = None) -> np.ndarray:
        """
        调整图像大小
        
        参数:
            image: 输入图像
            width: 目标宽度
            height: 目标高度
            scale: 缩放比例,如果指定了scale,则忽略width和height
            interpolation: 插值方法,可选 "nearest", "linear", "cubic", "area"
            
        返回:
            调整大小后的图像
        """
        # 获取插值方法
        if interpolation is None:
            interpolation = self.config["geometric"]["default_interpolation"]
        
        # 映射插值方法
        interp_methods = {
            "nearest": cv2.INTER_NEAREST,
            "linear": cv2.INTER_LINEAR,
            "cubic": cv2.INTER_CUBIC,
            "area": cv2.INTER_AREA
        }
        
        interp = interp_methods.get(interpolation, cv2.INTER_LINEAR)
        
        # 计算目标尺寸
        if scale is not None:
            h, w = image.shape[:2]
            width = int(w * scale)
            height = int(h * scale)
            logger.info(f"按比例调整图像大小: scale={scale}")
        elif width is not None and height is not None:
            logger.info(f"调整图像大小为: width={width}, height={height}")
        else:
            # 如果没有指定宽度和高度,保持原始尺寸
            h, w = image.shape[:2]
            width, height = w, h
            logger.warning("未指定目标尺寸,保持原始尺寸")
        
        # 调整大小
        resized = cv2.resize(image, (width, height), interpolation=interp)
        
        return resized
    
    def rotate_image(self, image: np.ndarray, angle: float, 
                    center: Tuple[int, int] = None, scale: float = 1.0) -> np.ndarray:
        """
        旋转图像
        
        参数:
            image: 输入图像
            angle: 旋转角度(度),正值表示逆时针旋转
            center: 旋转中心点,默认为图像中心
            scale: 缩放因子
            
        返回:
            旋转后的图像
        """
        h, w = image.shape[:2]
        
        if center is None:
            center = (w // 2, h // 2)
        
        # 获取旋转矩阵
        M = cv2.getRotationMatrix2D(center, angle, scale)
        
        # 执行仿射变换
        rotated = cv2.warpAffine(image, M, (w, h))
        
        logger.info(f"旋转图像: angle={angle}, center={center}, scale={scale}")
        return rotated
    
    def flip_image(self, image: np.ndarray, flip_code: int) -> np.ndarray:
        """
        翻转图像
        
        参数:
            image: 输入图像
            flip_code: 翻转代码
                0 = 水平翻转
                1 = 垂直翻转
                -1 = 水平和垂直翻转
            
        返回:
            翻转后的图像
        """
        flipped = cv2.flip(image, flip_code)
        
        flip_type = {0: "水平", 1: "垂直", -1: "水平和垂直"}
        logger.info(f"翻转图像: {flip_type.get(flip_code, '未知')}")
        
        return flipped
    
    def crop_image(self, image: np.ndarray, x: int, y: int, width: int, height: int) -> np.ndarray:
        """
        裁剪图像
        
        参数:
            image: 输入图像
            x, y: 左上角坐标
            width, height: 裁剪区域的宽度和高度
            
        返回:
            裁剪后的图像
        """
        # 确保坐标在有效范围内
        img_height, img_width = image.shape[:2]
        
        x = max(0, min(x, img_width - 1))
        y = max(0, min(y, img_height - 1))
        width = max(1, min(width, img_width - x))
        height = max(1, min(height, img_height - y))
        
        # 裁剪图像
        cropped = image[y:y+height, x:x+width]
        
        logger.info(f"裁剪图像: x={x}, y={y}, width={width}, height={height}")
        return cropped
    
    # 标准化和归一化方法
    def normalize_image(self, image: np.ndarray, method: str = None, 
                       min_value: float = 0, max_value: float = 255) -> np.ndarray:
        """
        标准化图像
        
        参数:
            image: 输入图像
            method: 标准化方法,可选 "minmax", "mean", "z-score"
            min_value: 最小值(用于minmax方法)
            max_value: 最大值(用于minmax方法)
            
        返回:
            标准化后的图像
        """
        if method is None:
            method = self.config["normalization"]["default_method"]
        
        # 转换为浮点型
        img_float = image.astype(np.float32)
        
        if method == "minmax":
            # Min-Max标准化
            img_min = np.min(img_float)
            img_max = np.max(img_float)
            
            if img_max > img_min:
                normalized = (img_float - img_min) / (img_max - img_min) * (max_value - min_value) + min_value
            else:
                normalized = np.zeros_like(img_float) + min_value
            
            logger.info(f"应用Min-Max标准化: min={min_value}, max={max_value}")
            
        elif method == "mean":
            # 减去均值
            mean = np.mean(img_float)
            normalized = img_float - mean
            
            logger.info(f"应用均值标准化: mean={mean}")
            
        elif method == "z-score":
            # Z-score标准化
            mean = np.mean(img_float)
            std = np.std(img_float)
            
            if std > 0:
                normalized = (img_float - mean) / std
            else:
                normalized = np.zeros_like(img_float)
            
            logger.info(f"应用Z-score标准化: mean={mean}, std={std}")
            
        else:
            logger.warning(f"未知的标准化方法: {method},返回原图像")
            return image
        
        # 如果输入是uint8类型,将结果转换回uint8
        if image.dtype == np.uint8:
            normalized = np.clip(normalized, 0, 255).astype(np.uint8)
        
        return normalized
    
    # 边缘检测和轮廓提取方法
    def detect_edges(self, image: np.ndarray, method: str = "canny", 
                    low_threshold: int = None, high_threshold: int = None) -> np.ndarray:
        """
        检测图像边缘
        
        参数:
            image: 输入图像
            method: 边缘检测方法,可选 "canny", "sobel", "laplacian"
            low_threshold: Canny边缘检测的低阈值
            high_threshold: Canny边缘检测的高阈值
            
        返回:
            边缘图像
        """
        # 如果是彩色图像,转换为灰度图
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        if method == "canny":
            if low_threshold is None:
                low_threshold = self.config["edge_detection"]["canny_low_threshold"]
            
            if high_threshold is None:
                high_threshold = self.config["edge_detection"]["canny_high_threshold"]
            
            aperture_size = self.config["edge_detection"]["canny_aperture_size"]
            
            # 应用Canny边缘检测
            edges = cv2.Canny(gray, low_threshold, high_threshold, apertureSize=aperture_size)
            
            logger.info(f"应用Canny边缘检测: low_threshold={low_threshold}, high_threshold={high_threshold}")
            
        elif method == "sobel":
            # 计算x和y方向的梯度
            sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
            sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
            
            # 计算梯度幅值
            magnitude = np.sqrt(sobelx**2 + sobely**2)
            
            # 归一化到0-255
            edges = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
            
            logger.info("应用Sobel边缘检测")
            
        elif method == "laplacian":
            # 应用拉普拉斯算子
            laplacian = cv2.Laplacian(gray, cv2.CV_64F)
            
            # 转换为绝对值并归一化
            edges = np.uint8(np.absolute(laplacian))
            
            logger.info("应用Laplacian边缘检测")
            
        else:
            logger.warning(f"未知的边缘检测方法: {method},返回原图像")
            return image
        
        return edges
    
    def find_contours(self, image: np.ndarray, mode: int = cv2.RETR_EXTERNAL, 
                     method: int = cv2.CHAIN_APPROX_SIMPLE) -> Tuple[List, np.ndarray]:
        """
        查找图像中的轮廓
        
        参数:
            image: 输入图像(最好是二值图像或边缘图像)
            mode: 轮廓检索模式
            method: 轮廓近似方法
            
        返回:
            轮廓列表和层次结构
        """
        # 确保图像是二值图像
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            _, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
        elif np.max(image) > 1:
            _, binary = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)
        else:
            binary = image.copy()
        
        # 查找轮廓
        contours, hierarchy = cv2.findContours(binary, mode, method)
        
        logger.info(f"查找轮廓: 找到 {len(contours)} 个轮廓")
        return contours, hierarchy
    
    def draw_contours(self, image: np.ndarray, contours: List, 
                     color: Tuple[int, int, int] = (0, 255, 0), 
                     thickness: int = 2) -> np.ndarray:
        """
        在图像上绘制轮廓
        
        参数:
            image: 输入图像
            contours: 轮廓列表
            color: 轮廓颜色
            thickness: 轮廓线条粗细
            
        返回:
            绘制轮廓后的图像
        """
        # 创建副本,避免修改原图像
        result = image.copy()
        
        # 如果图像是灰度图,转换为彩色图像
        if len(result.shape) == 2:
            result = cv2.cvtColor(result, cv2.COLOR_GRAY2BGR)
        
        # 绘制所有轮廓
        cv2.drawContours(result, contours, -1, color, thickness)
        
        logger.info(f"绘制 {len(contours)} 个轮廓: color={color}, thickness={thickness}")
        return result
    
    # 形态学操作方法
    def apply_morphological_operation(self, image: np.ndarray, operation: str, 
                                     kernel_size: int = None, 
                                     iterations: int = None) -> np.ndarray:
        """
        应用形态学操作
        
        参数:
            image: 输入图像
            operation: 操作类型,可选 "erode", "dilate", "open", "close", "gradient", "tophat", "blackhat"
            kernel_size: 结构元素大小
            iterations: 操作重复次数
            
        返回:
            处理后的图像
        """
        if kernel_size is None:
            kernel_size = self.config["morphological"]["default_kernel_size"]
        
        if iterations is None:
            iterations = self.config["morphological"]["default_iterations"]
        
        # 创建结构元素
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        
        # 应用形态学操作
        if operation == "erode":
            result = cv2.erode(image, kernel, iterations=iterations)
            op_name = "腐蚀"
        elif operation == "dilate":
            result = cv2.dilate(image, kernel, iterations=iterations)
            op_name = "膨胀"
        elif operation == "open":
            result = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel, iterations=iterations)
            op_name = "开运算"
        elif operation == "close":
            result = cv2.morphologyEx(image, cv2.MORPH_CLOSE, kernel, iterations=iterations)
            op_name = "闭运算"
        elif operation == "gradient":
            result = cv2.morphologyEx(image, cv2.MORPH_GRADIENT, kernel, iterations=iterations)
            op_name = "形态学梯度"
        elif operation == "tophat":
            result = cv2.morphologyEx(image, cv2.MORPH_TOPHAT, kernel, iterations=iterations)
            op_name = "顶帽"
        elif operation == "blackhat":
            result = cv2.morphologyEx(image, cv2.MORPH_BLACKHAT, kernel, iterations=iterations)
            op_name = "黑帽"
        else:
            logger.warning(f"未知的形态学操作: {operation},返回原图像")
            return image
        
        logger.info(f"应用{op_name}操作: kernel_size={kernel_size}, iterations={iterations}")
        return result
    
    # 综合处理方法
    def preprocess_for_ocr(self, image: np.ndarray) -> np.ndarray:
        """
        为OCR预处理图像
        
        参数:
            image: 输入图像
            
        返回:
            预处理后的图像
        """
        # 转换为灰度图
        if len(image.shape) == 3:
            gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        else:
            gray = image.copy()
        
        # 应用高斯模糊去噪
        blurred = self.apply_gaussian_blur(gray, kernel_size=5)
        
        # 自适应阈值二值化
        binary = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                      cv2.THRESH_BINARY, 11, 2)
        
        # 应用开运算去除小噪点
        processed = self.apply_morphological_operation(binary, "open", kernel_size=3)
        
        logger.info("完成OCR图像预处理")
        return processed
    
    def preprocess_for_face_detection(self, image: np.ndarray) -> np.ndarray:
        """
        为人脸检测预处理图像
        
        参数:
            image: 输入图像
            
        返回:
            预处理后的图像
        """
        # 调整图像大小
        resized = self.resize_image(image, width=640, height=480)
        
        # 应用直方图均衡化增强对比度
        if len(resized.shape) == 3:
            # 转换到YUV色彩空间
            yuv = cv2.cvtColor(resized, cv2.COLOR_BGR2YUV)
            # 对Y通道进行直方图均衡化
            yuv[:,:,0] = cv2.equalizeHist(yuv[:,:,0])
            # 转换回BGR
            processed = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
        else:
            # 灰度图像直接均衡化
            processed = cv2.equalizeHist(resized)
        
        logger.info("完成人脸检测图像预处理")
        return processed
    
    def preprocess_for_object_detection(self, image: np.ndarray) -> np.ndarray:
        """
        为物体检测预处理图像
        
        参数:
            image: 输入图像
            
        返回:
            预处理后的图像
        """
        # 调整图像大小
        resized = self.resize_image(image, width=416, height=416)
        
        # 标准化图像
        normalized = self.normalize_image(resized, method="minmax", min_value=0, max_value=1)
        
        logger.info("完成物体检测图像预处理")
        return normalized
    
    def preprocess_pipeline(self, image: np.ndarray, pipeline: List[Dict]) -> np.ndarray:
        """
        应用预处理流水线
        
        参数:
            image: 输入图像
            pipeline: 预处理步骤列表,每个步骤是一个字典,包含操作类型和参数
                例如: [
                    {"operation": "resize", "params": {"width": 300, "height": 300}},
                    {"operation": "gaussian_blur", "params": {"kernel_size": 5}},
                    {"operation": "normalize", "params": {"method": "minmax"}}
                ]
            
        返回:
            处理后的图像
        """
        processed = image.copy()
        
        for step in pipeline:
            operation = step.get("operation", "")
            params = step.get("params", {})
            
            if operation == "resize":
                processed = self.resize_image(processed, **params)
            elif operation == "rotate":
                processed = self.rotate_image(processed, **params)
            elif operation == "flip":
                processed = self.flip_image(processed, **params)
            elif operation == "crop":
                processed = self.crop_image(processed, **params)
            elif operation == "brightness_contrast":
                processed = self.adjust_brightness_contrast(processed, **params)
            elif operation == "gamma":
                processed = self.adjust_gamma(processed, **params)
            elif operation == "sharpen":
                processed = self.sharpen_image(processed, **params)
            elif operation == "enhance_details":
                processed = self.enhance_details(processed, **params)
            elif operation == "gaussian_blur":
                processed = self.apply_gaussian_blur(processed, **params)
            elif operation == "median_blur":
                processed = self.apply_median_blur(processed, **params)
            elif operation == "bilateral_filter":
                processed = self.apply_bilateral_filter(processed, **params)
            elif operation == "normalize":
                processed = self.normalize_image(processed, **params)
            elif operation == "edge_detection":
                processed = self.detect_edges(processed, **params)
            elif operation == "morphological":
                processed = self.apply_morphological_operation(processed, **params)
            else:
                logger.warning(f"未知的操作: {operation},跳过")
        
        logger.info(f"完成预处理流水线,应用了 {len(pipeline)} 个操作")
        return processed

tests\test_classification.py

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
测试分类/识别模块的功能
"""

import os
import sys
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits

# 添加项目根目录到系统路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# 导入项目模块
from src.classification import Classification
from src.feature_extraction import FeatureExtraction
from src.image_preprocessing import ImagePreprocessing
from src.utils.config_loader import ConfigLoader

def test_traditional_ml_classifiers():
    """测试传统机器学习分类器"""
    print("测试传统机器学习分类器...")
    
    # 加载示例数据集(手写数字)
    digits = load_digits()
    X, y = digits.data, digits.target
    
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # 初始化分类模块
    config_loader = ConfigLoader("config/classification_config.json")
    classification = Classification(config_loader.config)
    
    # 测试SVM分类器
    print("\n测试SVM分类器:")
    svm_classifier = classification.train_svm_classifier(X_train, y_train)
    if svm_classifier:
        metrics = classification.evaluate_classifier(X_test, y_test, "svm")
        print(f"SVM分类器准确率: {metrics['accuracy']:.4f}")
        print(f"SVM分类器F1分数: {metrics['f1_score']:.4f}")
    
    # 测试KNN分类器
    print("\n测试KNN分类器:")
    knn_classifier = classification.train_knn_classifier(X_train, y_train)
    if knn_classifier:
        metrics = classification.evaluate_classifier(X_test, y_test, "knn")
        print(f"KNN分类器准确率: {metrics['accuracy']:.4f}")
        print(f"KNN分类器F1分数: {metrics['f1_score']:.4f}")
    
    # 测试随机森林分类器
    print("\n测试随机森林分类器:")
    rf_classifier = classification.train_random_forest_classifier(X_train, y_train)
    if rf_classifier:
        metrics = classification.evaluate_classifier(X_test, y_test, "random_forest")
        print(f"随机森林分类器准确率: {metrics['accuracy']:.4f}")
        print(f"随机森林分类器F1分数: {metrics['f1_score']:.4f}")
    
    # 保存和加载分类器
    print("\n测试保存和加载分类器:")
    save_path = "models/svm_classifier.pkl"
    if classification.save_classifier("svm", save_path):
        print(f"SVM分类器已保存到 {save_path}")
    
    # 创建新的分类实例
    new_classification = Classification(config_loader.config)
    if new_classification.load_classifier("svm", save_path):
        print(f"SVM分类器已从 {save_path} 加载")
        metrics = new_classification.evaluate_classifier(X_test, y_test, "svm")
        print(f"加载后的SVM分类器准确率: {metrics['accuracy']:.4f}")

def test_image_classification():
    """测试图像分类功能"""
    print("\n测试图像分类功能...")
    
    # 创建目录(如果不存在)
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    
    # 加载配置
    config_loader = ConfigLoader("config/classification_config.json")
    feature_config_loader = ConfigLoader("config/feature_extraction_config.json")
    
    # 初始化模块
    classification = Classification(config_loader.config)
    feature_extraction = FeatureExtraction(feature_config_loader.config)
    preprocessing = ImagePreprocessing()
    
    # 加载示例图像
    try:
        # 尝试加载一些示例图像(这里假设有一个images目录)
        image_paths = []
        for root, _, files in os.walk("images"):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_paths.append(os.path.join(root, file))
        
        if not image_paths:
            print("未找到示例图像,跳过图像分类测试")
            return
        
        # 加载第一张图像
        image = cv2.imread(image_paths[0])
        if image is None:
            print(f"无法加载图像: {image_paths[0]}")
            return
        
        # 预处理图像
        processed_image = preprocessing.resize_image(image, (224, 224))
        
        # 提取特征
        features = feature_extraction.extract_hog_features(processed_image)
        
        # 使用SVM分类器进行分类
        print("\n使用SVM分类器分类图像:")
        # 注意:这里假设已经训练了分类器,实际应用中需要先训练
        if "svm" in classification.classifiers:
            result = classification.classify_image(processed_image, features, "svm")
            print(f"分类结果: {result['predictions']}")
            print(f"处理时间: {result['processing_time']:.4f} 秒")
        else:
            print("SVM分类器未训练,跳过分类")
        
    except Exception as e:
        print(f"图像分类测试出错: {e}")

def test_object_detection():
    """测试目标检测功能"""
    print("\n测试目标检测功能...")
    
    # 加载配置
    config_loader = ConfigLoader("config/classification_config.json")
    
    # 初始化分类模块
    classification = Classification(config_loader.config)
    
    # 加载示例图像
    try:
        # 尝试加载一些示例图像
        image_paths = []
        for root, _, files in os.walk("images"):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    image_paths.append(os.path.join(root, file))
        
        if not image_paths:
            print("未找到示例图像,跳过目标检测测试")
            return
        
        # 加载第一张图像
        image = cv2.imread(image_paths[0])
        if image is None:
            print(f"无法加载图像: {image_paths[0]}")
            return
        
        # 执行目标检测
        print("\n使用YOLOv5进行目标检测:")
        try:
            result = classification.recognize_objects(image, "yolo")
            detection_results = result["detection_results"]
            
            print(f"检测到 {len(detection_results['boxes'])} 个目标")
            print(f"处理时间: {result['processing_time']:.4f} 秒")
            
            # 绘制检测结果
            if len(detection_results['boxes']) > 0:
                result_image = classification.draw_detection_results(image, detection_results)
                
                # 保存结果图像
                result_path = "results/detection_result.jpg"
                cv2.imwrite(result_path, result_image)
                print(f"检测结果已保存到 {result_path}")
                
                # 显示结果(如果在支持GUI的环境中)
                try:
                    plt.figure(figsize=(10, 8))
                    plt.imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
                    plt.axis('off')
                    plt.title("目标检测结果")
                    plt.savefig("results/detection_result_plt.jpg")
                    print("结果图像已使用matplotlib保存")
                except Exception as e:
                    print(f"无法显示结果图像: {e}")
        except ImportError:
            print("未安装PyTorch或YOLOv5,跳过目标检测测试")
            print("请安装所需依赖: pip install torch torchvision")
        
    except Exception as e:
        print(f"目标检测测试出错: {e}")

def main():
    """主函数"""
    # 创建必要的目录
    os.makedirs("models", exist_ok=True)
    os.makedirs("results", exist_ok=True)
    os.makedirs("logs", exist_ok=True)
    
    # 测试传统机器学习分类器
    test_traditional_ml_classifiers()
    
    # 测试图像分类
    test_image_classification()
    
    # 测试目标检测
    test_object_detection()

if __name__ == "__main__":
    main()

tests\test_image_acquisition.py

"""
图像获取模块测试脚本

该脚本用于测试图像获取模块的各项功能,包括:
1. 从文件加载图像
2. 从摄像头捕获图像
3. 从URL加载图像
4. 批量加载图像
5. 保存图像
"""

import os
import sys
import cv2
import matplotlib.pyplot as plt
import numpy as np
import time

# 添加项目根目录到Python路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# 导入自定义模块
from src.image_acquisition import ImageAcquisition
from utils.config_loader import ConfigLoader

def display_image(image, title="Image"):
    """显示图像"""
    plt.figure(figsize=(10, 8))
    
    # 如果是BGR格式(OpenCV默认),转换为RGB
    if len(image.shape) == 3 and image.shape[2] == 3:
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    else:
        image_rgb = image
        
    plt.imshow(image_rgb)
    plt.title(title)
    plt.axis('off')
    plt.show()

def test_file_loading():
    """测试从文件加载图像"""
    print("\n=== 测试从文件加载图像 ===")
    
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    # 测试图像路径
    test_image_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 
                                  "data", "sample.jpg")
    
    # 检查测试图像是否存在,如果不存在则创建一个测试图像
    if not os.path.exists(test_image_path):
        print(f"测试图像不存在,创建测试图像: {test_image_path}")
        
        # 确保目录存在
        os.makedirs(os.path.dirname(test_image_path), exist_ok=True)
        
        # 创建一个简单的测试图像 (彩色渐变)
        width, height = 640, 480
        image = np.zeros((height, width, 3), dtype=np.uint8)
        
        # 创建渐变效果
        for y in range(height):
            for x in range(width):
                image[y, x, 0] = int(255 * x / width)  # 蓝色通道
                image[y, x, 1] = int(255 * y / height)  # 绿色通道
                image[y, x, 2] = int(255 * (x + y) / (width + height))  # 红色通道
        
        # 添加一些文本
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(image, 'Test Image', (50, 50), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
        
        # 保存测试图像
        cv2.imwrite(test_image_path, image)
        print(f"测试图像已创建")
    
    # 加载图像
    print(f"从文件加载图像: {test_image_path}")
    image = image_acq.load_from_file(test_image_path)
    
    if image is not None:
        print("图像加载成功")
        print(f"图像信息: {image_acq.get_image_info(image)}")
        
        # 显示图像 (如果在支持GUI的环境中)
        try:
            display_image(image, "从文件加载的图像")
        except Exception as e:
            print(f"无法显示图像: {str(e)}")
        
        # 测试调整大小
        resized_image = image_acq.load_from_file(test_image_path, resize=True)
        print(f"调整大小后的图像信息: {image_acq.get_image_info(resized_image)}")
        
        # 保存调整大小后的图像
        output_path = os.path.join(os.path.dirname(test_image_path), "resized_sample.jpg")
        image_acq.save_image(resized_image, output_path)
        print(f"调整大小后的图像已保存到: {output_path}")
    else:
        print("图像加载失败")

def test_camera_capture():
    """测试从摄像头捕获图像"""
    print("\n=== 测试从摄像头捕获图像 ===")
    
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    try:
        print("尝试从摄像头捕获图像...")
        image = image_acq.load_from_camera(num_frames=5)
        
        if image is not None:
            print("成功从摄像头捕获图像")
            print(f"图像信息: {image_acq.get_image_info(image)}")
            
            # 保存捕获的图像
            output_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 
                                     "data", "camera_capture.jpg")
            image_acq.save_image(image, output_path)
            print(f"摄像头捕获的图像已保存到: {output_path}")
            
            # 显示图像
            try:
                display_image(image, "摄像头捕获的图像")
            except Exception as e:
                print(f"无法显示图像: {str(e)}")
        else:
            print("无法从摄像头捕获图像")
    
    except Exception as e:
        print(f"测试摄像头捕获时出错: {str(e)}")
    
    finally:
        # 释放摄像头资源
        image_acq.release_camera()
        print("摄像头资源已释放")

def test_url_loading():
    """测试从URL加载图像"""
    print("\n=== 测试从URL加载图像 ===")
    
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    # 测试URL (使用一个公共图像URL)
    test_url = "https://raw.githubusercontent.com/opencv/opencv/master/samples/data/lena.jpg"
    
    try:
        print(f"从URL加载图像: {test_url}")
        image = image_acq.load_from_url(test_url)
        
        if image is not None:
            print("图像加载成功")
            print(f"图像信息: {image_acq.get_image_info(image)}")
            
            # 保存图像
            output_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 
                                     "data", "url_image.jpg")
            image_acq.save_image(image, output_path)
            print(f"从URL加载的图像已保存到: {output_path}")
            
            # 显示图像
            try:
                display_image(image, "从URL加载的图像")
            except Exception as e:
                print(f"无法显示图像: {str(e)}")
        else:
            print("从URL加载图像失败")
    
    except Exception as e:
        print(f"测试URL加载时出错: {str(e)}")

def test_batch_loading():
    """测试批量加载图像"""
    print("\n=== 测试批量加载图像 ===")
    
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    # 测试目录
    test_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
    
    # 确保目录中有多个测试图像
    for i in range(3):
        # 创建一个简单的测试图像
        width, height = 320, 240
        image = np.zeros((height, width, 3), dtype=np.uint8)
        
        # 使用不同颜色
        color = [(0, 0, 255), (0, 255, 0), (255, 0, 0)][i % 3]
        
        # 填充颜色
        image[:, :] = color
        
        # 添加文本
        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(image, f'Test {i+1}', (50, 120), font, 1, (255, 255, 255), 2, cv2.LINE_AA)
        
        # 保存图像
        output_path = os.path.join(test_dir, f"test_image_{i+1}.jpg")
        cv2.imwrite(output_path, image)
        print(f"创建测试图像: {output_path}")
    
    # 批量加载图像
    print(f"从目录批量加载图像: {test_dir}")
    images = image_acq.load_batch_from_directory(test_dir, resize=True)
    
    print(f"成功加载 {len(images)} 张图像")
    
    # 显示加载的图像信息
    for i, image in enumerate(images):
        print(f"图像 {i+1} 信息: {image_acq.get_image_info(image)}")

def test_color_conversion():
    """测试颜色空间转换"""
    print("\n=== 测试颜色空间转换 ===")
    
    # 创建图像获取实例
    image_acq = ImageAcquisition()
    
    # 加载测试图像
    test_image_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 
                                  "data", "sample.jpg")
    
    if not os.path.exists(test_image_path):
        print(f"测试图像不存在: {test_image_path}")
        return
    
    # 加载图像
    image = image_acq.load_from_file(test_image_path)
    
    if image is not None:
        # BGR转RGB
        rgb_image = ImageAcquisition.convert_color_space(image, cv2.COLOR_BGR2RGB)
        print("BGR转RGB成功")
        
        # BGR转灰度
        gray_image = ImageAcquisition.convert_color_space(image, cv2.COLOR_BGR2GRAY)
        print("BGR转灰度成功")
        print(f"灰度图像信息: {image_acq.get_image_info(gray_image)}")
        
        # BGR转HSV
        hsv_image = ImageAcquisition.convert_color_space(image, cv2.COLOR_BGR2HSV)
        print("BGR转HSV成功")
        
        # 保存转换后的图像
        output_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
        
        # 保存RGB图像
        rgb_path = os.path.join(output_dir, "rgb_image.jpg")
        image_acq.save_image(rgb_image, rgb_path)
        print(f"RGB图像已保存到: {rgb_path}")
        
        # 保存灰度图像
        gray_path = os.path.join(output_dir, "gray_image.jpg")
        image_acq.save_image(gray_image, gray_path)
        print(f"灰度图像已保存到: {gray_path}")
        
        # 保存HSV图像
        hsv_path = os.path.join(output_dir, "hsv_image.jpg")
        image_acq.save_image(hsv_image, hsv_path)
        print(f"HSV图像已保存到: {hsv_path}")
    else:
        print("无法加载测试图像")

def main():
    """主测试函数"""
    print("=== 图像获取模块测试 ===")
    
    # 加载配置
    config_loader = ConfigLoader()
    config = config_loader.load_config("image_acquisition_config")
    
    if config:
        print("成功加载配置:")
        print(f"默认图像尺寸: {config['default_image_size']}")
        print(f"支持的文件扩展名: {config['file_loading']['supported_extensions']}")
    else:
        print("无法加载配置,使用默认设置")
    
    # 运行测试
    test_file_loading()
    
    # 询问是否测试摄像头功能
    response = input("\n是否测试摄像头功能? (y/n): ")
    if response.lower() == 'y':
        test_camera_capture()
    
    # 询问是否测试URL加载功能
    response = input("\n是否测试从URL加载图像? (y/n): ")
    if response.lower() == 'y':
        test_url_loading()
    
    test_batch_loading()
    test_color_conversion()
    
    print("\n=== 测试完成 ===")

if __name__ == "__main__":
    main()

utils\config_loader.py

"""
配置加载器模块

负责加载和解析配置文件,提供统一的接口访问配置参数。
支持JSON和YAML格式的配置文件。
"""

import os
import json
import logging
from typing import Dict, Any, Optional

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ConfigLoader:
    """配置加载器类,用于加载和访问配置文件"""
    
    def __init__(self, config_dir: str = None):
        """
        初始化配置加载器
        
        参数:
            config_dir: 配置文件目录,默认为None,将使用相对路径
        """
        if config_dir is None:
            # 默认使用项目根目录下的config文件夹
            self.config_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'config')
        else:
            self.config_dir = config_dir
            
        self.configs = {}
        logger.info(f"配置加载器初始化完成,配置目录: {self.config_dir}")
    
    def load_config(self, config_name: str) -> Optional[Dict[str, Any]]:
        """
        加载指定名称的配置文件
        
        参数:
            config_name: 配置文件名称,不包含扩展名
            
        返回:
            配置字典,如果加载失败则返回None
        """
        # 如果配置已加载,直接返回
        if config_name in self.configs:
            return self.configs[config_name]
        
        # 尝试加载JSON格式配置
        json_path = os.path.join(self.config_dir, f"{config_name}.json")
        if os.path.exists(json_path):
            try:
                with open(json_path, 'r', encoding='utf-8') as f:
                    config = json.load(f)
                    self.configs[config_name] = config
                    logger.info(f"成功加载配置文件: {json_path}")
                    return config
            except Exception as e:
                logger.error(f"加载JSON配置文件时出错: {str(e)}")
                return None
        
        # 尝试加载YAML格式配置(如果有需要)
        # 这里可以添加YAML配置加载代码
        
        logger.error(f"找不到配置文件: {config_name}")
        return None
    
    def get_config(self, config_name: str, auto_load: bool = True) -> Optional[Dict[str, Any]]:
        """
        获取配置,如果未加载则尝试加载
        
        参数:
            config_name: 配置名称
            auto_load: 如果配置未加载,是否自动加载
            
        返回:
            配置字典
        """
        if config_name not in self.configs and auto_load:
            return self.load_config(config_name)
        
        return self.configs.get(config_name)
    
    def get_value(self, config_name: str, key_path: str, default_value: Any = None) -> Any:
        """
        获取配置中的特定值
        
        参数:
            config_name: 配置名称
            key_path: 键路径,使用点号分隔,如"camera.default_camera_id"
            default_value: 如果键不存在,返回的默认值
            
        返回:
            配置值
        """
        config = self.get_config(config_name)
        if not config:
            return default_value
        
        # 解析键路径
        keys = key_path.split('.')
        value = config
        
        # 逐层获取值
        for key in keys:
            if isinstance(value, dict) and key in value:
                value = value[key]
            else:
                return default_value
        
        return value
    
    def save_config(self, config_name: str, config_data: Dict[str, Any]) -> bool:
        """
        保存配置到文件
        
        参数:
            config_name: 配置名称
            config_data: 配置数据
            
        返回:
            是否保存成功
        """
        try:
            # 确保配置目录存在
            os.makedirs(self.config_dir, exist_ok=True)
            
            # 保存为JSON格式
            json_path = os.path.join(self.config_dir, f"{config_name}.json")
            with open(json_path, 'w', encoding='utf-8') as f:
                json.dump(config_data, f, indent=4, ensure_ascii=False)
            
            # 更新内存中的配置
            self.configs[config_name] = config_data
            
            logger.info(f"配置已保存到: {json_path}")
            return True
            
        except Exception as e:
            logger.error(f"保存配置时出错: {str(e)}")
            return False


# 示例用法
if __name__ == "__main__":
    # 创建配置加载器
    config_loader = ConfigLoader()
    
    # 加载图像获取模块配置
    image_acq_config = config_loader.load_config("image_acquisition_config")
    
    if image_acq_config:
        print("成功加载图像获取模块配置:")
        print(f"默认图像尺寸: {image_acq_config['default_image_size']}")
        print(f"支持的文件扩展名: {image_acq_config['file_loading']['supported_extensions']}")
        
        # 使用get_value方法获取特定配置
        camera_id = config_loader.get_value("image_acquisition_config", "camera.default_camera_id", 0)
        print(f"默认摄像头ID: {camera_id}")
    else:
        print("加载配置失败")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天天进步2015

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值