垃圾分类API开发指南:基于ai53_19/garbage_datasets的Flask应用
【免费下载链接】垃圾分类数据集 项目地址: https://ai.gitcode.com/ai53_19/garbage_datasets
1. 引言
你是否还在为垃圾分类应用开发中的模型集成和API构建而烦恼?本文将详细介绍如何基于ai53_19/garbage_datasets数据集和Flask框架,快速构建一个高效的垃圾分类API服务。通过本文,你将学习到:
- 数据集结构与类别体系解析
- 模型加载与预测功能实现
- Flask API服务构建与优化
- 完整的部署与测试流程
2. 数据集概述
ai53_19/garbage_datasets是一个包含40个细分类别的垃圾分类图像数据集,适用于目标检测任务。
2.1 数据集结构
garbage_datasets/
├── README.md
├── data.yaml # 数据集配置文件
├── dataset_infos.json # 数据集元信息
├── garbage_datasets.json
├── garbage_datasets.py # 模型训练与预测代码
└── datasets/
├── images/ # 图像数据
│ ├── train/ # 训练集图像
│ └── val/ # 验证集图像
├── labels/ # 标注数据
│ ├── train/ # 训练集标注
│ └── val/ # 验证集标注
└── videos/ # 视频数据
└── Cigrette.MP4
2.2 类别体系
数据集包含40个细分类别,分为四大类:
详细类别映射:
| 大类 | 细分类别 | 数量 |
|---|---|---|
| 可回收物 | 充电宝、包、化妆品瓶、玩具、塑料碗、塑料衣架、纸袋、插头电线、旧衣物、易拉罐、枕头、毛绒玩具、洗发水瓶、玻璃杯、鞋子、铁砧、纸板箱、调味品瓶、酒瓶、金属食品罐、锅、食用油桶、饮料瓶 | 24 |
| 有害垃圾 | 干电池、药膏、过期药品 | 3 |
| 厨余垃圾 | 剩饭剩菜、骨头、水果皮、纸浆、茶叶、蔬菜、蛋壳、鱼骨 | 8 |
| 其他垃圾 | 快餐盒、污损塑料、烟头、牙签、花盆、竹筷 | 6 |
3. 环境准备
3.1 安装依赖
# 创建虚拟环境
python -m venv venv
source venv/bin/activate # Linux/Mac
# 或
venv\Scripts\activate # Windows
# 安装依赖
pip install flask ultralytics torch matplotlib pillow numpy
3.2 获取项目
git clone https://gitcode.com/ai53_19/garbage_datasets
cd garbage_datasets
4. 模型集成
4.1 模型加载
garbage_datasets.py中已实现GarbageDetector类,用于模型训练和预测。我们需要对其进行封装,以便在API中使用:
from garbage_datasets import GarbageDetector
import os
import torch
class GarbageModel:
def __init__(self, model_path="best.pt"):
"""初始化模型"""
self.detector = GarbageDetector()
self.load_model(model_path)
# 加载类别名称
self.class_names = self._load_class_names()
# 加载类别映射
self.category_mapping = self._load_category_mapping()
def load_model(self, model_path):
"""加载训练好的模型"""
if os.path.exists(model_path):
self.detector.model = YOLO(model_path)
else:
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
def _load_class_names(self):
"""从dataset_infos.json加载类别名称"""
import json
with open("dataset_infos.json", "r", encoding="utf-8") as f:
data = json.load(f)
return {item["id"]: item["name_cn"] for item in data["categories"]}
def _load_category_mapping(self):
"""从data.yaml加载大类映射"""
import yaml
with open("data.yaml", "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
return data["category_mapping"]
def predict(self, image_path):
"""预测图像中的垃圾类别"""
results = self.detector.model(image_path)
# 处理预测结果
predictions = []
for result in results:
boxes = result.boxes
for box in boxes:
cls = int(box.cls[0])
conf = float(box.conf[0])
# 获取中文类别名
class_name = self.class_names.get(cls, "未知类别")
# 获取大类
category = self._get_category(class_name)
predictions.append({
"class_id": cls,
"class_name": class_name,
"confidence": round(conf, 4),
"category": category,
"bbox": box.xyxy.tolist()[0] # 边界框坐标
})
return predictions
def _get_category(self, class_name):
"""根据细分类别获取大类"""
for category, subclasses in self.category_mapping.items():
if class_name in subclasses:
return category
return "其他垃圾"
4.2 模型训练(可选)
如果没有预训练模型,可以使用以下代码进行训练:
# train_model.py
from garbage_datasets import GarbageDetector
if __name__ == "__main__":
detector = GarbageDetector()
# 训练新模型
detector.train("data.yaml")
# 或继续训练
# detector.train("data.yaml", weights_path="runs/detect/train/weights/best.pt")
训练参数说明:
{
"epochs": 100, # 训练轮数
"imgsz": 1024, # 图片尺寸
"batch": 32, # 批次大小
"workers": 8, # 数据加载器的工作进程数
"device": '0' if torch.cuda.is_available() else 'cpu', # 设备选择
"optimizer": 'AdamW', # 优化器
"lr0": 0.001, # 初始学习率
"weight_decay": 0.0005 # 权重衰减
}
5. Flask API开发
5.1 项目结构
garbage_api/
├── app/
│ ├── __init__.py # Flask应用初始化
│ ├── models/ # 模型相关代码
│ │ └── garbage_model.py # 模型封装类
│ ├── routes/ # API路由
│ │ └── predict.py # 预测接口
│ ├── utils/ # 工具函数
│ │ ├── image_processing.py # 图像处理工具
│ │ └── response.py # 响应格式化工具
│ └── config.py # 配置文件
├── uploads/ # 上传图片存储
├── best.pt # 训练好的模型
├── run.py # 应用入口
└── requirements.txt # 依赖列表
5.2 初始化Flask应用
# app/__init__.py
from flask import Flask
from app.routes.predict import predict_bp
import os
def create_app():
app = Flask(__name__)
# 配置上传文件夹
app.config['UPLOAD_FOLDER'] = os.path.join(os.path.dirname(__file__), '../uploads')
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
# 创建上传文件夹
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 注册蓝图
app.register_blueprint(predict_bp, url_prefix='/api')
return app
5.3 预测接口实现
# app/routes/predict.py
from flask import Blueprint, request, jsonify
from app.models.garbage_model import GarbageModel
from app.utils.response import api_response
import os
import uuid
predict_bp = Blueprint('predict', __name__)
# 初始化模型(全局单例)
model = GarbageModel()
@predict_bp.route('/predict', methods=['POST'])
def predict():
"""垃圾图像预测接口"""
if 'image' not in request.files:
return api_response(code=400, message="未提供图像文件")
file = request.files['image']
if file.filename == '':
return api_response(code=400, message="未选择文件")
# 保存上传的文件
filename = str(uuid.uuid4()) + os.path.splitext(file.filename)[1]
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(file_path)
try:
# 进行预测
results = model.predict(file_path)
return api_response(
code=200,
message="预测成功",
data={
"predictions": results,
"image_filename": filename
}
)
except Exception as e:
return api_response(code=500, message=f"预测失败: {str(e)}")
finally:
# 可选:删除临时文件
# os.remove(file_path)
@predict_bp.route('/health', methods=['GET'])
def health_check():
"""健康检查接口"""
return api_response(code=200, message="API服务正常运行")
5.4 响应工具函数
# app/utils/response.py
def api_response(code=200, message="success", data=None):
"""API响应格式化"""
response = {
"code": code,
"message": message
}
if data is not None:
response["data"] = data
return jsonify(response)
5.5 应用入口
# run.py
from app import create_app
app = create_app()
if __name__ == "__main__":
app.run(
host='0.0.0.0',
port=5000,
debug=False,
threaded=True # 启用多线程处理请求
)
6. API文档与测试
6.1 API接口说明
| 接口 | 方法 | 描述 | 请求参数 | 响应 |
|---|---|---|---|---|
/api/health | GET | 健康检查 | 无 | {"code":200,"message":"API服务正常运行"} |
/api/predict | POST | 垃圾图像预测 | image: 图像文件 | 预测结果JSON |
6.2 使用curl测试
# 健康检查
curl http://localhost:5000/api/health
# 图像预测
curl -X POST -F "image=@test.jpg" http://localhost:5000/api/predict
6.3 预测响应示例
{
"code": 200,
"message": "预测成功",
"data": {
"predictions": [
{
"class_id": 36,
"class_name": "饮料瓶",
"confidence": 0.9876,
"category": "Recyclables",
"bbox": [120.5, 80.3, 350.2, 420.8]
}
],
"image_filename": "a1b2c3d4-5678-90ef-ghij-klmnopqrstuv.jpg"
}
}
7. 性能优化与部署
7.1 性能优化策略
7.2 Docker部署
创建Dockerfile:
FROM python:3.9-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
# 创建上传目录
RUN mkdir -p uploads
EXPOSE 5000
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "run:app"]
构建和运行:
# 构建镜像
docker build -t garbage-api .
# 运行容器
docker run -d -p 5000:5000 --name garbage-api-container garbage-api
7.3 生产环境配置
使用Gunicorn作为WSGI服务器:
gunicorn --bind 0.0.0.0:5000 --workers 4 --threads 2 run:app
参数说明:
--workers: 工作进程数,建议设置为CPU核心数+1--threads: 每个工作进程的线程数--timeout: 超时时间,单位秒
8. 总结与扩展
8.1 项目总结
本文基于ai53_19/garbage_datasets数据集和Flask框架,构建了一个功能完善的垃圾分类API服务。主要实现了:
- 数据集解析与类别体系构建
- YOLO模型封装与预测功能实现
- RESTful API设计与实现
- 性能优化与部署指南
8.2 扩展方向
- 前端界面:开发Web界面或移动应用,提供直观的垃圾分类体验
- 批量处理:支持多图像批量预测和视频流实时处理
- 模型更新:实现模型自动更新机制,定期重新训练
- 日志监控:添加详细日志和性能监控
- 多语言支持:支持中英文等多语言类别名称
通过本文的指南,你可以快速搭建一个高效、可靠的垃圾分类API服务,为垃圾分类应用提供强大的后端支持。
9. 附录
9.1 完整代码结构
garbage_datasets/
├── app/
│ ├── __init__.py
│ ├── models/
│ │ └── garbage_model.py
│ ├── routes/
│ │ └── predict.py
│ ├── utils/
│ │ └── response.py
│ └── config.py
├── uploads/
├── best.pt
├── data.yaml
├── dataset_infos.json
├── garbage_datasets.py
├── requirements.txt
├── run.py
└── Dockerfile
9.2 requirements.txt
flask==2.3.3
ultralytics==8.0.200
torch==2.0.1
matplotlib==3.7.2
pillow==10.0.0
numpy==1.25.2
pyyaml==6.0.1
gunicorn==21.2.0
【免费下载链接】垃圾分类数据集 项目地址: https://ai.gitcode.com/ai53_19/garbage_datasets
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



