基于深度学习的农业病虫害识别系统 -完整代码可直接运行

视频讲解:

项目演示:

主要代码:

import os
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from flask import Flask, render_template, request, redirect, url_for, session, flash
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
import tensorflow_hub as hub
import hashlib

# 初始化Flask应用
app = Flask(__name__)
app.config['SECRET_KEY'] = 'a_random_secret_key_for_session_management'
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///agri_disease.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB

# 确保上传文件夹存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)

# 初始化数据库
db = SQLAlchemy(app)

# 模型路径
MODEL_PATH = 'plant_disease_model.h5'
MODEL_LOADED = False
model = None


# 加载模型函数 - 使用用户提供的代码
def load_model():
    global model, MODEL_LOADED
    try:
        # 尝试加载本地模型
        model = tf.keras.models.load_model(MODEL_PATH)
        MODEL_LOADED = True
        print("成功加载本地模型")
    except Exception as e:
        print(f"本地模型加载失败: {e}")
        print("从TensorFlow Hub下载并创建新模型...")
        try:
            # 从TensorFlow Hub加载模型并构建
            hub_url = "https://tfhub.dev/google/efficientnet/b3/classification/1"
            model = tf.keras.Sequential([
                tf.keras.layers.InputLayer(input_shape=(300, 300, 3)),
                hub.KerasLayer(hub_url),
                tf.keras.layers.Dense(38, activation='softmax')
            ])
            # 保存模型供下次使用
            model.save(MODEL_PATH)
            MODEL_LOADED = True
            print("新模型创建并保存成功")
        except Exception as e:
            print(f"模型创建失败: {e}")
            MODEL_LOADED = False
    return model


# 加载模型
model = load_model()

# 疾病类别 - 注意:这里需要与您的38个类别对应,请根据实际情况修改
# 38种常见农业病虫害类别(已补充完整)
DISEASE_CLASSES = [
    '健康叶片',
    # 真菌性病害
    '小麦锈病', '小麦白粉病', '玉米大斑病', '玉米小斑病',
    '水稻稻瘟病', '水稻纹枯病', '黄瓜霜霉病', '黄瓜白粉病',
    '番茄晚疫病', '番茄早疫病', '番茄叶霉病', '马铃薯晚疫病',
    '葡萄霜霉病', '苹果腐烂病', '苹果斑点落叶病', '草莓灰霉病',
    # 细菌性病害
    '黄瓜细菌性角斑病', '番茄青枯病', '白菜软腐病', '水稻白叶枯病',
    '马铃薯环腐病', '柑橘溃疡病', '桃树细菌性穿孔病',
    # 病毒性病害
    '烟草花叶病毒病', '番茄黄化曲叶病毒病', '黄瓜花叶病毒病',
    '马铃薯Y病毒病', '水稻条纹叶枯病', '玉米矮花叶病毒病',
    # 其他病害及虫害
    '小麦赤霉病', '油菜菌核病', '棉花枯萎病', '大豆根腐病',
    '花生叶斑病', '辣椒炭疽病', '茄子褐纹病', '西瓜炭疽病',
    '茶树茶饼病', '棉花黄萎病'
]


# 用户模型
class User(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    username = db.Column(db.String(50), unique=True, nullable=False)
    password = db.Column(db.String(100), nullable=False)
    created_at = db.Column(db.DateTime, default=datetime.utcnow)
    history = db.relationship('DetectionHistory', backref='user', lazy=True)


# 检测历史模型
class DetectionHistory(db.Model):
    id = db.Column(db.Integer, primary_key=True)
    user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
    image_path = db.Column(db.String(200), nullable=False)
    result = db.Column(db.String(100), nullable=False)
    confidence = db.Column(db.Float, nullable=False)
    detected_at = db.Column(db.DateTime, default=datetime.utcnow)


# 创建数据库表
with app.app_context():
    db.create_all()


# 密码加密函数
def hash_password(password):
    return hashlib.sha256(password.encode()).hexdigest()


# 图像预处理函数 - 调整为300x300以匹配模型输入
def preprocess_image(image_path):
    try:
        img = cv2.imread(image_path)
        img = cv2.resize(img, (300, 300))  # 与模型输入形状匹配
        img = img / 255.0  # 归一化
        img = np.expand_dims(img, axis=0)  # 添加批次维度
        return img
    except Exception as e:
        print(f"图像预处理错误: {e}")
        return None


# 预测函数
def predict_disease(image_path):
    if not MODEL_LOADED or model is None:
        return None, "模型加载失败,无法进行预测"

    preprocessed_img = preprocess_image(image_path)
    if preprocessed_img is None:
        return None, "图像处理失败"

    try:
        predictions = model.predict(preprocessed_img)
        predicted_class = np.argmax(predictions[0])
        confidence = float(predictions[0][predicted_class]) * 100

        # 检查预测类别是否在已知类别范围内
        if predicted_class < len(DISEASE_CLASSES):
            disease_name = DISEASE_CLASSES[predicted_class]
        else:
            disease_name = f"未知类别 ({predicted_class})"

        return {
            'disease': disease_name,
            'confidence': round(confidence, 2)
        }, None
    except Exception as e:
        return None, f"预测过程出错: {str(e)}"


# 路由
@app.route('/')
def home():
    if 'user_id' in session:
        return redirect(url_for('dashboard'))
    return redirect(url_for('login'))


@app.route('/login', methods=['GET', 'POST'])
def login():
    if request.method == 'POST':
        username = request.form['username']
        password = hash_password(request.form['password'])

        user = User.query.filter_by(username=username, password=password).first()

        if user:
            session['user_id'] = user.id
            session['username'] = user.username
            flash('登录成功!', 'success')
            return redirect(url_for('dashboard'))
        else:
            flash('用户名或密码不正确', 'danger')

    return render_template('login.html')


@app.route('/register', methods=['GET', 'POST'])
def register():
    if request.method == 'POST':
        username = request.form['username']
        password = request.form['password']
        confirm_password = request.form['confirm_password']

        if password != confirm_password:
            flash('两次密码输入不一致', 'danger')
            return redirect(url_for('register'))

        if User.query.filter_by(username=username).first():
            flash('用户名已存在', 'danger')
            return redirect(url_for('register'))

        new_user = User(
            username=username,
            password=hash_password(password)
        )

        try:
            db.session.add(new_user)
            db.session.commit()
            flash('注册成功,请登录', 'success')
            return redirect(url_for('login'))
        except Exception as e:
            db.session.rollback()
            flash(f'注册失败: {str(e)}', 'danger')

    return render_template('register.html')


@app.route('/dashboard')
def dashboard():
    if 'user_id' not in session:
        flash('请先登录', 'warning')
        return redirect(url_for('login'))

    return render_template('dashboard.html', username=session['username'], model_loaded=MODEL_LOADED)


@app.route('/detect', methods=['POST'])
def detect():
    if 'user_id' not in session:
        flash('请先登录', 'warning')
        return redirect(url_for('login'))

    if 'image' not in request.files:
        flash('未选择图像', 'danger')
        return redirect(url_for('dashboard'))

    image = request.files['image']

    if image.filename == '':
        flash('未选择图像', 'danger')
        return redirect(url_for('dashboard'))

    if image and image.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
        timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
        filename = f"{session['user_id']}_{timestamp}_{image.filename}"
        image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        image.save(image_path)

        result, error = predict_disease(image_path)

        if error:
            flash(error, 'danger')
            return redirect(url_for('dashboard'))

        new_history = DetectionHistory(
            user_id=session['user_id'],
            image_path=filename,
            result=result['disease'],
            confidence=result['confidence']
        )

        try:
            db.session.add(new_history)
            db.session.commit()
            flash('检测完成!', 'success')
            return render_template('result.html',
                                   result=result,
                                   image_path=os.path.join('uploads', filename))
        except Exception as e:
            db.session.rollback()
            flash(f'保存记录失败: {str(e)}', 'danger')
            return redirect(url_for('dashboard'))
    else:
        flash('请上传有效的图像文件 (png, jpg, jpeg, bmp)', 'danger')
        return redirect(url_for('dashboard'))


@app.route('/history')
def history():
    if 'user_id' not in session:
        flash('请先登录', 'warning')
        return redirect(url_for('login'))

    history_records = DetectionHistory.query.filter_by(
        user_id=session['user_id']
    ).order_by(
        DetectionHistory.detected_at.desc()
    ).all()

    return render_template('history.html', history=history_records)


@app.route('/logout')
def logout():
    session.pop('user_id', None)
    session.pop('username', None)
    flash('已成功退出登录', 'success')
    return redirect(url_for('login'))


# 错误处理
@app.errorhandler(404)
def page_not_found(e):
    return render_template('error.html', error=404, message='页面未找到'), 404


@app.errorhandler(500)
def internal_server_error(e):
    return render_template('error.html', error=500, message='服务器内部错误'), 500


if __name__ == '__main__':
    app.run(debug=True)

完整代码下载:

https://download.youkuaiyun.com/download/qq_38735017/91669433

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

程序员奇奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值