视频讲解:
项目演示:




主要代码:
import os
import numpy as np
import cv2
import tensorflow as tf
from PIL import Image
from flask import Flask, render_template, request, redirect, url_for, session, flash
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
import tensorflow_hub as hub
import hashlib
# 初始化Flask应用
app = Flask(__name__)
app.config['SECRET_KEY'] = 'a_random_secret_key_for_session_management'
app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///agri_disease.db'
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB
# 确保上传文件夹存在
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
# 初始化数据库
db = SQLAlchemy(app)
# 模型路径
MODEL_PATH = 'plant_disease_model.h5'
MODEL_LOADED = False
model = None
# 加载模型函数 - 使用用户提供的代码
def load_model():
global model, MODEL_LOADED
try:
# 尝试加载本地模型
model = tf.keras.models.load_model(MODEL_PATH)
MODEL_LOADED = True
print("成功加载本地模型")
except Exception as e:
print(f"本地模型加载失败: {e}")
print("从TensorFlow Hub下载并创建新模型...")
try:
# 从TensorFlow Hub加载模型并构建
hub_url = "https://tfhub.dev/google/efficientnet/b3/classification/1"
model = tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(300, 300, 3)),
hub.KerasLayer(hub_url),
tf.keras.layers.Dense(38, activation='softmax')
])
# 保存模型供下次使用
model.save(MODEL_PATH)
MODEL_LOADED = True
print("新模型创建并保存成功")
except Exception as e:
print(f"模型创建失败: {e}")
MODEL_LOADED = False
return model
# 加载模型
model = load_model()
# 疾病类别 - 注意:这里需要与您的38个类别对应,请根据实际情况修改
# 38种常见农业病虫害类别(已补充完整)
DISEASE_CLASSES = [
'健康叶片',
# 真菌性病害
'小麦锈病', '小麦白粉病', '玉米大斑病', '玉米小斑病',
'水稻稻瘟病', '水稻纹枯病', '黄瓜霜霉病', '黄瓜白粉病',
'番茄晚疫病', '番茄早疫病', '番茄叶霉病', '马铃薯晚疫病',
'葡萄霜霉病', '苹果腐烂病', '苹果斑点落叶病', '草莓灰霉病',
# 细菌性病害
'黄瓜细菌性角斑病', '番茄青枯病', '白菜软腐病', '水稻白叶枯病',
'马铃薯环腐病', '柑橘溃疡病', '桃树细菌性穿孔病',
# 病毒性病害
'烟草花叶病毒病', '番茄黄化曲叶病毒病', '黄瓜花叶病毒病',
'马铃薯Y病毒病', '水稻条纹叶枯病', '玉米矮花叶病毒病',
# 其他病害及虫害
'小麦赤霉病', '油菜菌核病', '棉花枯萎病', '大豆根腐病',
'花生叶斑病', '辣椒炭疽病', '茄子褐纹病', '西瓜炭疽病',
'茶树茶饼病', '棉花黄萎病'
]
# 用户模型
class User(db.Model):
id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(50), unique=True, nullable=False)
password = db.Column(db.String(100), nullable=False)
created_at = db.Column(db.DateTime, default=datetime.utcnow)
history = db.relationship('DetectionHistory', backref='user', lazy=True)
# 检测历史模型
class DetectionHistory(db.Model):
id = db.Column(db.Integer, primary_key=True)
user_id = db.Column(db.Integer, db.ForeignKey('user.id'), nullable=False)
image_path = db.Column(db.String(200), nullable=False)
result = db.Column(db.String(100), nullable=False)
confidence = db.Column(db.Float, nullable=False)
detected_at = db.Column(db.DateTime, default=datetime.utcnow)
# 创建数据库表
with app.app_context():
db.create_all()
# 密码加密函数
def hash_password(password):
return hashlib.sha256(password.encode()).hexdigest()
# 图像预处理函数 - 调整为300x300以匹配模型输入
def preprocess_image(image_path):
try:
img = cv2.imread(image_path)
img = cv2.resize(img, (300, 300)) # 与模型输入形状匹配
img = img / 255.0 # 归一化
img = np.expand_dims(img, axis=0) # 添加批次维度
return img
except Exception as e:
print(f"图像预处理错误: {e}")
return None
# 预测函数
def predict_disease(image_path):
if not MODEL_LOADED or model is None:
return None, "模型加载失败,无法进行预测"
preprocessed_img = preprocess_image(image_path)
if preprocessed_img is None:
return None, "图像处理失败"
try:
predictions = model.predict(preprocessed_img)
predicted_class = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class]) * 100
# 检查预测类别是否在已知类别范围内
if predicted_class < len(DISEASE_CLASSES):
disease_name = DISEASE_CLASSES[predicted_class]
else:
disease_name = f"未知类别 ({predicted_class})"
return {
'disease': disease_name,
'confidence': round(confidence, 2)
}, None
except Exception as e:
return None, f"预测过程出错: {str(e)}"
# 路由
@app.route('/')
def home():
if 'user_id' in session:
return redirect(url_for('dashboard'))
return redirect(url_for('login'))
@app.route('/login', methods=['GET', 'POST'])
def login():
if request.method == 'POST':
username = request.form['username']
password = hash_password(request.form['password'])
user = User.query.filter_by(username=username, password=password).first()
if user:
session['user_id'] = user.id
session['username'] = user.username
flash('登录成功!', 'success')
return redirect(url_for('dashboard'))
else:
flash('用户名或密码不正确', 'danger')
return render_template('login.html')
@app.route('/register', methods=['GET', 'POST'])
def register():
if request.method == 'POST':
username = request.form['username']
password = request.form['password']
confirm_password = request.form['confirm_password']
if password != confirm_password:
flash('两次密码输入不一致', 'danger')
return redirect(url_for('register'))
if User.query.filter_by(username=username).first():
flash('用户名已存在', 'danger')
return redirect(url_for('register'))
new_user = User(
username=username,
password=hash_password(password)
)
try:
db.session.add(new_user)
db.session.commit()
flash('注册成功,请登录', 'success')
return redirect(url_for('login'))
except Exception as e:
db.session.rollback()
flash(f'注册失败: {str(e)}', 'danger')
return render_template('register.html')
@app.route('/dashboard')
def dashboard():
if 'user_id' not in session:
flash('请先登录', 'warning')
return redirect(url_for('login'))
return render_template('dashboard.html', username=session['username'], model_loaded=MODEL_LOADED)
@app.route('/detect', methods=['POST'])
def detect():
if 'user_id' not in session:
flash('请先登录', 'warning')
return redirect(url_for('login'))
if 'image' not in request.files:
flash('未选择图像', 'danger')
return redirect(url_for('dashboard'))
image = request.files['image']
if image.filename == '':
flash('未选择图像', 'danger')
return redirect(url_for('dashboard'))
if image and image.filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp')):
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
filename = f"{session['user_id']}_{timestamp}_{image.filename}"
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
image.save(image_path)
result, error = predict_disease(image_path)
if error:
flash(error, 'danger')
return redirect(url_for('dashboard'))
new_history = DetectionHistory(
user_id=session['user_id'],
image_path=filename,
result=result['disease'],
confidence=result['confidence']
)
try:
db.session.add(new_history)
db.session.commit()
flash('检测完成!', 'success')
return render_template('result.html',
result=result,
image_path=os.path.join('uploads', filename))
except Exception as e:
db.session.rollback()
flash(f'保存记录失败: {str(e)}', 'danger')
return redirect(url_for('dashboard'))
else:
flash('请上传有效的图像文件 (png, jpg, jpeg, bmp)', 'danger')
return redirect(url_for('dashboard'))
@app.route('/history')
def history():
if 'user_id' not in session:
flash('请先登录', 'warning')
return redirect(url_for('login'))
history_records = DetectionHistory.query.filter_by(
user_id=session['user_id']
).order_by(
DetectionHistory.detected_at.desc()
).all()
return render_template('history.html', history=history_records)
@app.route('/logout')
def logout():
session.pop('user_id', None)
session.pop('username', None)
flash('已成功退出登录', 'success')
return redirect(url_for('login'))
# 错误处理
@app.errorhandler(404)
def page_not_found(e):
return render_template('error.html', error=404, message='页面未找到'), 404
@app.errorhandler(500)
def internal_server_error(e):
return render_template('error.html', error=500, message='服务器内部错误'), 500
if __name__ == '__main__':
app.run(debug=True)
完整代码下载:
https://download.youkuaiyun.com/download/qq_38735017/91669433
774

被折叠的 条评论
为什么被折叠?



