垃圾分类API开发指南:基于ai53_19/garbage_datasets的Flask应用

垃圾分类API开发指南:基于ai53_19/garbage_datasets的Flask应用

【免费下载链接】垃圾分类数据集 【免费下载链接】垃圾分类数据集 项目地址: https://ai.gitcode.com/ai53_19/garbage_datasets

1. 引言

你是否还在为垃圾分类应用开发中的模型集成和API构建而烦恼?本文将详细介绍如何基于ai53_19/garbage_datasets数据集和Flask框架,快速构建一个高效的垃圾分类API服务。通过本文,你将学习到:

  • 数据集结构与类别体系解析
  • 模型加载与预测功能实现
  • Flask API服务构建与优化
  • 完整的部署与测试流程

2. 数据集概述

ai53_19/garbage_datasets是一个包含40个细分类别的垃圾分类图像数据集,适用于目标检测任务。

2.1 数据集结构

garbage_datasets/
├── README.md
├── data.yaml          # 数据集配置文件
├── dataset_infos.json # 数据集元信息
├── garbage_datasets.json
├── garbage_datasets.py # 模型训练与预测代码
└── datasets/
    ├── images/        # 图像数据
    │   ├── train/     # 训练集图像
    │   └── val/       # 验证集图像
    ├── labels/        # 标注数据
    │   ├── train/     # 训练集标注
    │   └── val/       # 验证集标注
    └── videos/        # 视频数据
        └── Cigrette.MP4

2.2 类别体系

数据集包含40个细分类别,分为四大类:

mermaid

详细类别映射

大类细分类别数量
可回收物充电宝、包、化妆品瓶、玩具、塑料碗、塑料衣架、纸袋、插头电线、旧衣物、易拉罐、枕头、毛绒玩具、洗发水瓶、玻璃杯、鞋子、铁砧、纸板箱、调味品瓶、酒瓶、金属食品罐、锅、食用油桶、饮料瓶24
有害垃圾干电池、药膏、过期药品3
厨余垃圾剩饭剩菜、骨头、水果皮、纸浆、茶叶、蔬菜、蛋壳、鱼骨8
其他垃圾快餐盒、污损塑料、烟头、牙签、花盆、竹筷6

3. 环境准备

3.1 安装依赖

# 创建虚拟环境
python -m venv venv
source venv/bin/activate  # Linux/Mac
# 或
venv\Scripts\activate     # Windows

# 安装依赖
pip install flask ultralytics torch matplotlib pillow numpy

3.2 获取项目

git clone https://gitcode.com/ai53_19/garbage_datasets
cd garbage_datasets

4. 模型集成

4.1 模型加载

garbage_datasets.py中已实现GarbageDetector类,用于模型训练和预测。我们需要对其进行封装,以便在API中使用:

from garbage_datasets import GarbageDetector
import os
import torch

class GarbageModel:
    def __init__(self, model_path="best.pt"):
        """初始化模型"""
        self.detector = GarbageDetector()
        self.load_model(model_path)
        # 加载类别名称
        self.class_names = self._load_class_names()
        # 加载类别映射
        self.category_mapping = self._load_category_mapping()
        
    def load_model(self, model_path):
        """加载训练好的模型"""
        if os.path.exists(model_path):
            self.detector.model = YOLO(model_path)
        else:
            raise FileNotFoundError(f"模型文件 {model_path} 不存在")
    
    def _load_class_names(self):
        """从dataset_infos.json加载类别名称"""
        import json
        with open("dataset_infos.json", "r", encoding="utf-8") as f:
            data = json.load(f)
        return {item["id"]: item["name_cn"] for item in data["categories"]}
    
    def _load_category_mapping(self):
        """从data.yaml加载大类映射"""
        import yaml
        with open("data.yaml", "r", encoding="utf-8") as f:
            data = yaml.safe_load(f)
        return data["category_mapping"]
    
    def predict(self, image_path):
        """预测图像中的垃圾类别"""
        results = self.detector.model(image_path)
        
        # 处理预测结果
        predictions = []
        for result in results:
            boxes = result.boxes
            for box in boxes:
                cls = int(box.cls[0])
                conf = float(box.conf[0])
                # 获取中文类别名
                class_name = self.class_names.get(cls, "未知类别")
                # 获取大类
                category = self._get_category(class_name)
                
                predictions.append({
                    "class_id": cls,
                    "class_name": class_name,
                    "confidence": round(conf, 4),
                    "category": category,
                    "bbox": box.xyxy.tolist()[0]  # 边界框坐标
                })
        
        return predictions
    
    def _get_category(self, class_name):
        """根据细分类别获取大类"""
        for category, subclasses in self.category_mapping.items():
            if class_name in subclasses:
                return category
        return "其他垃圾"

4.2 模型训练(可选)

如果没有预训练模型,可以使用以下代码进行训练:

# train_model.py
from garbage_datasets import GarbageDetector

if __name__ == "__main__":
    detector = GarbageDetector()
    # 训练新模型
    detector.train("data.yaml")
    # 或继续训练
    # detector.train("data.yaml", weights_path="runs/detect/train/weights/best.pt")

训练参数说明:

{
    "epochs": 100,           # 训练轮数
    "imgsz": 1024,           # 图片尺寸
    "batch": 32,             # 批次大小
    "workers": 8,            # 数据加载器的工作进程数
    "device": '0' if torch.cuda.is_available() else 'cpu',  # 设备选择
    "optimizer": 'AdamW',    # 优化器
    "lr0": 0.001,            # 初始学习率
    "weight_decay": 0.0005   # 权重衰减
}

5. Flask API开发

5.1 项目结构

garbage_api/
├── app/
│   ├── __init__.py       # Flask应用初始化
│   ├── models/           # 模型相关代码
│   │   └── garbage_model.py  # 模型封装类
│   ├── routes/           # API路由
│   │   └── predict.py    # 预测接口
│   ├── utils/            # 工具函数
│   │   ├── image_processing.py  # 图像处理工具
│   │   └── response.py   # 响应格式化工具
│   └── config.py         # 配置文件
├── uploads/              # 上传图片存储
├── best.pt               # 训练好的模型
├── run.py                # 应用入口
└── requirements.txt      # 依赖列表

5.2 初始化Flask应用

# app/__init__.py
from flask import Flask
from app.routes.predict import predict_bp
import os

def create_app():
    app = Flask(__name__)
    
    # 配置上传文件夹
    app.config['UPLOAD_FOLDER'] = os.path.join(os.path.dirname(__file__), '../uploads')
    app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB
    
    # 创建上传文件夹
    os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
    
    # 注册蓝图
    app.register_blueprint(predict_bp, url_prefix='/api')
    
    return app

5.3 预测接口实现

# app/routes/predict.py
from flask import Blueprint, request, jsonify
from app.models.garbage_model import GarbageModel
from app.utils.response import api_response
import os
import uuid

predict_bp = Blueprint('predict', __name__)

# 初始化模型(全局单例)
model = GarbageModel()

@predict_bp.route('/predict', methods=['POST'])
def predict():
    """垃圾图像预测接口"""
    if 'image' not in request.files:
        return api_response(code=400, message="未提供图像文件")
    
    file = request.files['image']
    if file.filename == '':
        return api_response(code=400, message="未选择文件")
    
    # 保存上传的文件
    filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]
    file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
    file.save(file_path)
    
    try:
        # 进行预测
        results = model.predict(file_path)
        
        return api_response(
            code=200, 
            message="预测成功", 
            data={
                "predictions": results,
                "image_filename": filename
            }
        )
    except Exception as e:
        return api_response(code=500, message=f"预测失败: {str(e)}")
    finally:
        # 可选:删除临时文件
        # os.remove(file_path)

@predict_bp.route('/health', methods=['GET'])
def health_check():
    """健康检查接口"""
    return api_response(code=200, message="API服务正常运行")

5.4 响应工具函数

# app/utils/response.py
def api_response(code=200, message="success", data=None):
    """API响应格式化"""
    response = {
        "code": code,
        "message": message
    }
    if data is not None:
        response["data"] = data
    return jsonify(response)

5.5 应用入口

# run.py
from app import create_app

app = create_app()

if __name__ == "__main__":
    app.run(
        host='0.0.0.0', 
        port=5000, 
        debug=False,
        threaded=True  # 启用多线程处理请求
    )

6. API文档与测试

6.1 API接口说明

接口方法描述请求参数响应
/api/healthGET健康检查{"code":200,"message":"API服务正常运行"}
/api/predictPOST垃圾图像预测image: 图像文件预测结果JSON

6.2 使用curl测试

# 健康检查
curl http://localhost:5000/api/health

# 图像预测
curl -X POST -F "image=@test.jpg" http://localhost:5000/api/predict

6.3 预测响应示例

{
  "code": 200,
  "message": "预测成功",
  "data": {
    "predictions": [
      {
        "class_id": 36,
        "class_name": "饮料瓶",
        "confidence": 0.9876,
        "category": "Recyclables",
        "bbox": [120.5, 80.3, 350.2, 420.8]
      }
    ],
    "image_filename": "a1b2c3d4-5678-90ef-ghij-klmnopqrstuv.jpg"
  }
}

7. 性能优化与部署

7.1 性能优化策略

mermaid

7.2 Docker部署

创建Dockerfile

FROM python:3.9-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

# 创建上传目录
RUN mkdir -p uploads

EXPOSE 5000

CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "run:app"]

构建和运行:

# 构建镜像
docker build -t garbage-api .

# 运行容器
docker run -d -p 5000:5000 --name garbage-api-container garbage-api

7.3 生产环境配置

使用Gunicorn作为WSGI服务器:

gunicorn --bind 0.0.0.0:5000 --workers 4 --threads 2 run:app

参数说明:

  • --workers: 工作进程数,建议设置为CPU核心数+1
  • --threads: 每个工作进程的线程数
  • --timeout: 超时时间,单位秒

8. 总结与扩展

8.1 项目总结

本文基于ai53_19/garbage_datasets数据集和Flask框架,构建了一个功能完善的垃圾分类API服务。主要实现了:

  1. 数据集解析与类别体系构建
  2. YOLO模型封装与预测功能实现
  3. RESTful API设计与实现
  4. 性能优化与部署指南

8.2 扩展方向

  1. 前端界面:开发Web界面或移动应用,提供直观的垃圾分类体验
  2. 批量处理:支持多图像批量预测和视频流实时处理
  3. 模型更新:实现模型自动更新机制,定期重新训练
  4. 日志监控:添加详细日志和性能监控
  5. 多语言支持:支持中英文等多语言类别名称

通过本文的指南,你可以快速搭建一个高效、可靠的垃圾分类API服务,为垃圾分类应用提供强大的后端支持。

9. 附录

9.1 完整代码结构

garbage_datasets/
├── app/
│   ├── __init__.py
│   ├── models/
│   │   └── garbage_model.py
│   ├── routes/
│   │   └── predict.py
│   ├── utils/
│   │   └── response.py
│   └── config.py
├── uploads/
├── best.pt
├── data.yaml
├── dataset_infos.json
├── garbage_datasets.py
├── requirements.txt
├── run.py
└── Dockerfile

9.2 requirements.txt

flask==2.3.3
ultralytics==8.0.200
torch==2.0.1
matplotlib==3.7.2
pillow==10.0.0
numpy==1.25.2
pyyaml==6.0.1
gunicorn==21.2.0

【免费下载链接】垃圾分类数据集 【免费下载链接】垃圾分类数据集 项目地址: https://ai.gitcode.com/ai53_19/garbage_datasets

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值