Python项目:智能医疗辅助诊断系统开发(2)

源代码续

"""
医疗案例数据访问对象 - 提供医疗案例和相关数据的数据库操作
"""

import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from .base_dao import BaseDAO

# 配置日志
logger = logging.getLogger(__name__)

class MedicalCaseDAO(BaseDAO):
    """医疗案例数据访问对象类"""
    
    def __init__(self):
        """初始化医疗案例DAO"""
        super().__init__('medical_cases', 'case_id')
    
    def find_cases_by_patient(self, patient_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询患者的医疗案例
        
        Args:
            patient_id: 患者ID
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        return self.find_by_criteria({'patient_id': patient_id}, limit=limit, offset=offset, 
                                    order_by='created_at', order_direction='DESC')
    
    def find_cases_by_doctor(self, doctor_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询医生负责的医疗案例
        
        Args:
            doctor_id: 医生ID
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        return self.find_by_criteria({'primary_doctor_id': doctor_id}, limit=limit, offset=offset, 
                                    order_by='created_at', order_direction='DESC')
    
    def find_cases_by_status(self, status: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据状态查询医疗案例
        
        Args:
            status: 案例状态
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        return self.find_by_criteria({'status': status}, limit=limit, offset=offset, 
                                    order_by='created_at', order_direction='DESC')
    
    def find_cases_by_priority(self, priority: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据优先级查询医疗案例
        
        Args:
            priority: 案例优先级
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        return self.find_by_criteria({'priority': priority}, limit=limit, offset=offset, 
                                    order_by='created_at', order_direction='DESC')
    
    def find_case_with_details(self, case_id: int) -> Optional[Dict[str, Any]]:
        """查询医疗案例详情,包含患者和医生信息
        
        Args:
            case_id: 案例ID
            
        Returns:
            Dict[str, Any]: 医疗案例详情,如果没有找到则返回None
        """
        query = """
        SELECT mc.*, 
               p.medical_record_number, 
               pu.first_name as patient_first_name, 
               pu.last_name as patient_last_name,
               d.specialization, 
               du.first_name as doctor_first_name, 
               du.last_name as doctor_last_name
        FROM medical_cases mc
        JOIN patients p ON mc.patient_id = p.patient_id
        JOIN users pu ON p.user_id = pu.user_id
        JOIN doctors d ON mc.primary_doctor_id = d.doctor_id
        JOIN users du ON d.user_id = du.user_id
        WHERE mc.case_id = %s
        """
        
        try:
            return self.db.execute_one(query, (case_id,))
        except Exception as e:
            logger.error(f"Error finding case with details: {e}")
            return None
    
    def find_recent_cases(self, limit: int = 10) -> List[Dict[str, Any]]:
        """查询最近的医疗案例
        
        Args:
            limit: 返回记录数量限制
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        return self.find_all(limit=limit, offset=0)
    
    def update_case_status(self, case_id: int, status: str) -> bool:
        """更新案例状态
        
        Args:
            case_id: 案例ID
            status: 新状态
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(case_id, {'status': status, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
    
    def update_case_priority(self, case_id: int, priority: str) -> bool:
        """更新案例优先级
        
        Args:
            case_id: 案例ID
            priority: 新优先级
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(case_id, {'priority': priority, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
    
    def assign_doctor(self, case_id: int, doctor_id: int) -> bool:
        """为案例分配医生
        
        Args:
            case_id: 案例ID
            doctor_id: 医生ID
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(case_id, {'primary_doctor_id': doctor_id, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
    
    def search_cases(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
        """搜索医疗案例
        
        Args:
            search_term: 搜索关键词
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医疗案例列表
        """
        query = """
        SELECT mc.*, 
               pu.first_name as patient_first_name, 
               pu.last_name as patient_last_name,
               du.first_name as doctor_first_name, 
               du.last_name as doctor_last_name
        FROM medical_cases mc
        JOIN patients p ON mc.patient_id = p.patient_id
        JOIN users pu ON p.user_id = pu.user_id
        JOIN doctors d ON mc.primary_doctor_id = d.doctor_id
        JOIN users du ON d.user_id = du.user_id
        WHERE mc.title LIKE %s OR mc.description LIKE %s
        OR pu.first_name LIKE %s OR pu.last_name LIKE %s
        OR du.first_name LIKE %s OR du.last_name LIKE %s
        LIMIT %s OFFSET %s
        """
        
        search_pattern = f"%{search_term}%"
        params = (search_pattern, search_pattern, search_pattern, search_pattern, 
                 search_pattern, search_pattern, limit, offset)
        
        try:
            return self.db.execute_query(query, params)
        except Exception as e:
            logger.error(f"Error searching cases: {e}")
            return []
    
    def count_cases_by_status(self) -> Dict[str, int]:
        """统计各状态的案例数量
        
        Returns:
            Dict[str, int]: 状态及对应的案例数量
        """
        query = """
        SELECT status, COUNT(*) as count
        FROM medical_cases
        GROUP BY status
        """
        
        try:
            results = self.db.execute_query(query)
            return {result['status']: result['count'] for result in results}
        except Exception as e:
            logger.error(f"Error counting cases by status: {e}")
            return {}


class MedicalImageDAO(BaseDAO):
    """医学图像数据访问对象类"""
    
    def __init__(self):
        """初始化医学图像DAO"""
        super().__init__('medical_images', 'image_id')
    
    def find_images_by_case(self, case_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询案例的医学图像
        
        Args:
            case_id: 案例ID
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医学图像列表
        """
        return self.find_by_criteria({'case_id': case_id}, limit=limit, offset=offset, 
                                    order_by='upload_date', order_direction='DESC')
    
    def find_images_by_type(self, image_type: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据图像类型查询医学图像
        
        Args:
            image_type: 图像类型
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医学图像列表
        """
        return self.find_by_criteria({'image_type': image_type}, limit=limit, offset=offset, 
                                    order_by='upload_date', order_direction='DESC')
    
    def find_images_with_analysis(self, case_id: int) -> List[Dict[str, Any]]:
        """查询带有分析结果的医学图像
        
        Args:
            case_id: 案例ID
            
        Returns:
            List[Dict[str, Any]]: 医学图像及分析结果列表
        """
        query = """
        SELECT mi.*, iar.analysis_id, iar.analysis_date, iar.findings, iar.confidence_score
        FROM medical_images mi
        LEFT JOIN image_analysis_results iar ON mi.image_id = iar.image_id
        WHERE mi.case_id = %s
        ORDER BY mi.upload_date DESC
        """
        
        try:
            return self.db.execute_query(query, (case_id,))
        except Exception as e:
            logger.error(f"Error finding images with analysis: {e}")
            return []
    
    def count_images_by_type(self) -> Dict[str, int]:
        """统计各类型的图像数量
        
        Returns:
            Dict[str, int]: 图像类型及对应的数量
        """
        query = """
        SELECT image_type, COUNT(*) as count
        FROM medical_images
        GROUP BY image_type
        """
        
        try:
            results = self.db.execute_query(query)
            return {result['image_type']: result['count'] for result in results}
        except Exception as e:
            logger.error(f"Error counting images by type: {e}")
            return {}


class ImageAnalysisResultDAO(BaseDAO):
    """图像分析结果数据访问对象类"""
    
    def __init__(self):
        """初始化图像分析结果DAO"""
        super().__init__('image_analysis_results', 'analysis_id')
    
    def find_by_image(self, image_id: int) -> Optional[Dict[str, Any]]:
        """查询图像的分析结果
        
        Args:
            image_id: 图像ID
            
        Returns:
            Dict[str, Any]: 分析结果,如果没有找到则返回None
        """
        return self.find_by_criteria({'image_id': image_id}, limit=1, offset=0)[0] if self.find_by_criteria({'image_id': image_id}, limit=1, offset=0) else None
    
    def find_by_model(self, model_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询指定AI模型的分析结果
        
        Args:
            model_id: AI模型ID
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 分析结果列表
        """
        return self.find_by_criteria({'model_id': model_id}, limit=limit, offset=offset, 
                                    order_by='analysis_date', order_direction='DESC')
    
    def find_by_confidence_range(self, min_confidence: float, max_confidence: float, 
                               limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据置信度范围查询分析结果
        
        Args:
            min_confidence: 最小置信度
            max_confidence: 最大置信度
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 分析结果列表
        """
        query = """
        SELECT * FROM image_analysis_results
        WHERE confidence_score BETWEEN %s AND %s
        ORDER BY analysis_date DESC
        LIMIT %s OFFSET %s
        """
        
        try:
            return self.db.execute_query(query, (min_confidence, max_confidence, limit, offset))
        except Exception as e:
            logger.error(f"Error finding analysis by confidence range: {e}")
            return []
    
    def find_with_image_info(self, analysis_id: int) -> Optional[Dict[str, Any]]:
        """查询分析结果,包含图像信息
        
        Args:
            analysis_id: 分析结果ID
            
        Returns:
            Dict[str, Any]: 分析结果及图像信息,如果没有找到则返回None
        """
        query = """
        SELECT iar.*, mi.image_path, mi.image_type, mi.upload_date, 
               mi.case_id, am.model_name, am.version
        FROM image_analysis_results iar
        JOIN medical_images mi ON iar.image_id = mi.image_id
        JOIN ai_models am ON iar.model_id = am.model_id
        WHERE iar.analysis_id = %s
        """
        
        try:
            return self.db.execute_one(query, (analysis_id,))
        except Exception as e:
            logger.error(f"Error finding analysis with image info: {e}")
            return None
    
    def find_latest_analyses(self, limit: int = 10) -> List[Dict[str, Any]]:
        """查询最新的分析结果
        
        Args:
            limit: 返回记录数量限制
            
        Returns:
            List[Dict[str, Any]]: 分析结果列表
        """
        query = """
        SELECT iar.*, mi.image_path, mi.image_type, am.model_name
        FROM image_analysis_results iar
        JOIN medical_images mi ON iar.image_id = mi.image_id
        JOIN ai_models am ON iar.model_id = am.model_id
        ORDER BY iar.analysis_date DESC
        LIMIT %s
        """
        
        try:
            return self.db.execute_query(query, (limit,))
        except Exception as e:
            logger.error(f"Error finding latest analyses: {e}")
            return []
    
    def get_average_confidence_by_model(self) -> Dict[str, float]:
        """获取各AI模型的平均置信度
        
        Returns:
            Dict[str, float]: 模型名称及对应的平均置信度
        """
        query = """
        SELECT am.model_name, AVG(iar.confidence_score) as avg_confidence
        FROM image_analysis_results iar
        JOIN ai_models am ON iar.model_id = am.model_id
        GROUP BY am.model_name
        """
        
        try:
            results = self.db.execute_query(query)
            return {result['model_name']: float(result['avg_confidence']) for result in results}
        except Exception as e:
            logger.error(f"Error getting average confidence by model: {e}")
            return {}

modules\data_management\user_dao.py

"""
用户数据访问对象 - 提供用户、医生和患者相关的数据库操作
"""

import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from .base_dao import BaseDAO

# 配置日志
logger = logging.getLogger(__name__)

class UserDAO(BaseDAO):
    """用户数据访问对象类"""
    
    def __init__(self):
        """初始化用户DAO"""
        super().__init__('users', 'user_id')
    
    def find_by_username(self, username: str) -> Optional[Dict[str, Any]]:
        """根据用户名查找用户
        
        Args:
            username: 用户名
            
        Returns:
            Dict[str, Any]: 用户信息,如果没有找到则返回None
        """
        return self.find_by_criteria({'username': username}, limit=1, offset=0)[0] if self.find_by_criteria({'username': username}, limit=1, offset=0) else None
    
    def find_by_email(self, email: str) -> Optional[Dict[str, Any]]:
        """根据邮箱查找用户
        
        Args:
            email: 邮箱地址
            
        Returns:
            Dict[str, Any]: 用户信息,如果没有找到则返回None
        """
        return self.find_by_criteria({'email': email}, limit=1, offset=0)[0] if self.find_by_criteria({'email': email}, limit=1, offset=0) else None
    
    def find_by_role(self, role: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据角色查找用户
        
        Args:
            role: 用户角色
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 用户列表
        """
        return self.find_by_criteria({'role': role}, limit=limit, offset=offset)
    
    def update_last_login(self, user_id: int) -> bool:
        """更新用户最后登录时间
        
        Args:
            user_id: 用户ID
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(user_id, {'last_login': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
    
    def change_password(self, user_id: int, hashed_password: str) -> bool:
        """更改用户密码
        
        Args:
            user_id: 用户ID
            hashed_password: 加密后的密码
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(user_id, {'password': hashed_password})
    
    def activate_user(self, user_id: int) -> bool:
        """激活用户
        
        Args:
            user_id: 用户ID
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(user_id, {'is_active': True})
    
    def deactivate_user(self, user_id: int) -> bool:
        """停用用户
        
        Args:
            user_id: 用户ID
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(user_id, {'is_active': False})
    
    def search_users(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
        """搜索用户
        
        Args:
            search_term: 搜索关键词
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 用户列表
        """
        query = """
        SELECT * FROM users 
        WHERE username LIKE %s OR email LIKE %s 
        OR first_name LIKE %s OR last_name LIKE %s
        LIMIT %s OFFSET %s
        """
        search_pattern = f"%{search_term}%"
        params = (search_pattern, search_pattern, search_pattern, search_pattern, limit, offset)
        
        try:
            return self.db.execute_query(query, params)
        except Exception as e:
            logger.error(f"Error searching users: {e}")
            return []


class DoctorDAO(BaseDAO):
    """医生数据访问对象类"""
    
    def __init__(self):
        """初始化医生DAO"""
        super().__init__('doctors', 'doctor_id')
    
    def find_with_user_info(self, doctor_id: Optional[int] = None, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询医生信息,包含用户基本信息
        
        Args:
            doctor_id: 医生ID,如果为None则查询所有医生
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医生信息列表
        """
        query = """
        SELECT d.*, u.username, u.email, u.first_name, u.last_name, 
               u.gender, u.profile_image
        FROM doctors d
        JOIN users u ON d.user_id = u.user_id
        """
        
        params = []
        if doctor_id is not None:
            query += " WHERE d.doctor_id = %s"
            params.append(doctor_id)
        
        query += " LIMIT %s OFFSET %s"
        params.extend([limit, offset])
        
        try:
            return self.db.execute_query(query, tuple(params))
        except Exception as e:
            logger.error(f"Error finding doctors with user info: {e}")
            return []
    
    def find_by_specialization(self, specialization: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据专业查询医生
        
        Args:
            specialization: 专业领域
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医生信息列表
        """
        return self.find_by_criteria({'specialization': specialization}, limit=limit, offset=offset)
    
    def find_by_department(self, department: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """根据科室查询医生
        
        Args:
            department: 科室名称
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医生信息列表
        """
        return self.find_by_criteria({'department': department}, limit=limit, offset=offset)
    
    def find_top_rated(self, limit: int = 10) -> List[Dict[str, Any]]:
        """查询评分最高的医生
        
        Args:
            limit: 返回记录数量限制
            
        Returns:
            List[Dict[str, Any]]: 医生信息列表
        """
        query = """
        SELECT d.*, u.first_name, u.last_name, u.profile_image
        FROM doctors d
        JOIN users u ON d.user_id = u.user_id
        ORDER BY d.rating DESC
        LIMIT %s
        """
        
        try:
            return self.db.execute_query(query, (limit,))
        except Exception as e:
            logger.error(f"Error finding top rated doctors: {e}")
            return []
    
    def update_rating(self, doctor_id: int, new_rating: float) -> bool:
        """更新医生评分
        
        Args:
            doctor_id: 医生ID
            new_rating: 新评分
            
        Returns:
            bool: 更新是否成功
        """
        return self.update(doctor_id, {'rating': new_rating})
    
    def search_doctors(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
        """搜索医生
        
        Args:
            search_term: 搜索关键词
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 医生信息列表
        """
        query = """
        SELECT d.*, u.first_name, u.last_name, u.email, u.profile_image
        FROM doctors d
        JOIN users u ON d.user_id = u.user_id
        WHERE d.specialization LIKE %s OR d.department LIKE %s
        OR u.first_name LIKE %s OR u.last_name LIKE %s
        LIMIT %s OFFSET %s
        """
        
        search_pattern = f"%{search_term}%"
        params = (search_pattern, search_pattern, search_pattern, search_pattern, limit, offset)
        
        try:
            return self.db.execute_query(query, params)
        except Exception as e:
            logger.error(f"Error searching doctors: {e}")
            return []


class PatientDAO(BaseDAO):
    """患者数据访问对象类"""
    
    def __init__(self):
        """初始化患者DAO"""
        super().__init__('patients', 'patient_id')
    
    def find_with_user_info(self, patient_id: Optional[int] = None, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询患者信息,包含用户基本信息
        
        Args:
            patient_id: 患者ID,如果为None则查询所有患者
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 患者信息列表
        """
        query = """
        SELECT p.*, u.username, u.email, u.first_name, u.last_name, 
               u.gender, u.date_of_birth, u.address, u.profile_image
        FROM patients p
        JOIN users u ON p.user_id = u.user_id
        """
        
        params = []
        if patient_id is not None:
            query += " WHERE p.patient_id = %s"
            params.append(patient_id)
        
        query += " LIMIT %s OFFSET %s"
        params.extend([limit, offset])
        
        try:
            return self.db.execute_query(query, tuple(params))
        except Exception as e:
            logger.error(f"Error finding patients with user info: {e}")
            return []
    
    def find_by_medical_record_number(self, mrn: str) -> Optional[Dict[str, Any]]:
        """根据病历号查询患者
        
        Args:
            mrn: 病历号
            
        Returns:
            Dict[str, Any]: 患者信息,如果没有找到则返回None
        """
        return self.find_by_criteria({'medical_record_number': mrn}, limit=1, offset=0)[0] if self.find_by_criteria({'medical_record_number': mrn}, limit=1, offset=0) else None
    
    def find_patients_by_doctor(self, doctor_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
        """查询医生的患者
        
        Args:
            doctor_id: 医生ID
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 患者信息列表
        """
        query = """
        SELECT DISTINCT p.*, u.first_name, u.last_name
        FROM patients p
        JOIN users u ON p.user_id = u.user_id
        JOIN medical_cases mc ON p.patient_id = mc.patient_id
        WHERE mc.primary_doctor_id = %s
        LIMIT %s OFFSET %s
        """
        
        try:
            return self.db.execute_query(query, (doctor_id, limit, offset))
        except Exception as e:
            logger.error(f"Error finding patients by doctor: {e}")
            return []
    
    def search_patients(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
        """搜索患者
        
        Args:
            search_term: 搜索关键词
            limit: 返回记录数量限制
            offset: 起始偏移量
            
        Returns:
            List[Dict[str, Any]]: 患者信息列表
        """
        query = """
        SELECT p.*, u.first_name, u.last_name, u.email
        FROM patients p
        JOIN users u ON p.user_id = u.user_id
        WHERE p.medical_record_number LIKE %s
        OR u.first_name LIKE %s OR u.last_name LIKE %s
        LIMIT %s OFFSET %s
        """
        
        search_pattern = f"%{search_term}%"
        params = (search_pattern, search_pattern, search_pattern, limit, offset)
        
        try:
            return self.db.execute_query(query, params)
        except Exception as e:
            logger.error(f"Error searching patients: {e}")
            return []

modules\data_management_init_.py


modules\diagnosis_engine\diagnosis_engine.py

"""
Diagnosis Engine Module

This module serves as the main entry point for the diagnosis engine, integrating
the knowledge graph, inference engine, and integration engine components to provide
comprehensive diagnostic suggestions based on multiple data sources.
"""

import os
import json
import logging
import time
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path

from .knowledge_graph import MedicalKnowledgeGraph
from .inference_engine import MedicalInferenceEngine
from .integration_engine import DiagnosticIntegrationEngine

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class DiagnosisEngine:
    """
    Diagnosis Engine for providing comprehensive diagnostic suggestions.
    
    This class serves as the main entry point for the diagnosis engine, integrating
    the knowledge graph, inference engine, and integration engine components to provide
    comprehensive diagnostic suggestions based on multiple data sources.
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the Diagnosis Engine.
        
        Args:
            config_path: Path to the configuration file for the diagnosis engine.
                         If None, default configuration will be used.
        """
        self.config = self._load_config(config_path)
        self.knowledge_graph = MedicalKnowledgeGraph(config_path)
        self.inference_engine = MedicalInferenceEngine(self.knowledge_graph, config_path)
        self.integration_engine = DiagnosticIntegrationEngine(self.inference_engine, config_path)
        
        # Load rules and cases if specified in config
        self._load_knowledge_resources()
        
        logger.info("Diagnosis Engine initialized")
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file or use default configuration.
        
        Args:
            config_path: Path to the configuration file.
            
        Returns:
            Dictionary containing configuration parameters.
        """
        default_config = {
            "knowledge_resources": {
                "rules_file": "",
                "cases_file": "",
                "knowledge_graph_nodes": "",
                "knowledge_graph_relations": ""
            },
            "explanation": {
                "generate_explanation": True,
                "explanation_detail_level": "detailed",
                "include_evidence": True,
                "include_confidence_scores": True,
                "include_alternative_diagnoses": True
            },
            "performance": {
                "use_gpu": True,
                "batch_processing": True,
                "cache_results": True,
                "cache_expiry_hours": 24
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                    logger.info(f"Loaded configuration from {config_path}")
                    return config
            except Exception as e:
                logger.error(f"Error loading configuration: {e}")
                return default_config
        else:
            # Try to load from the default location
            try:
                base_dir = Path(__file__).parent.parent.parent
                default_path = base_dir / "config" / "diagnosis_engine_config.json"
                if os.path.exists(default_path):
                    with open(default_path, 'r') as f:
                        config = json.load(f)
                        logger.info(f"Loaded configuration from {default_path}")
                        return config
            except Exception as e:
                logger.error(f"Error loading default configuration: {e}")
            
            logger.info("Using default diagnosis engine configuration")
            return default_config
    
    def _load_knowledge_resources(self):
        """
        Load knowledge resources specified in the configuration.
        """
        resources = self.config.get("knowledge_resources", {})
        
        # Load rules
        rules_file = resources.get("rules_file", "")
        if rules_file and os.path.exists(rules_file):
            self.inference_engine.load_rules(rules_file)
        
        # Load cases
        cases_file = resources.get("cases_file", "")
        if cases_file and os.path.exists(cases_file):
            self.inference_engine.load_cases(cases_file)
        
        # Load knowledge graph
        nodes_file = resources.get("knowledge_graph_nodes", "")
        relations_file = resources.get("knowledge_graph_relations", "")
        if nodes_file and relations_file and os.path.exists(nodes_file) and os.path.exists(relations_file):
            self.knowledge_graph.import_from_csv(nodes_file, relations_file)
    
    def diagnose(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]] = None,
                text_analysis: Optional[Dict[str, Any]] = None, 
                lab_results: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Generate diagnostic suggestions based on multiple data sources.
        
        Args:
            patient_data: Dictionary containing patient data, including demographics.
            image_analysis: Results from medical image analysis.
            text_analysis: Results from clinical text analysis.
            lab_results: Laboratory test results.
            
        Returns:
            Dictionary containing diagnostic results.
        """
        start_time = time.time()
        
        # Integrate results from different sources
        diagnosis_results = self.integration_engine.integrate_results(
            patient_data, image_analysis, text_analysis, lab_results
        )
        
        # Add performance metrics
        diagnosis_results["performance"] = {
            "processing_time": time.time() - start_time,
            "timestamp": time.time()
        }
        
        logger.info(f"Diagnosis completed in {diagnosis_results['performance']['processing_time']:.2f} seconds")
        
        return diagnosis_results
    
    def generate_report(self, diagnosis_results: Dict[str, Any], patient_data: Dict[str, Any],
                       format: str = "text") -> str:
        """
        Generate a diagnostic report based on the results.
        
        Args:
            diagnosis_results: Results from diagnostic integration.
            patient_data: Patient demographic data.
            format: Report format ("text", "html", or "json").
            
        Returns:
            Diagnostic report in the specified format.
        """
        return self.integration_engine.generate_diagnostic_report(
            diagnosis_results, patient_data, format
        )
    
    def visualize_knowledge_graph(self, node_ids: Optional[List[str]] = None, 
                                filename: Optional[str] = None) -> None:
        """
        Visualize the knowledge graph or a subgraph.
        
        Args:
            node_ids: List of node identifiers to visualize (optional).
            filename: Path to save the visualization (optional).
        """
        self.knowledge_graph.visualize(node_ids, filename)
    
    def get_knowledge_graph_statistics(self) -> Dict[str, Any]:
        """
        Get statistics about the knowledge graph.
        
        Returns:
            Dictionary containing various statistics about the knowledge graph.
        """
        return self.knowledge_graph.get_statistics()
    
    def export_knowledge_graph(self, nodes_file: str, relations_file: str) -> None:
        """
        Export the knowledge graph to CSV files.
        
        Args:
            nodes_file: Path to save the nodes CSV file.
            relations_file: Path to save the relations CSV file.
        """
        self.knowledge_graph.export_to_csv(nodes_file, relations_file)
    
    def import_external_knowledge(self, source_type: str, source_path: str) -> Dict[str, Any]:
        """
        Import knowledge from external sources.
        
        Args:
            source_type: Type of knowledge source ("umls", "snomed_ct", "icd10", "custom").
            source_path: Path to the knowledge source file.
            
        Returns:
            Dictionary containing import results.
        """
        # This is a placeholder for importing from external knowledge sources
        # In a real implementation, this would connect to medical ontologies and terminologies
        
        logger.warning(f"Import from {source_type} is not fully implemented")
        
        return {
            "success": False,
            "message": f"Import from {source_type} is not implemented yet",
            "source_type": source_type,
            "source_path": source_path
        }
    
    def clear_knowledge_graph(self) -> None:
        """
        Clear the knowledge graph.
        """
        self.knowledge_graph.clear()
        logger.info("Knowledge graph cleared")

modules\diagnosis_engine\inference_engine.py

"""
Medical Inference Engine Module

This module provides functionality for diagnostic reasoning based on medical knowledge
and patient data, using various inference methods including rule-based reasoning,
Bayesian networks, case-based reasoning, and deep learning approaches.
"""

import os
import json
import logging
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict

from .knowledge_graph import MedicalKnowledgeGraph

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MedicalInferenceEngine:
    """
    Medical Inference Engine for diagnostic reasoning.
    
    This class provides methods for inferring diagnoses based on patient data,
    using various reasoning methods including rule-based, Bayesian networks,
    case-based reasoning, and deep learning approaches.
    """
    
    def __init__(self, knowledge_graph: Optional[MedicalKnowledgeGraph] = None, config_path: Optional[str] = None):
        """
        Initialize the Medical Inference Engine.
        
        Args:
            knowledge_graph: Medical knowledge graph to use for reasoning.
                            If None, a new knowledge graph will be created.
            config_path: Path to the configuration file for the inference engine.
                         If None, default configuration will be used.
        """
        self.knowledge_graph = knowledge_graph if knowledge_graph else MedicalKnowledgeGraph()
        self.config = self._load_config(config_path)
        self.reasoning_methods = {
            "rule_based": self._rule_based_reasoning,
            "bayesian_network": self._bayesian_network_reasoning,
            "case_based": self._case_based_reasoning,
            "deep_learning": self._deep_learning_reasoning,
            "hybrid": self._hybrid_reasoning
        }
        self.case_database = []  # For case-based reasoning
        self.rules = []  # For rule-based reasoning
        logger.info("Medical Inference Engine initialized")
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file or use default configuration.
        
        Args:
            config_path: Path to the configuration file.
            
        Returns:
            Dictionary containing configuration parameters.
        """
        default_config = {
            "reasoning_methods": ["rule_based", "bayesian_network", "case_based", "deep_learning"],
            "default_method": "hybrid",
            "confidence_threshold": 0.65,
            "max_differential_diagnoses": 5
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f).get("inference_engine", {})
                    logger.info(f"Loaded configuration from {config_path}")
                    return config
            except Exception as e:
                logger.error(f"Error loading configuration: {e}")
                return default_config
        else:
            # Try to load from the default location
            try:
                base_dir = Path(__file__).parent.parent.parent
                default_path = base_dir / "config" / "diagnosis_engine_config.json"
                if os.path.exists(default_path):
                    with open(default_path, 'r') as f:
                        config = json.load(f).get("inference_engine", {})
                        logger.info(f"Loaded configuration from {default_path}")
                        return config
            except Exception as e:
                logger.error(f"Error loading default configuration: {e}")
            
            logger.info("Using default inference engine configuration")
            return default_config
    
    def load_rules(self, rules_file: str) -> None:
        """
        Load diagnostic rules from a JSON file.
        
        Args:
            rules_file: Path to the JSON file containing diagnostic rules.
        """
        try:
            with open(rules_file, 'r') as f:
                self.rules = json.load(f)
            logger.info(f"Loaded {len(self.rules)} diagnostic rules from {rules_file}")
        except Exception as e:
            logger.error(f"Error loading rules: {e}")
    
    def load_cases(self, cases_file: str) -> None:
        """
        Load previous cases for case-based reasoning.
        
        Args:
            cases_file: Path to the CSV file containing previous cases.
        """
        try:
            self.case_database = pd.read_csv(cases_file).to_dict('records')
            logger.info(f"Loaded {len(self.case_database)} cases from {cases_file}")
        except Exception as e:
            logger.error(f"Error loading cases: {e}")
    
    def diagnose(self, patient_data: Dict[str, Any], method: Optional[str] = None) -> Dict[str, Any]:
        """
        Generate diagnostic suggestions based on patient data.
        
        Args:
            patient_data: Dictionary containing patient data, including:
                         - demographics: age, gender, etc.
                         - symptoms: list of symptoms
                         - medical_history: past medical conditions
                         - lab_results: laboratory test results
                         - image_analysis: results from image analysis
                         - text_analysis: results from clinical text analysis
            method: Reasoning method to use. If None, the default method from
                   configuration will be used.
                   
        Returns:
            Dictionary containing diagnostic results, including:
            - primary_diagnosis: most likely diagnosis
            - differential_diagnoses: alternative diagnoses
            - confidence_scores: confidence scores for each diagnosis
            - evidence: supporting evidence for diagnoses
            - explanation: explanation of the reasoning process
        """
        if method is None:
            method = self.config.get("default_method", "hybrid")
        
        if method not in self.reasoning_methods:
            logger.warning(f"Unknown reasoning method: {method}, using default")
            method = self.config.get("default_method", "hybrid")
        
        # Preprocess patient data
        processed_data = self._preprocess_patient_data(patient_data)
        
        # Apply the selected reasoning method
        reasoning_func = self.reasoning_methods[method]
        diagnosis_results = reasoning_func(processed_data)
        
        # Filter results based on confidence threshold
        confidence_threshold = self.config.get("confidence_threshold", 0.65)
        max_differential = self.config.get("max_differential_diagnoses", 5)
        
        filtered_diagnoses = [d for d in diagnosis_results["diagnoses"] 
                             if d["confidence"] >= confidence_threshold]
        
        # Sort by confidence score and limit to max_differential
        sorted_diagnoses = sorted(filtered_diagnoses, key=lambda x: x["confidence"], reverse=True)
        primary_diagnosis = sorted_diagnoses[0] if sorted_diagnoses else None
        differential_diagnoses = sorted_diagnoses[1:max_differential+1] if len(sorted_diagnoses) > 1 else []
        
        # Generate explanation
        explanation = self._generate_explanation(primary_diagnosis, differential_diagnoses, 
                                               diagnosis_results["evidence"], method)
        
        return {
            "success": True,
            "primary_diagnosis": primary_diagnosis,
            "differential_diagnoses": differential_diagnoses,
            "evidence": diagnosis_results["evidence"],
            "explanation": explanation,
            "method": method,
            "confidence_threshold": confidence_threshold
        }
    
    def _preprocess_patient_data(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocess patient data for diagnostic reasoning.
        
        Args:
            patient_data: Raw patient data.
            
        Returns:
            Preprocessed patient data.
        """
        processed_data = patient_data.copy()
        
        # Normalize symptom names
        if "symptoms" in processed_data:
            processed_data["symptoms"] = [self._normalize_term(s) for s in processed_data["symptoms"]]
        
        # Normalize disease names in medical history
        if "medical_history" in processed_data:
            processed_data["medical_history"] = [self._normalize_term(d) for d in processed_data["medical_history"]]
        
        # Extract entities from text analysis if available
        if "text_analysis" in processed_data:
            text_analysis = processed_data["text_analysis"]
            
            # Extract symptoms from text analysis if not already provided
            if "symptoms" not in processed_data and "entities" in text_analysis:
                symptoms = [e["text"] for e in text_analysis["entities"] 
                           if e["label"].lower() == "symptom"]
                processed_data["symptoms"] = symptoms
            
            # Extract medical history from text analysis if not already provided
            if "medical_history" not in processed_data and "entities" in text_analysis:
                diseases = [e["text"] for e in text_analysis["entities"] 
                           if e["label"].lower() == "disease"]
                processed_data["medical_history"] = diseases
        
        return processed_data
    
    def _normalize_term(self, term: str) -> str:
        """
        Normalize medical terms for consistent matching.
        
        Args:
            term: Medical term to normalize.
            
        Returns:
            Normalized term.
        """
        # Simple normalization: lowercase and remove trailing whitespace
        normalized = term.lower().strip()
        
        # TODO: Implement more sophisticated normalization using medical ontologies
        
        return normalized
    
    def _rule_based_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply rule-based reasoning for diagnosis.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary containing diagnoses and supporting evidence.
        """
        diagnoses = []
        evidence = defaultdict(list)
        
        # Extract patient symptoms
        patient_symptoms = set(patient_data.get("symptoms", []))
        
        # Check each rule
        for rule in self.rules:
            rule_name = rule.get("name", "Unknown Rule")
            disease = rule.get("disease", "")
            required_symptoms = set(rule.get("required_symptoms", []))
            supporting_symptoms = set(rule.get("supporting_symptoms", []))
            exclusion_symptoms = set(rule.get("exclusion_symptoms", []))
            
            # Check if any exclusion symptoms are present
            if patient_symptoms.intersection(exclusion_symptoms):
                continue
            
            # Check if all required symptoms are present
            if not required_symptoms.issubset(patient_symptoms):
                continue
            
            # Calculate confidence based on supporting symptoms
            supporting_count = len(patient_symptoms.intersection(supporting_symptoms))
            total_supporting = len(supporting_symptoms)
            
            if total_supporting > 0:
                support_confidence = supporting_count / total_supporting
            else:
                support_confidence = 1.0 if required_symptoms else 0.0
            
            # Adjust confidence based on required symptoms
            required_confidence = 1.0 if required_symptoms else 0.5
            
            # Final confidence is a weighted combination
            confidence = 0.7 * required_confidence + 0.3 * support_confidence
            
            # Add to diagnoses if confidence is high enough
            if confidence > 0:
                diagnoses.append({
                    "disease": disease,
                    "confidence": confidence,
                    "rule": rule_name
                })
                
                # Record evidence
                for symptom in required_symptoms:
                    if symptom in patient_symptoms:
                        evidence[disease].append({
                            "type": "required_symptom",
                            "item": symptom,
                            "weight": "high"
                        })
                
                for symptom in supporting_symptoms:
                    if symptom in patient_symptoms:
                        evidence[disease].append({
                            "type": "supporting_symptom",
                            "item": symptom,
                            "weight": "medium"
                        })
        
        return {
            "diagnoses": diagnoses,
            "evidence": dict(evidence)
        }
    
    def _bayesian_network_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply Bayesian network reasoning for diagnosis.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary containing diagnoses and supporting evidence.
        """
        diagnoses = []
        evidence = defaultdict(list)
        
        # Extract patient symptoms and other findings
        patient_symptoms = set(patient_data.get("symptoms", []))
        patient_findings = {
            **{f"symptom_{s}": 1 for s in patient_symptoms},
            **{f"lab_{k}": v for k, v in patient_data.get("lab_results", {}).items()}
        }
        
        # Get all disease nodes from the knowledge graph
        disease_nodes = [n for n, data in self.knowledge_graph.graph.nodes(data=True) 
                        if data.get("type") == "disease"]
        
        for disease_id in disease_nodes:
            # Get related symptoms for this disease
            related_nodes = self.knowledge_graph.query_related_nodes(
                disease_id, relation_type="has_symptom", direction="outgoing"
            )
            
            disease_symptoms = [node["id"] for node in related_nodes]
            
            # Skip if no symptoms are defined for this disease
            if not disease_symptoms:
                continue
            
            # Calculate prior probability (can be refined based on prevalence data)
            prior_prob = 0.01  # Default prior probability
            
            # Calculate likelihood using naive Bayes approach
            likelihood = 1.0
            symptom_evidence = []
            
            for symptom in disease_symptoms:
                # Get the conditional probability of the symptom given the disease
                # This should come from the knowledge graph or a probability table
                # For now, we'll use a simplified approach
                edge_data = self.knowledge_graph.graph.get_edge_data(disease_id, symptom) or {}
                conditional_prob = edge_data.get("probability", 0.7)  # Default conditional probability
                
                if symptom in patient_symptoms:
                    likelihood *= conditional_prob
                    symptom_evidence.append({
                        "type": "symptom",
                        "item": symptom,
                        "weight": "high" if conditional_prob > 0.8 else "medium",
                        "probability": conditional_prob
                    })
                else:
                    likelihood *= (1 - conditional_prob)
            
            # Calculate posterior probability using Bayes' rule
            # P(disease|symptoms) = P(symptoms|disease) * P(disease) / P(symptoms)
            # Since P(symptoms) is the same for all diseases, we can ignore it for ranking
            posterior_prob = likelihood * prior_prob
            
            # Add to diagnoses if probability is above zero
            if posterior_prob > 0:
                disease_data = self.knowledge_graph.graph.nodes[disease_id]
                diagnoses.append({
                    "disease": disease_data.get("name", disease_id),
                    "confidence": posterior_prob,
                    "method": "bayesian"
                })
                
                # Record evidence
                evidence[disease_data.get("name", disease_id)] = symptom_evidence
        
        # Normalize confidence scores to [0, 1] range
        max_confidence = max([d["confidence"] for d in diagnoses]) if diagnoses else 1
        for diagnosis in diagnoses:
            diagnosis["confidence"] = diagnosis["confidence"] / max_confidence
        
        return {
            "diagnoses": diagnoses,
            "evidence": dict(evidence)
        }
    
    def _case_based_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply case-based reasoning for diagnosis.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary containing diagnoses and supporting evidence.
        """
        diagnoses = []
        evidence = defaultdict(list)
        
        # Skip if no case database is available
        if not self.case_database:
            logger.warning("No case database available for case-based reasoning")
            return {"diagnoses": [], "evidence": {}}
        
        # Extract patient features
        patient_features = self._extract_case_features(patient_data)
        
        # Find similar cases
        similar_cases = []
        for case in self.case_database:
            case_features = {k: v for k, v in case.items() if k != "diagnosis"}
            similarity = self._calculate_case_similarity(patient_features, case_features)
            similar_cases.append({
                "case": case,
                "similarity": similarity
            })
        
        # Sort by similarity
        similar_cases.sort(key=lambda x: x["similarity"], reverse=True)
        
        # Get top similar cases
        top_cases = similar_cases[:10]  # Consider top 10 most similar cases
        
        # Count diagnoses in top cases
        diagnosis_counts = defaultdict(int)
        diagnosis_similarities = defaultdict(list)
        
        for case_data in top_cases:
            case = case_data["case"]
            similarity = case_data["similarity"]
            
            diagnosis = case.get("diagnosis", "")
            if diagnosis:
                diagnosis_counts[diagnosis] += 1
                diagnosis_similarities[diagnosis].append(similarity)
        
        # Calculate confidence for each diagnosis
        for diagnosis, count in diagnosis_counts.items():
            # Confidence is based on frequency and average similarity
            frequency = count / len(top_cases)
            avg_similarity = sum(diagnosis_similarities[diagnosis]) / len(diagnosis_similarities[diagnosis])
            confidence = 0.7 * frequency + 0.3 * avg_similarity
            
            diagnoses.append({
                "disease": diagnosis,
                "confidence": confidence,
                "method": "case_based",
                "similar_cases": count
            })
            
            # Record evidence from similar cases
            for i, case_data in enumerate(top_cases):
                case = case_data["case"]
                if case.get("diagnosis") == diagnosis:
                    evidence[diagnosis].append({
                        "type": "similar_case",
                        "item": f"Case {i+1}",
                        "similarity": case_data["similarity"],
                        "weight": "high" if case_data["similarity"] > 0.8 else "medium"
                    })
        
        return {
            "diagnoses": diagnoses,
            "evidence": dict(evidence)
        }
    
    def _extract_case_features(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Extract features from patient data for case-based reasoning.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary of features for case comparison.
        """
        features = {}
        
        # Extract demographic features
        demographics = patient_data.get("demographics", {})
        features["age"] = demographics.get("age", 0)
        features["gender"] = demographics.get("gender", "unknown")
        
        # Extract symptom features (as binary indicators)
        for symptom in patient_data.get("symptoms", []):
            features[f"symptom_{symptom}"] = 1
        
        # Extract lab result features
        for lab_test, value in patient_data.get("lab_results", {}).items():
            features[f"lab_{lab_test}"] = value
        
        return features
    
    def _calculate_case_similarity(self, features1: Dict[str, Any], features2: Dict[str, Any]) -> float:
        """
        Calculate similarity between two cases based on their features.
        
        Args:
            features1: Features of the first case.
            features2: Features of the second case.
            
        Returns:
            Similarity score between 0 and 1.
        """
        # Combine all feature keys
        all_keys = set(features1.keys()) | set(features2.keys())
        
        # Initialize similarity components
        matches = 0
        total = 0
        
        for key in all_keys:
            # Skip if key is not relevant for comparison
            if key in ["patient_id", "case_id"]:
                continue
            
            total += 1
            
            # Check if both cases have this feature
            if key in features1 and key in features2:
                value1 = features1[key]
                value2 = features2[key]
                
                # Calculate similarity based on feature type
                if isinstance(value1, (int, float)) and isinstance(value2, (int, float)):
                    # Numeric feature: calculate normalized difference
                    max_val = max(abs(value1), abs(value2))
                    if max_val > 0:
                        similarity = 1 - abs(value1 - value2) / max_val
                    else:
                        similarity = 1.0
                elif isinstance(value1, str) and isinstance(value2, str):
                    # String feature: exact match
                    similarity = 1.0 if value1 == value2 else 0.0
                else:
                    # Other types: exact match
                    similarity = 1.0 if value1 == value2 else 0.0
                
                matches += similarity
        
        # Calculate overall similarity
        if total > 0:
            return matches / total
        else:
            return 0.0
    
    def _deep_learning_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply deep learning reasoning for diagnosis.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary containing diagnoses and supporting evidence.
        """
        # This is a placeholder for deep learning-based diagnosis
        # In a real implementation, this would use a trained neural network model
        
        logger.warning("Deep learning reasoning is not fully implemented")
        
        # Return empty results for now
        return {
            "diagnoses": [],
            "evidence": {}
        }
    
    def _hybrid_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Apply hybrid reasoning combining multiple methods.
        
        Args:
            patient_data: Preprocessed patient data.
            
        Returns:
            Dictionary containing diagnoses and supporting evidence.
        """
        all_diagnoses = []
        all_evidence = defaultdict(list)
        
        # Apply each enabled reasoning method
        enabled_methods = self.config.get("reasoning_methods", ["rule_based", "bayesian_network"])
        
        for method_name in enabled_methods:
            if method_name == "hybrid":
                continue  # Skip to avoid recursion
                
            if method_name in self.reasoning_methods:
                method_func = self.reasoning_methods[method_name]
                result = method_func(patient_data)
                
                # Add method identifier to each diagnosis
                for diagnosis in result["diagnoses"]:
                    diagnosis["method"] = method_name
                
                all_diagnoses.extend(result["diagnoses"])
                
                # Merge evidence
                for disease, evidence_list in result["evidence"].items():
                    all_evidence[disease].extend(evidence_list)
        
        # Combine diagnoses from different methods
        combined_diagnoses = self._combine_diagnoses(all_diagnoses)
        
        return {
            "diagnoses": combined_diagnoses,
            "evidence": dict(all_evidence)
        }
    
    def _combine_diagnoses(self, diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Combine diagnoses from different reasoning methods.
        
        Args:
            diagnoses: List of diagnoses from different methods.
            
        Returns:
            Combined list of diagnoses with updated confidence scores.
        """
        # Group diagnoses by disease
        disease_groups = defaultdict(list)
        for diagnosis in diagnoses:
            disease = diagnosis["disease"]
            disease_groups[disease].append(diagnosis)
        
        # Combine diagnoses for each disease
        combined_diagnoses = []
        for disease, group in disease_groups.items():
            # Calculate combined confidence
            # Strategy: weighted average of confidence scores
            method_weights = {
                "rule_based": 0.3,
                "bayesian_network": 0.4,
                "case_based": 0.2,
                "deep_learning": 0.1
            }
            
            total_weight = 0
            weighted_confidence = 0
            
            for diagnosis in group:
                method = diagnosis.get("method", "unknown")
                weight = method_weights.get(method, 0.1)
                confidence = diagnosis.get("confidence", 0)
                
                weighted_confidence += weight * confidence
                total_weight += weight
            
            if total_weight > 0:
                combined_confidence = weighted_confidence / total_weight
            else:
                combined_confidence = 0
            
            # Create combined diagnosis
            combined_diagnosis = {
                "disease": disease,
                "confidence": combined_confidence,
                "methods": [d.get("method", "unknown") for d in group],
                "individual_confidences": {d.get("method", f"method_{i}"): d.get("confidence", 0) 
                                          for i, d in enumerate(group)}
            }
            
            combined_diagnoses.append(combined_diagnosis)
        
        return combined_diagnoses
    
    def _generate_explanation(self, primary_diagnosis: Optional[Dict[str, Any]],
                             differential_diagnoses: List[Dict[str, Any]],
                             evidence: Dict[str, List[Dict[str, Any]]],
                             method: str) -> Dict[str, Any]:
        """
        Generate an explanation for the diagnostic results.
        
        Args:
            primary_diagnosis: Primary diagnosis.
            differential_diagnoses: List of differential diagnoses.
            evidence: Supporting evidence for diagnoses.
            method: Reasoning method used.
            
        Returns:
            Dictionary containing explanation components.
        """
        explanation = {
            "reasoning_method": method,
            "summary": "",
            "primary_diagnosis_explanation": "",
            "differential_diagnoses_explanation": [],
            "evidence_summary": {}
        }
        
        # Generate summary
        if primary_diagnosis:
            disease = primary_diagnosis["disease"]
            confidence = primary_diagnosis["confidence"]
            confidence_text = self._confidence_to_text(confidence)
            
            explanation["summary"] = (
                f"The most likely diagnosis is {disease} ({confidence_text} confidence: {confidence:.2f}). "
                f"This diagnosis was determined using {method} reasoning."
            )
            
            # Generate primary diagnosis explanation
            disease_evidence = evidence.get(disease, [])
            if disease_evidence:
                evidence_text = self._summarize_evidence(disease_evidence)
                explanation["primary_diagnosis_explanation"] = (
                    f"The diagnosis of {disease} is supported by the following evidence: {evidence_text}"
                )
        else:
            explanation["summary"] = "No confident diagnosis could be determined based on the available information."
        
        # Generate differential diagnoses explanation
        for diagnosis in differential_diagnoses:
            disease = diagnosis["disease"]
            confidence = diagnosis["confidence"]
            confidence_text = self._confidence_to_text(confidence)
            
            disease_evidence = evidence.get(disease, [])
            evidence_text = self._summarize_evidence(disease_evidence)
            
            diff_explanation = {
                "disease": disease,
                "confidence": confidence,
                "confidence_text": confidence_text,
                "explanation": f"Alternative diagnosis: {disease} ({confidence_text} confidence: {confidence:.2f}). "
                              f"Supporting evidence: {evidence_text}"
            }
            
            explanation["differential_diagnoses_explanation"].append(diff_explanation)
        
        # Generate evidence summary
        for disease, disease_evidence in evidence.items():
            explanation["evidence_summary"][disease] = self._categorize_evidence(disease_evidence)
        
        return explanation
    
    def _confidence_to_text(self, confidence: float) -> str:
        """
        Convert confidence score to text description.
        
        Args:
            confidence: Confidence score between 0 and 1.
            
        Returns:
            Text description of confidence.
        """
        if confidence >= 0.9:
            return "very high"
        elif confidence >= 0.75:
            return "high"
        elif confidence >= 0.6:
            return "moderate"
        elif confidence >= 0.4:
            return "low"
        else:
            return "very low"
    
    def _summarize_evidence(self, evidence_list: List[Dict[str, Any]]) -> str:
        """
        Summarize evidence items into a text description.
        
        Args:
            evidence_list: List of evidence items.
            
        Returns:
            Text summary of evidence.
        """
        if not evidence_list:
            return "No specific evidence available."
        
        # Group evidence by type and weight
        high_evidence = [e["item"] for e in evidence_list if e.get("weight") == "high"]
        medium_evidence = [e["item"] for e in evidence_list if e.get("weight") == "medium"]
        other_evidence = [e["item"] for e in evidence_list if e.get("weight") not in ["high", "medium"]]
        
        evidence_parts = []
        
        if high_evidence:
            evidence_parts.append(f"Strong indicators: {', '.join(high_evidence)}")
        
        if medium_evidence:
            evidence_parts.append(f"Supporting factors: {', '.join(medium_evidence)}")
        
        if other_evidence:
            evidence_parts.append(f"Other factors: {', '.join(other_evidence)}")
        
        return ". ".join(evidence_parts)
    
    def _categorize_evidence(self, evidence_list: List[Dict[str, Any]]) -> Dict[str, List[str]]:
        """
        Categorize evidence items by type and weight.
        
        Args:
            evidence_list: List of evidence items.
            
        Returns:
            Dictionary with categorized evidence.
        """
        categorized = {
            "symptoms": [],
            "lab_results": [],
            "imaging": [],
            "patient_history": [],
            "other": []
        }
        
        for evidence in evidence_list:
            item = evidence["item"]
            evidence_type = evidence.get("type", "")
            
            if "symptom" in evidence_type:
                categorized["symptoms"].append(item)
            elif "lab" in evidence_type:
                categorized["lab_results"].append(item)
            elif "image" in evidence_type:
                categorized["imaging"].append(item)
            elif "history" in evidence_type:
                categorized["patient_history"].append(item)
            else:
                categorized["other"].append(item)
        
        return categorized

modules\diagnosis_engine\integration_engine.py

"""
Diagnostic Integration Engine Module

This module provides functionality for integrating results from different analysis modules,
including medical image analysis, clinical text analysis, and lab results, to provide
comprehensive diagnostic suggestions.
"""

import os
import json
import logging
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path
from collections import defaultdict

from .knowledge_graph import MedicalKnowledgeGraph
from .inference_engine import MedicalInferenceEngine

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class DiagnosticIntegrationEngine:
    """
    Diagnostic Integration Engine for combining results from different analysis modules.
    
    This class provides methods for integrating results from medical image analysis,
    clinical text analysis, and lab results to provide comprehensive diagnostic suggestions.
    """
    
    def __init__(self, inference_engine: Optional[MedicalInferenceEngine] = None, config_path: Optional[str] = None):
        """
        Initialize the Diagnostic Integration Engine.
        
        Args:
            inference_engine: Medical inference engine to use for reasoning.
                             If None, a new inference engine will be created.
            config_path: Path to the configuration file for the integration engine.
                         If None, default configuration will be used.
        """
        self.inference_engine = inference_engine if inference_engine else MedicalInferenceEngine()
        self.config = self._load_config(config_path)
        logger.info("Diagnostic Integration Engine initialized")
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file or use default configuration.
        
        Args:
            config_path: Path to the configuration file.
            
        Returns:
            Dictionary containing configuration parameters.
        """
        default_config = {
            "image_analysis_weight": 0.5,
            "text_analysis_weight": 0.5,
            "lab_results_weight": 0.3,
            "patient_history_weight": 0.4,
            "fusion_method": "weighted_average"
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f).get("integration", {})
                    logger.info(f"Loaded configuration from {config_path}")
                    return config
            except Exception as e:
                logger.error(f"Error loading configuration: {e}")
                return default_config
        else:
            # Try to load from the default location
            try:
                base_dir = Path(__file__).parent.parent.parent
                default_path = base_dir / "config" / "diagnosis_engine_config.json"
                if os.path.exists(default_path):
                    with open(default_path, 'r') as f:
                        config = json.load(f).get("integration", {})
                        logger.info(f"Loaded configuration from {default_path}")
                        return config
            except Exception as e:
                logger.error(f"Error loading default configuration: {e}")
            
            logger.info("Using default integration engine configuration")
            return default_config
    
    def integrate_results(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]] = None,
                         text_analysis: Optional[Dict[str, Any]] = None, 
                         lab_results: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
        """
        Integrate results from different analysis modules.
        
        Args:
            patient_data: Dictionary containing patient demographic data.
            image_analysis: Results from medical image analysis.
            text_analysis: Results from clinical text analysis.
            lab_results: Laboratory test results.
            
        Returns:
            Dictionary containing integrated diagnostic results.
        """
        # Prepare integrated data for inference
        integrated_data = self._prepare_integrated_data(patient_data, image_analysis, text_analysis, lab_results)
        
        # Run diagnostic inference
        diagnosis_results = self.inference_engine.diagnose(integrated_data)
        
        # Enhance results with source information
        enhanced_results = self._enhance_results_with_sources(diagnosis_results, image_analysis, text_analysis, lab_results)
        
        return enhanced_results
    
    def _prepare_integrated_data(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]],
                               text_analysis: Optional[Dict[str, Any]], 
                               lab_results: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Prepare integrated data for diagnostic inference.
        
        Args:
            patient_data: Patient demographic data.
            image_analysis: Results from medical image analysis.
            text_analysis: Results from clinical text analysis.
            lab_results: Laboratory test results.
            
        Returns:
            Integrated data ready for diagnostic inference.
        """
        integrated_data = {
            "demographics": patient_data.get("demographics", {}),
            "symptoms": [],
            "findings": [],
            "medical_history": patient_data.get("medical_history", [])
        }
        
        # Extract symptoms and findings from text analysis
        if text_analysis:
            # Import text analysis results to knowledge graph
            if hasattr(self.inference_engine, "knowledge_graph"):
                self.inference_engine.knowledge_graph.import_from_text_analysis(text_analysis)
            
            # Extract symptoms
            if "entities" in text_analysis:
                symptoms = [e["text"] for e in text_analysis["entities"] 
                           if e.get("label", "").lower() == "symptom"]
                integrated_data["symptoms"].extend(symptoms)
            
            # Extract medical history
            if "entities" in text_analysis:
                diseases = [e["text"] for e in text_analysis["entities"] 
                           if e.get("label", "").lower() == "disease"]
                integrated_data["medical_history"].extend(diseases)
            
            # Add text analysis results
            integrated_data["text_analysis"] = text_analysis
        
        # Extract findings from image analysis
        if image_analysis:
            if "findings" in image_analysis:
                integrated_data["findings"].extend(image_analysis["findings"])
            
            # Extract abnormalities as potential symptoms
            if "abnormalities" in image_analysis:
                for abnormality in image_analysis["abnormalities"]:
                    if "description" in abnormality:
                        integrated_data["symptoms"].append(abnormality["description"])
            
            # Add image analysis results
            integrated_data["image_analysis"] = image_analysis
        
        # Add lab results
        if lab_results:
            integrated_data["lab_results"] = lab_results
            
            # Extract abnormal lab results as findings
            if "abnormal_results" in lab_results:
                for test, result in lab_results["abnormal_results"].items():
                    integrated_data["findings"].append(f"Abnormal {test}: {result}")
        
        # Remove duplicates
        integrated_data["symptoms"] = list(set(integrated_data["symptoms"]))
        integrated_data["findings"] = list(set(integrated_data["findings"]))
        integrated_data["medical_history"] = list(set(integrated_data["medical_history"]))
        
        return integrated_data
    
    def _enhance_results_with_sources(self, diagnosis_results: Dict[str, Any],
                                     image_analysis: Optional[Dict[str, Any]],
                                     text_analysis: Optional[Dict[str, Any]],
                                     lab_results: Optional[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Enhance diagnostic results with information about data sources.
        
        Args:
            diagnosis_results: Results from diagnostic inference.
            image_analysis: Results from medical image analysis.
            text_analysis: Results from clinical text analysis.
            lab_results: Laboratory test results.
            
        Returns:
            Enhanced diagnostic results with source information.
        """
        enhanced_results = diagnosis_results.copy()
        
        # Add source information
        sources = []
        if image_analysis:
            sources.append({
                "type": "image_analysis",
                "modality": image_analysis.get("modality", "unknown"),
                "body_part": image_analysis.get("body_part", "unknown"),
                "confidence": image_analysis.get("confidence", 1.0),
                "weight": self.config.get("image_analysis_weight", 0.5)
            })
        
        if text_analysis:
            sources.append({
                "type": "text_analysis",
                "document_type": text_analysis.get("document_type", "unknown"),
                "confidence": text_analysis.get("confidence", 1.0),
                "weight": self.config.get("text_analysis_weight", 0.5)
            })
        
        if lab_results:
            sources.append({
                "type": "lab_results",
                "test_count": len(lab_results.get("results", {})),
                "abnormal_count": len(lab_results.get("abnormal_results", {})),
                "weight": self.config.get("lab_results_weight", 0.3)
            })
        
        enhanced_results["data_sources"] = sources
        
        # Add confidence breakdown by source
        if enhanced_results.get("primary_diagnosis"):
            primary_diagnosis = enhanced_results["primary_diagnosis"]
            confidence_breakdown = self._calculate_confidence_breakdown(
                primary_diagnosis, image_analysis, text_analysis, lab_results
            )
            primary_diagnosis["confidence_breakdown"] = confidence_breakdown
        
        return enhanced_results
    
    def _calculate_confidence_breakdown(self, diagnosis: Dict[str, Any],
                                      image_analysis: Optional[Dict[str, Any]],
                                      text_analysis: Optional[Dict[str, Any]],
                                      lab_results: Optional[Dict[str, Any]]) -> Dict[str, float]:
        """
        Calculate confidence breakdown by data source.
        
        Args:
            diagnosis: Diagnosis to calculate confidence breakdown for.
            image_analysis: Results from medical image analysis.
            text_analysis: Results from clinical text analysis.
            lab_results: Laboratory test results.
            
        Returns:
            Dictionary with confidence contribution from each source.
        """
        disease = diagnosis["disease"]
        confidence_breakdown = {}
        
        # Image analysis contribution
        if image_analysis:
            # Check if the disease is mentioned in image findings
            image_confidence = 0.0
            if "findings" in image_analysis:
                for finding in image_analysis["findings"]:
                    if disease.lower() in finding.lower():
                        image_confidence = 0.8  # High confidence if explicitly mentioned
                        break
            
            # Check if related abnormalities are present
            if image_confidence == 0.0 and "abnormalities" in image_analysis:
                for abnormality in image_analysis["abnormalities"]:
                    # This is simplified; in a real system, we would use the knowledge graph
                    # to check if the abnormality is related to the disease
                    if "description" in abnormality and any(
                        term in abnormality["description"].lower() 
                        for term in disease.lower().split()
                    ):
                        image_confidence = 0.5  # Moderate confidence for related abnormalities
                        break
            
            confidence_breakdown["image_analysis"] = image_confidence * self.config.get("image_analysis_weight", 0.5)
        
        # Text analysis contribution
        if text_analysis:
            text_confidence = 0.0
            
            # Check if the disease is mentioned in entities
            if "entities" in text_analysis:
                for entity in text_analysis["entities"]:
                    if entity.get("label", "").lower() == "disease" and entity["text"].lower() == disease.lower():
                        text_confidence = 0.9  # Very high confidence if explicitly mentioned
                        break
            
            # Check if related symptoms are present
            if text_confidence < 0.5 and "entities" in text_analysis:
                symptom_count = 0
                for entity in text_analysis["entities"]:
                    if entity.get("label", "").lower() == "symptom":
                        # This is simplified; in a real system, we would use the knowledge graph
                        # to check if the symptom is related to the disease
                        symptom_count += 1
                
                if symptom_count > 0:
                    text_confidence = min(0.7, 0.3 + 0.1 * symptom_count)  # Confidence increases with symptom count
            
            confidence_breakdown["text_analysis"] = text_confidence * self.config.get("text_analysis_weight", 0.5)
        
        # Lab results contribution
        if lab_results:
            lab_confidence = 0.0
            
            # Check if there are abnormal results related to the disease
            if "abnormal_results" in lab_results:
                # This is simplified; in a real system, we would use the knowledge graph
                # to check if the abnormal results are related to the disease
                abnormal_count = len(lab_results["abnormal_results"])
                if abnormal_count > 0:
                    lab_confidence = min(0.6, 0.2 + 0.1 * abnormal_count)  # Confidence increases with abnormal count
            
            confidence_breakdown["lab_results"] = lab_confidence * self.config.get("lab_results_weight", 0.3)
        
        return confidence_breakdown
    
    def generate_diagnostic_report(self, diagnosis_results: Dict[str, Any], 
                                  patient_data: Dict[str, Any],
                                  format: str = "text") -> str:
        """
        Generate a diagnostic report based on the integrated results.
        
        Args:
            diagnosis_results: Results from diagnostic integration.
            patient_data: Patient demographic data.
            format: Report format ("text", "html", or "json").
            
        Returns:
            Diagnostic report in the specified format.
        """
        if format == "json":
            # Return JSON representation
            return json.dumps(diagnosis_results, indent=2)
        
        # Extract patient information
        demographics = patient_data.get("demographics", {})
        patient_name = demographics.get("name", "Unknown")
        patient_age = demographics.get("age", "Unknown")
        patient_gender = demographics.get("gender", "Unknown")
        
        # Extract diagnosis information
        primary_diagnosis = diagnosis_results.get("primary_diagnosis", {})
        differential_diagnoses = diagnosis_results.get("differential_diagnoses", [])
        evidence = diagnosis_results.get("evidence", {})
        explanation = diagnosis_results.get("explanation", {})
        
        if format == "html":
            # Generate HTML report
            html_report = f"""<!DOCTYPE html>
<html>
<head>
    <title>Diagnostic Report for {patient_name}</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 20px; }}
        h1, h2, h3 {{ color: #2c3e50; }}
        .section {{ margin-bottom: 20px; }}
        .diagnosis {{ background-color: #f8f9fa; padding: 10px; border-radius: 5px; }}
        .primary {{ border-left: 5px solid #27ae60; }}
        .differential {{ border-left: 5px solid #3498db; }}
        .evidence {{ margin-left: 20px; }}
        .confidence {{ font-weight: bold; }}
        .high {{ color: #27ae60; }}
        .moderate {{ color: #f39c12; }}
        .low {{ color: #e74c3c; }}
        table {{ border-collapse: collapse; width: 100%; }}
        th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
        th {{ background-color: #f2f2f2; }}
    </style>
</head>
<body>
    <h1>Diagnostic Report</h1>
    
    <div class="section">
        <h2>Patient Information</h2>
        <table>
            <tr><th>Name</th><td>{patient_name}</td></tr>
            <tr><th>Age</th><td>{patient_age}</td></tr>
            <tr><th>Gender</th><td>{patient_gender}</td></tr>
        </table>
    </div>
    
    <div class="section">
        <h2>Diagnostic Summary</h2>
        <p>{explanation.get("summary", "No diagnostic summary available.")}</p>
    </div>
"""
            
            # Add primary diagnosis
            if primary_diagnosis:
                disease = primary_diagnosis.get("disease", "Unknown")
                confidence = primary_diagnosis.get("confidence", 0)
                confidence_class = "high" if confidence >= 0.75 else "moderate" if confidence >= 0.5 else "low"
                
                html_report += f"""
    <div class="section">
        <h2>Primary Diagnosis</h2>
        <div class="diagnosis primary">
            <h3>{disease}</h3>
            <p class="confidence">Confidence: <span class="{confidence_class}">{confidence:.2f}</span></p>
            <p>{explanation.get("primary_diagnosis_explanation", "")}</p>
            
            <h4>Supporting Evidence:</h4>
            <div class="evidence">
"""
                
                # Add evidence for primary diagnosis
                disease_evidence = evidence.get(disease, [])
                if disease_evidence:
                    html_report += "<ul>"
                    for ev in disease_evidence:
                        html_report += f"<li>{ev.get('item', '')}: {ev.get('type', '')}</li>"
                    html_report += "</ul>"
                else:
                    html_report += "<p>No specific evidence available.</p>"
                
                html_report += """
            </div>
        </div>
    </div>
"""
            
            # Add differential diagnoses
            if differential_diagnoses:
                html_report += """
    <div class="section">
        <h2>Differential Diagnoses</h2>
"""
                
                for diagnosis in differential_diagnoses:
                    disease = diagnosis.get("disease", "Unknown")
                    confidence = diagnosis.get("confidence", 0)
                    confidence_class = "high" if confidence >= 0.75 else "moderate" if confidence >= 0.5 else "low"
                    
                    html_report += f"""
        <div class="diagnosis differential">
            <h3>{disease}</h3>
            <p class="confidence">Confidence: <span class="{confidence_class}">{confidence:.2f}</span></p>
"""
                    
                    # Add evidence for differential diagnosis
                    disease_evidence = evidence.get(disease, [])
                    if disease_evidence:
                        html_report += "<h4>Supporting Evidence:</h4><div class=\"evidence\"><ul>"
                        for ev in disease_evidence:
                            html_report += f"<li>{ev.get('item', '')}: {ev.get('type', '')}</li>"
                        html_report += "</ul></div>"
                    
                    html_report += """
        </div>
"""
                
                html_report += """
    </div>
"""
            
            # Add data sources
            data_sources = diagnosis_results.get("data_sources", [])
            if data_sources:
                html_report += """
    <div class="section">
        <h2>Data Sources</h2>
        <table>
            <tr>
                <th>Source Type</th>
                <th>Details</th>
                <th>Weight</th>
            </tr>
"""
                
                for source in data_sources:
                    source_type = source.get("type", "Unknown")
                    details = ", ".join([f"{k}: {v}" for k, v in source.items() 
                                        if k not in ["type", "weight"]])
                    weight = source.get("weight", 0)
                    
                    html_report += f"""
            <tr>
                <td>{source_type}</td>
                <td>{details}</td>
                <td>{weight:.2f}</td>
            </tr>
"""
                
                html_report += """
        </table>
    </div>
"""
            
            # Add footer
            html_report += """
    <div class="section">
        <p><em>This report was generated by the Medical Diagnostic System. It is intended to assist healthcare professionals and should not replace clinical judgment.</em></p>
    </div>
</body>
</html>
"""
            
            return html_report
        
        else:  # Text format
            # Generate text report
            text_report = f"""DIAGNOSTIC REPORT
================

PATIENT INFORMATION
------------------
Name: {patient_name}
Age: {patient_age}
Gender: {patient_gender}

DIAGNOSTIC SUMMARY
-----------------
{explanation.get("summary", "No diagnostic summary available.")}

"""
            
            # Add primary diagnosis
            if primary_diagnosis:
                disease = primary_diagnosis.get("disease", "Unknown")
                confidence = primary_diagnosis.get("confidence", 0)
                
                text_report += f"""PRIMARY DIAGNOSIS
----------------
{disease} (Confidence: {confidence:.2f})

{explanation.get("primary_diagnosis_explanation", "")}

Supporting Evidence:
"""
                
                # Add evidence for primary diagnosis
                disease_evidence = evidence.get(disease, [])
                if disease_evidence:
                    for ev in disease_evidence:
                        text_report += f"- {ev.get('item', '')}: {ev.get('type', '')}\n"
                else:
                    text_report += "No specific evidence available.\n"
                
                text_report += "\n"
            
            # Add differential diagnoses
            if differential_diagnoses:
                text_report += """DIFFERENTIAL DIAGNOSES
----------------------
"""
                
                for diagnosis in differential_diagnoses:
                    disease = diagnosis.get("disease", "Unknown")
                    confidence = diagnosis.get("confidence", 0)
                    
                    text_report += f"{disease} (Confidence: {confidence:.2f})\n"
                    
                    # Add evidence for differential diagnosis
                    disease_evidence = evidence.get(disease, [])
                    if disease_evidence:
                        text_report += "Supporting Evidence:\n"
                        for ev in disease_evidence:
                            text_report += f"- {ev.get('item', '')}: {ev.get('type', '')}\n"
                    
                    text_report += "\n"
            
            # Add data sources
            data_sources = diagnosis_results.get("data_sources", [])
            if data_sources:
                text_report += """DATA SOURCES
------------
"""
                
                for source in data_sources:
                    source_type = source.get("type", "Unknown")
                    details = ", ".join([f"{k}: {v}" for k, v in source.items() 
                                        if k not in ["type", "weight"]])
                    weight = source.get("weight", 0)
                    
                    text_report += f"{source_type} (Weight: {weight:.2f}) - {details}\n"
                
                text_report += "\n"
            
            # Add footer
            text_report += """NOTE: This report was generated by the Medical Diagnostic System. It is intended to assist healthcare professionals and should not replace clinical judgment.
"""
            
            return text_report

modules\diagnosis_engine\knowledge_graph.py

"""
Medical Knowledge Graph Module

This module provides functionality for building and querying a medical knowledge graph
that serves as the foundation for the diagnostic reasoning engine.
"""

import os
import json
import logging
from typing import Dict, List, Tuple, Any, Optional, Set, Union
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class MedicalKnowledgeGraph:
    """
    Medical Knowledge Graph for storing and querying medical knowledge.
    
    This class provides methods for building a knowledge graph from various sources,
    including medical ontologies, clinical guidelines, and extracted relations from
    clinical text analysis.
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the Medical Knowledge Graph.
        
        Args:
            config_path: Path to the configuration file for the knowledge graph.
                         If None, default configuration will be used.
        """
        self.graph = nx.DiGraph()
        self.config = self._load_config(config_path)
        self.node_types = {
            "disease": {"color": "red", "shape": "o"},
            "symptom": {"color": "blue", "shape": "s"},
            "medication": {"color": "green", "shape": "^"},
            "procedure": {"color": "purple", "shape": "d"},
            "lab_test": {"color": "orange", "shape": "p"},
            "anatomy": {"color": "brown", "shape": "h"},
            "risk_factor": {"color": "pink", "shape": "*"}
        }
        self.relation_types = {
            "causes": {"color": "red", "style": "solid"},
            "treats": {"color": "green", "style": "solid"},
            "diagnoses": {"color": "blue", "style": "dashed"},
            "contraindicates": {"color": "orange", "style": "dotted"},
            "has_symptom": {"color": "purple", "style": "solid"},
            "located_in": {"color": "brown", "style": "dashed"},
            "increases_risk": {"color": "pink", "style": "dotted"}
        }
        logger.info("Medical Knowledge Graph initialized")
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file or use default configuration.
        
        Args:
            config_path: Path to the configuration file.
            
        Returns:
            Dictionary containing configuration parameters.
        """
        default_config = {
            "database_type": "in_memory",
            "external_sources": ["umls", "snomed_ct", "icd10"],
            "relation_confidence_threshold": 0.7
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f).get("knowledge_graph", {})
                    logger.info(f"Loaded configuration from {config_path}")
                    return config
            except Exception as e:
                logger.error(f"Error loading configuration: {e}")
                return default_config
        else:
            # Try to load from the default location
            try:
                base_dir = Path(__file__).parent.parent.parent
                default_path = base_dir / "config" / "diagnosis_engine_config.json"
                if os.path.exists(default_path):
                    with open(default_path, 'r') as f:
                        config = json.load(f).get("knowledge_graph", {})
                        logger.info(f"Loaded configuration from {default_path}")
                        return config
            except Exception as e:
                logger.error(f"Error loading default configuration: {e}")
            
            logger.info("Using default knowledge graph configuration")
            return default_config
    
    def add_node(self, node_id: str, node_type: str, properties: Dict[str, Any] = None) -> None:
        """
        Add a node to the knowledge graph.
        
        Args:
            node_id: Unique identifier for the node.
            node_type: Type of the node (e.g., disease, symptom, medication).
            properties: Additional properties of the node.
        """
        if properties is None:
            properties = {}
        
        if node_type not in self.node_types:
            logger.warning(f"Unknown node type: {node_type}, using default properties")
            node_visual = {"color": "gray", "shape": "o"}
        else:
            node_visual = self.node_types[node_type]
        
        # Merge visual properties with other properties
        node_data = {
            "id": node_id,
            "type": node_type,
            "visual": node_visual,
            **properties
        }
        
        self.graph.add_node(node_id, **node_data)
        logger.debug(f"Added node: {node_id} of type {node_type}")
    
    def add_relation(self, source_id: str, target_id: str, relation_type: str, 
                    confidence: float = 1.0, properties: Dict[str, Any] = None) -> None:
        """
        Add a relation between two nodes in the knowledge graph.
        
        Args:
            source_id: Identifier of the source node.
            target_id: Identifier of the target node.
            relation_type: Type of the relation (e.g., causes, treats).
            confidence: Confidence score for the relation (0.0 to 1.0).
            properties: Additional properties of the relation.
        """
        if properties is None:
            properties = {}
        
        # Check if nodes exist, if not, create placeholder nodes
        if source_id not in self.graph:
            logger.warning(f"Source node {source_id} does not exist, creating placeholder")
            self.add_node(source_id, "unknown")
        
        if target_id not in self.graph:
            logger.warning(f"Target node {target_id} does not exist, creating placeholder")
            self.add_node(target_id, "unknown")
        
        # Check confidence threshold
        if confidence < self.config.get("relation_confidence_threshold", 0.0):
            logger.debug(f"Relation confidence {confidence} below threshold, not adding relation")
            return
        
        if relation_type not in self.relation_types:
            logger.warning(f"Unknown relation type: {relation_type}, using default properties")
            relation_visual = {"color": "gray", "style": "solid"}
        else:
            relation_visual = self.relation_types[relation_type]
        
        # Merge visual properties with other properties
        relation_data = {
            "type": relation_type,
            "confidence": confidence,
            "visual": relation_visual,
            **properties
        }
        
        self.graph.add_edge(source_id, target_id, **relation_data)
        logger.debug(f"Added relation: {source_id} --[{relation_type}]--> {target_id} (conf: {confidence})")
    
    def import_from_csv(self, nodes_file: str, relations_file: str) -> None:
        """
        Import knowledge graph from CSV files.
        
        Args:
            nodes_file: Path to CSV file containing node data.
            relations_file: Path to CSV file containing relation data.
        """
        try:
            # Import nodes
            nodes_df = pd.read_csv(nodes_file)
            for _, row in nodes_df.iterrows():
                node_id = row['id']
                node_type = row['type']
                properties = {col: row[col] for col in nodes_df.columns if col not in ['id', 'type']}
                self.add_node(node_id, node_type, properties)
            
            # Import relations
            relations_df = pd.read_csv(relations_file)
            for _, row in relations_df.iterrows():
                source_id = row['source_id']
                target_id = row['target_id']
                relation_type = row['type']
                confidence = float(row.get('confidence', 1.0))
                properties = {col: row[col] for col in relations_df.columns 
                             if col not in ['source_id', 'target_id', 'type', 'confidence']}
                self.add_relation(source_id, target_id, relation_type, confidence, properties)
            
            logger.info(f"Imported {len(nodes_df)} nodes and {len(relations_df)} relations from CSV")
        except Exception as e:
            logger.error(f"Error importing from CSV: {e}")
    
    def import_from_text_analysis(self, analysis_results: Dict[str, Any]) -> None:
        """
        Import knowledge from clinical text analysis results.
        
        Args:
            analysis_results: Results from clinical text analysis containing
                             entities and relations.
        """
        try:
            # Extract entities
            entities = analysis_results.get("entities", [])
            for entity in entities:
                entity_id = entity.get("id", entity.get("text", ""))
                entity_type = entity.get("label", "unknown")
                properties = {
                    "text": entity.get("text", ""),
                    "source": "text_analysis",
                    "confidence": entity.get("confidence", 1.0)
                }
                self.add_node(entity_id, entity_type, properties)
            
            # Extract relations
            relations = analysis_results.get("relations", [])
            for relation in relations:
                source_id = relation.get("entity1", {}).get("id", relation.get("entity1", {}).get("text", ""))
                target_id = relation.get("entity2", {}).get("id", relation.get("entity2", {}).get("text", ""))
                relation_type = relation.get("relation_type", "unknown")
                confidence = relation.get("confidence", 1.0)
                properties = {"source": "text_analysis"}
                self.add_relation(source_id, target_id, relation_type, confidence, properties)
            
            logger.info(f"Imported {len(entities)} entities and {len(relations)} relations from text analysis")
        except Exception as e:
            logger.error(f"Error importing from text analysis: {e}")
    
    def query_related_nodes(self, node_id: str, relation_type: Optional[str] = None, 
                           direction: str = "outgoing", max_depth: int = 1) -> List[Dict[str, Any]]:
        """
        Query nodes related to a given node.
        
        Args:
            node_id: Identifier of the node to query.
            relation_type: Type of relation to filter by (optional).
            direction: Direction of relations ('outgoing', 'incoming', or 'both').
            max_depth: Maximum depth of traversal.
            
        Returns:
            List of related nodes with their properties.
        """
        if node_id not in self.graph:
            logger.warning(f"Node {node_id} not found in knowledge graph")
            return []
        
        related_nodes = []
        
        if direction in ["outgoing", "both"]:
            # Get outgoing relations
            for target_id in self.graph.successors(node_id):
                edge_data = self.graph.get_edge_data(node_id, target_id)
                if relation_type is None or edge_data.get("type") == relation_type:
                    node_data = self.graph.nodes[target_id]
                    related_nodes.append({
                        "id": target_id,
                        "type": node_data.get("type"),
                        "relation": {
                            "type": edge_data.get("type"),
                            "direction": "outgoing",
                            "confidence": edge_data.get("confidence", 1.0)
                        },
                        "properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
                    })
        
        if direction in ["incoming", "both"]:
            # Get incoming relations
            for source_id in self.graph.predecessors(node_id):
                edge_data = self.graph.get_edge_data(source_id, node_id)
                if relation_type is None or edge_data.get("type") == relation_type:
                    node_data = self.graph.nodes[source_id]
                    related_nodes.append({
                        "id": source_id,
                        "type": node_data.get("type"),
                        "relation": {
                            "type": edge_data.get("type"),
                            "direction": "incoming",
                            "confidence": edge_data.get("confidence", 1.0)
                        },
                        "properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
                    })
        
        return related_nodes
    
    def find_path(self, source_id: str, target_id: str, max_length: int = 5) -> List[List[Dict[str, Any]]]:
        """
        Find paths between two nodes in the knowledge graph.
        
        Args:
            source_id: Identifier of the source node.
            target_id: Identifier of the target node.
            max_length: Maximum path length to consider.
            
        Returns:
            List of paths, where each path is a list of nodes and relations.
        """
        if source_id not in self.graph or target_id not in self.graph:
            logger.warning(f"Source or target node not found in knowledge graph")
            return []
        
        try:
            # Find all simple paths between source and target
            all_paths = list(nx.all_simple_paths(self.graph, source_id, target_id, cutoff=max_length))
            
            # Format paths with node and relation information
            formatted_paths = []
            for path in all_paths:
                formatted_path = []
                for i in range(len(path) - 1):
                    current_node = path[i]
                    next_node = path[i + 1]
                    
                    # Get node data
                    current_node_data = self.graph.nodes[current_node]
                    
                    # Get edge data
                    edge_data = self.graph.get_edge_data(current_node, next_node)
                    
                    formatted_path.append({
                        "node": {
                            "id": current_node,
                            "type": current_node_data.get("type"),
                            "properties": {k: v for k, v in current_node_data.items() 
                                          if k not in ["id", "type"]}
                        },
                        "relation": {
                            "type": edge_data.get("type"),
                            "confidence": edge_data.get("confidence", 1.0),
                            "properties": {k: v for k, v in edge_data.items() 
                                          if k not in ["type", "confidence"]}
                        }
                    })
                
                # Add the last node
                last_node_data = self.graph.nodes[path[-1]]
                formatted_path.append({
                    "node": {
                        "id": path[-1],
                        "type": last_node_data.get("type"),
                        "properties": {k: v for k, v in last_node_data.items() 
                                      if k not in ["id", "type"]}
                    },
                    "relation": None
                })
                
                formatted_paths.append(formatted_path)
            
            return formatted_paths
        except Exception as e:
            logger.error(f"Error finding path: {e}")
            return []
    
    def get_node_info(self, node_id: str) -> Dict[str, Any]:
        """
        Get detailed information about a node.
        
        Args:
            node_id: Identifier of the node.
            
        Returns:
            Dictionary containing node information.
        """
        if node_id not in self.graph:
            logger.warning(f"Node {node_id} not found in knowledge graph")
            return {}
        
        node_data = self.graph.nodes[node_id]
        
        # Get outgoing relations
        outgoing_relations = []
        for target_id in self.graph.successors(node_id):
            edge_data = self.graph.get_edge_data(node_id, target_id)
            target_data = self.graph.nodes[target_id]
            outgoing_relations.append({
                "target": {
                    "id": target_id,
                    "type": target_data.get("type"),
                    "name": target_data.get("name", target_id)
                },
                "type": edge_data.get("type"),
                "confidence": edge_data.get("confidence", 1.0)
            })
        
        # Get incoming relations
        incoming_relations = []
        for source_id in self.graph.predecessors(node_id):
            edge_data = self.graph.get_edge_data(source_id, node_id)
            source_data = self.graph.nodes[source_id]
            incoming_relations.append({
                "source": {
                    "id": source_id,
                    "type": source_data.get("type"),
                    "name": source_data.get("name", source_id)
                },
                "type": edge_data.get("type"),
                "confidence": edge_data.get("confidence", 1.0)
            })
        
        return {
            "id": node_id,
            "type": node_data.get("type"),
            "properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]},
            "outgoing_relations": outgoing_relations,
            "incoming_relations": incoming_relations
        }
    
    def find_nodes_by_property(self, property_name: str, property_value: Any) -> List[Dict[str, Any]]:
        """
        Find nodes that have a specific property value.
        
        Args:
            property_name: Name of the property to search for.
            property_value: Value of the property to match.
            
        Returns:
            List of nodes matching the criteria.
        """
        matching_nodes = []
        
        for node_id, node_data in self.graph.nodes(data=True):
            if property_name in node_data and node_data[property_name] == property_value:
                matching_nodes.append({
                    "id": node_id,
                    "type": node_data.get("type"),
                    "properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
                })
        
        return matching_nodes
    
    def get_subgraph(self, node_ids: List[str], include_relations: bool = True) -> nx.DiGraph:
        """
        Extract a subgraph containing specified nodes and their relations.
        
        Args:
            node_ids: List of node identifiers to include in the subgraph.
            include_relations: Whether to include relations between the nodes.
            
        Returns:
            NetworkX DiGraph representing the subgraph.
        """
        # Filter nodes that exist in the graph
        valid_nodes = [node_id for node_id in node_ids if node_id in self.graph]
        
        if not valid_nodes:
            logger.warning("No valid nodes provided for subgraph extraction")
            return nx.DiGraph()
        
        if include_relations:
            # Create a subgraph with the specified nodes and all edges between them
            subgraph = self.graph.subgraph(valid_nodes).copy()
        else:
            # Create a subgraph with only the specified nodes, no edges
            subgraph = nx.DiGraph()
            for node_id in valid_nodes:
                node_data = self.graph.nodes[node_id]
                subgraph.add_node(node_id, **node_data)
        
        return subgraph
    
    def visualize(self, node_ids: Optional[List[str]] = None, 
                 filename: Optional[str] = None, 
                 show_labels: bool = True,
                 figsize: Tuple[int, int] = (12, 10)) -> None:
        """
        Visualize the knowledge graph or a subgraph.
        
        Args:
            node_ids: List of node identifiers to visualize (optional).
            filename: Path to save the visualization (optional).
            show_labels: Whether to show node labels.
            figsize: Figure size as (width, height) in inches.
        """
        if node_ids:
            # Visualize a subgraph
            graph_to_viz = self.get_subgraph(node_ids)
            title = f"Medical Knowledge Subgraph ({len(node_ids)} nodes)"
        else:
            # Visualize the entire graph
            graph_to_viz = self.graph
            title = f"Medical Knowledge Graph ({len(self.graph.nodes)} nodes, {len(self.graph.edges)} edges)"
        
        if len(graph_to_viz.nodes) == 0:
            logger.warning("No nodes to visualize")
            return
        
        plt.figure(figsize=figsize)
        plt.title(title)
        
        # Set up node positions using spring layout
        pos = nx.spring_layout(graph_to_viz, seed=42)
        
        # Draw nodes with different colors and shapes based on node type
        for node_type, style in self.node_types.items():
            node_list = [n for n, data in graph_to_viz.nodes(data=True) if data.get("type") == node_type]
            if node_list:
                nx.draw_networkx_nodes(
                    graph_to_viz, pos,
                    nodelist=node_list,
                    node_color=style["color"],
                    node_shape=style["shape"],
                    node_size=500,
                    alpha=0.8
                )
        
        # Draw any nodes with unknown types
        unknown_nodes = [n for n, data in graph_to_viz.nodes(data=True) 
                        if data.get("type") not in self.node_types]
        if unknown_nodes:
            nx.draw_networkx_nodes(
                graph_to_viz, pos,
                nodelist=unknown_nodes,
                node_color="gray",
                node_shape="o",
                node_size=500,
                alpha=0.8
            )
        
        # Draw edges with different colors and styles based on relation type
        for rel_type, style in self.relation_types.items():
            edge_list = [(u, v) for u, v, data in graph_to_viz.edges(data=True) 
                        if data.get("type") == rel_type]
            if edge_list:
                nx.draw_networkx_edges(
                    graph_to_viz, pos,
                    edgelist=edge_list,
                    edge_color=style["color"],
                    style=style["style"],
                    alpha=0.7,
                    arrowsize=15
                )
        
        # Draw any edges with unknown types
        unknown_edges = [(u, v) for u, v, data in graph_to_viz.edges(data=True) 
                         if data.get("type") not in self.relation_types]
        if unknown_edges:
            nx.draw_networkx_edges(
                graph_to_viz, pos,
                edgelist=unknown_edges,
                edge_color="gray",
                style="solid",
                alpha=0.7,
                arrowsize=15
            )
        
        # Draw node labels if requested
        if show_labels:
            labels = {n: data.get("name", n) for n, data in graph_to_viz.nodes(data=True)}
            nx.draw_networkx_labels(graph_to_viz, pos, labels, font_size=10)
        
        # Create legend for node types
        node_handles = []
        node_labels = []
        for node_type, style in self.node_types.items():
            node_handles.append(plt.Line2D([0], [0], marker=style["shape"], color='w', 
                                          markerfacecolor=style["color"], markersize=10))
            node_labels.append(node_type)
        
        # Create legend for relation types
        edge_handles = []
        edge_labels = []
        for rel_type, style in self.relation_types.items():
            if style["style"] == "solid":
                linestyle = "-"
            elif style["style"] == "dashed":
                linestyle = "--"
            else:  # dotted
                linestyle = ":"
            edge_handles.append(plt.Line2D([0], [1], color=style["color"], linestyle=linestyle))
            edge_labels.append(rel_type)
        
        # Add legends
        plt.legend(node_handles, node_labels, title="Node Types", loc="upper left", bbox_to_anchor=(1, 1))
        plt.legend(edge_handles, edge_labels, title="Relation Types", loc="upper left", bbox_to_anchor=(1, 0.8))
        
        plt.axis('off')
        plt.tight_layout()
        
        if filename:
            plt.savefig(filename, bbox_inches='tight')
            logger.info(f"Visualization saved to {filename}")
        
        plt.show()
    
    def export_to_csv(self, nodes_file: str, relations_file: str) -> None:
        """
        Export the knowledge graph to CSV files.
        
        Args:
            nodes_file: Path to save the nodes CSV file.
            relations_file: Path to save the relations CSV file.
        """
        try:
            # Export nodes
            nodes_data = []
            for node_id, node_data in self.graph.nodes(data=True):
                node_dict = {"id": node_id, "type": node_data.get("type", "unknown")}
                for k, v in node_data.items():
                    if k not in ["id", "type", "visual"]:
                        node_dict[k] = v
                nodes_data.append(node_dict)
            
            nodes_df = pd.DataFrame(nodes_data)
            nodes_df.to_csv(nodes_file, index=False)
            
            # Export relations
            relations_data = []
            for source_id, target_id, edge_data in self.graph.edges(data=True):
                relation_dict = {
                    "source_id": source_id,
                    "target_id": target_id,
                    "type": edge_data.get("type", "unknown"),
                    "confidence": edge_data.get("confidence", 1.0)
                }
                for k, v in edge_data.items():
                    if k not in ["type", "confidence", "visual"]:
                        relation_dict[k] = v
                relations_data.append(relation_dict)
            
            relations_df = pd.DataFrame(relations_data)
            relations_df.to_csv(relations_file, index=False)
            
            logger.info(f"Exported {len(nodes_data)} nodes to {nodes_file}")
            logger.info(f"Exported {len(relations_data)} relations to {relations_file}")
        except Exception as e:
            logger.error(f"Error exporting to CSV: {e}")
    
    def get_statistics(self) -> Dict[str, Any]:
        """
        Get statistics about the knowledge graph.
        
        Returns:
            Dictionary containing various statistics about the knowledge graph.
        """
        stats = {
            "num_nodes": len(self.graph.nodes),
            "num_edges": len(self.graph.edges),
            "node_types": {},
            "relation_types": {},
            "avg_degree": 0,
            "density": nx.density(self.graph)
        }
        
        # Count node types
        for _, data in self.graph.nodes(data=True):
            node_type = data.get("type", "unknown")
            if node_type in stats["node_types"]:
                stats["node_types"][node_type] += 1
            else:
                stats["node_types"][node_type] = 1
        
        # Count relation types
        for _, _, data in self.graph.edges(data=True):
            relation_type = data.get("type", "unknown")
            if relation_type in stats["relation_types"]:
                stats["relation_types"][relation_type] += 1
            else:
                stats["relation_types"][relation_type] = 1
        
        # Calculate average degree
        if stats["num_nodes"] > 0:
            stats["avg_degree"] = 2 * stats["num_edges"] / stats["num_nodes"]
        
        return stats
    
    def clear(self) -> None:
        """
        Clear the knowledge graph.
        """
        self.graph.clear()
        logger.info("Knowledge graph cleared")

modules\image_analysis\image_analysis.py

"""
Medical Image Analysis Module Main Script

This script demonstrates how to use the image analysis module components
for processing, segmenting, and classifying medical images.
"""

import os
import logging
import argparse
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from pathlib import Path
from typing import Dict, List, Any, Optional, Union, Tuple

# Import image analysis modules
from modules.image_analysis.image_processor import (
    MedicalImageProcessor, XRayProcessor, CTProcessor, MRIProcessor
)
from modules.image_analysis.image_classifier import (
    MedicalImageClassifier, ChestXRayClassifier, BrainMRIClassifier
)
from modules.image_analysis.image_segmentation import (
    MedicalImageSegmenter, LungSegmenter, BrainSegmenter, OrganSegmenter, DeepLearningSegmenter
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("medical_image_analysis.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def analyze_medical_image(image_path: str, image_type: str, output_dir: str, 
                         config_path: Optional[str] = None) -> Dict[str, Any]:
    """
    Analyze a medical image
    
    Args:
        image_path: Path to the medical image
        image_type: Type of medical image ('xray', 'ct', 'mri')
        output_dir: Directory to save results
        config_path: Optional path to configuration file
        
    Returns:
        Dict[str, Any]: Analysis results
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize appropriate processor and segmenter based on image type
    if image_type == 'xray':
        processor = XRayProcessor(config_path)
        segmenter = LungSegmenter(config_path)
    elif image_type == 'ct':
        processor = CTProcessor(config_path)
        segmenter = OrganSegmenter(config_path)
    elif image_type == 'mri':
        processor = MRIProcessor(config_path)
        segmenter = BrainSegmenter(config_path)
    else:
        processor = MedicalImageProcessor(config_path)
        segmenter = MedicalImageSegmenter(config_path)
    
    # Load and preprocess image
    logger.info(f"Loading image: {image_path}")
    image = processor.load_image(image_path)
    
    if image is None:
        logger.error(f"Failed to load image: {image_path}")
        return {"success": False, "error": "Failed to load image"}
    
    logger.info("Preprocessing image...")
    preprocessed = processor.preprocess(image)
    
    # Save preprocessed image
    preprocessed_path = os.path.join(output_dir, "preprocessed.png")
    if len(preprocessed.shape) == 2:
        cv2.imwrite(preprocessed_path, (preprocessed * 255).astype(np.uint8))
    else:
        cv2.imwrite(preprocessed_path, cv2.cvtColor((preprocessed * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
    
    # Segment image
    logger.info("Segmenting image...")
    if image_type == 'xray':
        segmentation_mask = segmenter.segment_lungs(preprocessed)
        # Enhance lung structures
        enhanced = segmenter.enhance_lung_structures(preprocessed, segmentation_mask)
        enhanced_path = os.path.join(output_dir, "enhanced.png")
        cv2.imwrite(enhanced_path, (enhanced * 255).astype(np.uint8))
    elif image_type == 'mri':
        segmentation_mask = segmenter.segment_brain(preprocessed)
        # Try to segment tumor
        tumor_mask = segmenter.segment_tumor(preprocessed, segmentation_mask)
        if np.any(tumor_mask):
            tumor_path = os.path.join(output_dir, "tumor_mask.png")
            cv2.imwrite(tumor_path, tumor_mask.astype(np.uint8) * 255)
    elif image_type == 'ct':
        liver_mask = segmenter.segment_liver(preprocessed)
        kidney_mask = segmenter.segment_kidneys(preprocessed)
        bone_mask = segmenter.segment_bones(preprocessed)
        
        # Create combined mask
        segmentation_mask = np.zeros_like(preprocessed[:,:,0] if len(preprocessed.shape) > 2 else preprocessed, dtype=np.uint8)
        segmentation_mask[liver_mask] = 1  # Liver
        segmentation_mask[kidney_mask] = 2  # Kidneys
        segmentation_mask[bone_mask] = 3    # Bones
        
        # Save individual masks
        cv2.imwrite(os.path.join(output_dir, 'liver_mask.png'), liver_mask.astype(np.uint8) * 255)
        cv2.imwrite(os.path.join(output_dir, 'kidney_mask.png'), kidney_mask.astype(np.uint8) * 255)
        cv2.imwrite(os.path.join(output_dir, 'bone_mask.png'), bone_mask.astype(np.uint8) * 255)
    else:
        segmentation_mask = segmenter.threshold_segmentation(preprocessed)
    
    # Extract region properties
    regions = segmenter.extract_region_properties(segmentation_mask)
    
    # Extract features
    logger.info("Extracting features...")
    features = processor.extract_features(preprocessed, segmentation_mask)
    
    # Visualize results
    logger.info("Generating visualization...")
    visualization_path = os.path.join(output_dir, "segmentation_result.png")
    segmenter.visualize_segmentation(preprocessed, segmentation_mask, regions, visualization_path)
    
    # Save segmentation mask
    mask_path = os.path.join(output_dir, "segmentation_mask.png")
    cv2.imwrite(mask_path, segmentation_mask.astype(np.uint8) * 255)
    
    # Save results
    results = {
        "success": True,
        "image_path": image_path,
        "image_type": image_type,
        "features": {k: float(v) if isinstance(v, (np.float32, np.float64)) else v for k, v in features.items()},
        "regions": [
            {
                "label": int(region["label"]),
                "area": float(region["area"]),
                "centroid": [float(c) for c in region["centroid"]],
                "equivalent_diameter": float(region["equivalent_diameter"])
            }
            for region in regions
        ],
        "visualization_path": visualization_path,
        "mask_path": mask_path,
        "preprocessed_path": preprocessed_path
    }
    
    results_path = os.path.join(output_dir, "analysis_results.json")
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"Results saved to {results_path}")
    
    return results

def classify_medical_image(image_path: str, image_type: str, model_path: str, 
                          output_dir: str, config_path: Optional[str] = None) -> Dict[str, Any]:
    """
    Classify a medical image
    
    Args:
        image_path: Path to the medical image
        image_type: Type of medical image ('xray', 'ct', 'mri')
        model_path: Path to the trained model
        output_dir: Directory to save results
        config_path: Optional path to configuration file
        
    Returns:
        Dict[str, Any]: Classification results
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize appropriate classifier based on image type
    if image_type == 'xray':
        classifier = ChestXRayClassifier(config_path)
    elif image_type == 'mri':
        classifier = BrainMRIClassifier(config_path)
    else:
        classifier = MedicalImageClassifier(config_path)
    
    # Load model if it exists
    if os.path.exists(model_path):
        logger.info(f"Loading model from {model_path}")
        success = classifier.load_model(model_path)
        
        if not success:
            logger.error(f"Failed to load model from {model_path}")
            return {"success": False, "error": "Failed to load model"}
    else:
        logger.error(f"Model not found: {model_path}")
        return {"success": False, "error": "Model not found"}
    
    # Load and preprocess image
    logger.info(f"Loading image: {image_path}")
    image = cv2.imread(image_path)
    
    if image is None:
        logger.error(f"Failed to load image: {image_path}")
        return {"success": False, "error": "Failed to load image"}
    
    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Classify image
    logger.info("Classifying image...")
    prediction = classifier.predict(image)
    
    if not prediction["success"]:
        logger.error(f"Classification failed: {prediction.get('error', 'Unknown error')}")
        return prediction
    
    # Visualize prediction
    logger.info("Generating visualization...")
    visualization_path = os.path.join(output_dir, "classification_result.png")
    classifier.visualize_predictions(image, prediction, visualization_path)
    
    # Add visualization path to results
    prediction["visualization_path"] = visualization_path
    
    # Save results
    results_path = os.path.join(output_dir, "classification_results.json")
    with open(results_path, 'w') as f:
        json.dump(prediction, f, indent=2)
    
    logger.info(f"Results saved to {results_path}")
    
    return prediction

def segment_with_deep_learning(image_path: str, model_path: str, output_dir: str, 
                              config_path: Optional[str] = None) -> Dict[str, Any]:
    """
    Segment a medical image using deep learning
    
    Args:
        image_path: Path to the medical image
        model_path: Path to the trained segmentation model
        output_dir: Directory to save results
        config_path: Optional path to configuration file
        
    Returns:
        Dict[str, Any]: Segmentation results
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize deep learning segmenter
    segmenter = DeepLearningSegmenter(config_path, model_path)
    
    # Load image
    logger.info(f"Loading image: {image_path}")
    image = cv2.imread(image_path)
    
    if image is None:
        logger.error(f"Failed to load image: {image_path}")
        return {"success": False, "error": "Failed to load image"}
    
    # Convert BGR to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Segment image
    logger.info("Segmenting image with deep learning model...")
    segmentation_mask = segmenter.segment_with_model(image)
    
    if segmentation_mask.size == 0:
        logger.error("Segmentation failed")
        return {"success": False, "error": "Segmentation failed"}
    
    # Extract region properties
    regions = segmenter.extract_region_properties(segmentation_mask)
    
    # Visualize results
    logger.info("Generating visualization...")
    visualization_path = os.path.join(output_dir, "deep_segmentation_result.png")
    segmenter.visualize_segmentation(image, segmentation_mask, regions, visualization_path)
    
    # Save segmentation mask
    mask_path = os.path.join(output_dir, "deep_segmentation_mask.png")
    cv2.imwrite(mask_path, segmentation_mask.astype(np.uint8) * 255)
    
    # Save results
    results = {
        "success": True,
        "image_path": image_path,
        "model_path": model_path,
        "regions": [
            {
                "label": int(region["label"]),
                "area": float(region["area"]),
                "centroid": [float(c) for c in region["centroid"]],
                "equivalent_diameter": float(region["equivalent_diameter"])
            }
            for region in regions
        ],
        "visualization_path": visualization_path,
        "mask_path": mask_path
    }
    
    results_path = os.path.join(output_dir, "deep_segmentation_results.json")
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    logger.info(f"Results saved to {results_path}")
    
    return results

def main():
    """Main function to run the image analysis module"""
    
    # Parse command line arguments
    parser = argparse.ArgumentParser(description='Medical Image Analysis Module')
    parser.add_argument('--image', type=str, required=True,
                        help='Path to medical image')
    parser.add_argument('--type', type=str, choices=['xray', 'ct', 'mri'], required=True,
                        help='Type of medical image')
    parser.add_argument('--mode', type=str, choices=['analyze', 'classify', 'segment'], required=True,
                        help='Analysis mode')
    parser.add_argument('--output', type=str, default='./output',
                        help='Directory to save results')
    parser.add_argument('--config', type=str, default=None,
                        help='Path to configuration file')
    parser.add_argument('--model', type=str, default=None,
                        help='Path to trained model (for classification or deep segmentation mode)')
    
    args = parser.parse_args()
    
    # Check if image exists
    if not os.path.isfile(args.image):
        logger.error(f"Image not found: {args.image}")
        return
    
    # Create output directory
    os.makedirs(args.output, exist_ok=True)
    
    # Run analysis based on mode
    if args.mode == 'analyze':
        results = analyze_medical_image(args.image, args.type, args.output, args.config)
    elif args.mode == 'classify':
        if args.model is None:
            logger.error("Model path is required for classification mode")
            return
        
        results = classify_medical_image(args.image, args.type, args.model, args.output, args.config)
    else:  # segment
        if args.model is None:
            logger.error("Model path is required for deep segmentation mode")
            return
        
        results = segment_with_deep_learning(args.image, args.model, args.output, args.config)
    
    # Print summary
    if results["success"]:
        print("\n===== Medical Image Analysis Summary =====")
        print(f"Image: {args.image}")
        print(f"Type: {args.type}")
        print(f"Mode: {args.mode}")
        
        if args.mode == 'analyze':
            print("\nFeatures:")
            for feature, value in results["features"].items():
                print(f"  - {feature}: {value}")
            
            print(f"\nDetected {len(results['regions'])} regions")
            
        elif args.mode == 'classify':
            print(f"\nPredicted class: {results['predicted_class']}")
            print(f"Confidence: {results['confidence']:.4f}")
            
            print("\nClass probabilities:")
            for cls, prob in results["probabilities"].items():
                print(f"  - {cls}: {prob:.4f}")
        
        else:  # segment
            print(f"\nDetected {len(results['regions'])} regions")
        
        print(f"\nResults saved to {args.output}")
        
    else:
        print(f"Error: {results.get('error', 'Unknown error')}")

if __name__ == "__main__":
    main()

modules\image_analysis\image_classifier.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天天进步2015

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值