"""
医疗案例数据访问对象 - 提供医疗案例和相关数据的数据库操作
"""
import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from .base_dao import BaseDAO
# 配置日志
logger = logging.getLogger(__name__)
class MedicalCaseDAO(BaseDAO):
"""医疗案例数据访问对象类"""
def __init__(self):
"""初始化医疗案例DAO"""
super().__init__('medical_cases', 'case_id')
def find_cases_by_patient(self, patient_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询患者的医疗案例
Args:
patient_id: 患者ID
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
return self.find_by_criteria({'patient_id': patient_id}, limit=limit, offset=offset,
order_by='created_at', order_direction='DESC')
def find_cases_by_doctor(self, doctor_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询医生负责的医疗案例
Args:
doctor_id: 医生ID
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
return self.find_by_criteria({'primary_doctor_id': doctor_id}, limit=limit, offset=offset,
order_by='created_at', order_direction='DESC')
def find_cases_by_status(self, status: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据状态查询医疗案例
Args:
status: 案例状态
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
return self.find_by_criteria({'status': status}, limit=limit, offset=offset,
order_by='created_at', order_direction='DESC')
def find_cases_by_priority(self, priority: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据优先级查询医疗案例
Args:
priority: 案例优先级
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
return self.find_by_criteria({'priority': priority}, limit=limit, offset=offset,
order_by='created_at', order_direction='DESC')
def find_case_with_details(self, case_id: int) -> Optional[Dict[str, Any]]:
"""查询医疗案例详情,包含患者和医生信息
Args:
case_id: 案例ID
Returns:
Dict[str, Any]: 医疗案例详情,如果没有找到则返回None
"""
query = """
SELECT mc.*,
p.medical_record_number,
pu.first_name as patient_first_name,
pu.last_name as patient_last_name,
d.specialization,
du.first_name as doctor_first_name,
du.last_name as doctor_last_name
FROM medical_cases mc
JOIN patients p ON mc.patient_id = p.patient_id
JOIN users pu ON p.user_id = pu.user_id
JOIN doctors d ON mc.primary_doctor_id = d.doctor_id
JOIN users du ON d.user_id = du.user_id
WHERE mc.case_id = %s
"""
try:
return self.db.execute_one(query, (case_id,))
except Exception as e:
logger.error(f"Error finding case with details: {e}")
return None
def find_recent_cases(self, limit: int = 10) -> List[Dict[str, Any]]:
"""查询最近的医疗案例
Args:
limit: 返回记录数量限制
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
return self.find_all(limit=limit, offset=0)
def update_case_status(self, case_id: int, status: str) -> bool:
"""更新案例状态
Args:
case_id: 案例ID
status: 新状态
Returns:
bool: 更新是否成功
"""
return self.update(case_id, {'status': status, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def update_case_priority(self, case_id: int, priority: str) -> bool:
"""更新案例优先级
Args:
case_id: 案例ID
priority: 新优先级
Returns:
bool: 更新是否成功
"""
return self.update(case_id, {'priority': priority, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def assign_doctor(self, case_id: int, doctor_id: int) -> bool:
"""为案例分配医生
Args:
case_id: 案例ID
doctor_id: 医生ID
Returns:
bool: 更新是否成功
"""
return self.update(case_id, {'primary_doctor_id': doctor_id, 'updated_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def search_cases(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""搜索医疗案例
Args:
search_term: 搜索关键词
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医疗案例列表
"""
query = """
SELECT mc.*,
pu.first_name as patient_first_name,
pu.last_name as patient_last_name,
du.first_name as doctor_first_name,
du.last_name as doctor_last_name
FROM medical_cases mc
JOIN patients p ON mc.patient_id = p.patient_id
JOIN users pu ON p.user_id = pu.user_id
JOIN doctors d ON mc.primary_doctor_id = d.doctor_id
JOIN users du ON d.user_id = du.user_id
WHERE mc.title LIKE %s OR mc.description LIKE %s
OR pu.first_name LIKE %s OR pu.last_name LIKE %s
OR du.first_name LIKE %s OR du.last_name LIKE %s
LIMIT %s OFFSET %s
"""
search_pattern = f"%{search_term}%"
params = (search_pattern, search_pattern, search_pattern, search_pattern,
search_pattern, search_pattern, limit, offset)
try:
return self.db.execute_query(query, params)
except Exception as e:
logger.error(f"Error searching cases: {e}")
return []
def count_cases_by_status(self) -> Dict[str, int]:
"""统计各状态的案例数量
Returns:
Dict[str, int]: 状态及对应的案例数量
"""
query = """
SELECT status, COUNT(*) as count
FROM medical_cases
GROUP BY status
"""
try:
results = self.db.execute_query(query)
return {result['status']: result['count'] for result in results}
except Exception as e:
logger.error(f"Error counting cases by status: {e}")
return {}
class MedicalImageDAO(BaseDAO):
"""医学图像数据访问对象类"""
def __init__(self):
"""初始化医学图像DAO"""
super().__init__('medical_images', 'image_id')
def find_images_by_case(self, case_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询案例的医学图像
Args:
case_id: 案例ID
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医学图像列表
"""
return self.find_by_criteria({'case_id': case_id}, limit=limit, offset=offset,
order_by='upload_date', order_direction='DESC')
def find_images_by_type(self, image_type: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据图像类型查询医学图像
Args:
image_type: 图像类型
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医学图像列表
"""
return self.find_by_criteria({'image_type': image_type}, limit=limit, offset=offset,
order_by='upload_date', order_direction='DESC')
def find_images_with_analysis(self, case_id: int) -> List[Dict[str, Any]]:
"""查询带有分析结果的医学图像
Args:
case_id: 案例ID
Returns:
List[Dict[str, Any]]: 医学图像及分析结果列表
"""
query = """
SELECT mi.*, iar.analysis_id, iar.analysis_date, iar.findings, iar.confidence_score
FROM medical_images mi
LEFT JOIN image_analysis_results iar ON mi.image_id = iar.image_id
WHERE mi.case_id = %s
ORDER BY mi.upload_date DESC
"""
try:
return self.db.execute_query(query, (case_id,))
except Exception as e:
logger.error(f"Error finding images with analysis: {e}")
return []
def count_images_by_type(self) -> Dict[str, int]:
"""统计各类型的图像数量
Returns:
Dict[str, int]: 图像类型及对应的数量
"""
query = """
SELECT image_type, COUNT(*) as count
FROM medical_images
GROUP BY image_type
"""
try:
results = self.db.execute_query(query)
return {result['image_type']: result['count'] for result in results}
except Exception as e:
logger.error(f"Error counting images by type: {e}")
return {}
class ImageAnalysisResultDAO(BaseDAO):
"""图像分析结果数据访问对象类"""
def __init__(self):
"""初始化图像分析结果DAO"""
super().__init__('image_analysis_results', 'analysis_id')
def find_by_image(self, image_id: int) -> Optional[Dict[str, Any]]:
"""查询图像的分析结果
Args:
image_id: 图像ID
Returns:
Dict[str, Any]: 分析结果,如果没有找到则返回None
"""
return self.find_by_criteria({'image_id': image_id}, limit=1, offset=0)[0] if self.find_by_criteria({'image_id': image_id}, limit=1, offset=0) else None
def find_by_model(self, model_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询指定AI模型的分析结果
Args:
model_id: AI模型ID
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 分析结果列表
"""
return self.find_by_criteria({'model_id': model_id}, limit=limit, offset=offset,
order_by='analysis_date', order_direction='DESC')
def find_by_confidence_range(self, min_confidence: float, max_confidence: float,
limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据置信度范围查询分析结果
Args:
min_confidence: 最小置信度
max_confidence: 最大置信度
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 分析结果列表
"""
query = """
SELECT * FROM image_analysis_results
WHERE confidence_score BETWEEN %s AND %s
ORDER BY analysis_date DESC
LIMIT %s OFFSET %s
"""
try:
return self.db.execute_query(query, (min_confidence, max_confidence, limit, offset))
except Exception as e:
logger.error(f"Error finding analysis by confidence range: {e}")
return []
def find_with_image_info(self, analysis_id: int) -> Optional[Dict[str, Any]]:
"""查询分析结果,包含图像信息
Args:
analysis_id: 分析结果ID
Returns:
Dict[str, Any]: 分析结果及图像信息,如果没有找到则返回None
"""
query = """
SELECT iar.*, mi.image_path, mi.image_type, mi.upload_date,
mi.case_id, am.model_name, am.version
FROM image_analysis_results iar
JOIN medical_images mi ON iar.image_id = mi.image_id
JOIN ai_models am ON iar.model_id = am.model_id
WHERE iar.analysis_id = %s
"""
try:
return self.db.execute_one(query, (analysis_id,))
except Exception as e:
logger.error(f"Error finding analysis with image info: {e}")
return None
def find_latest_analyses(self, limit: int = 10) -> List[Dict[str, Any]]:
"""查询最新的分析结果
Args:
limit: 返回记录数量限制
Returns:
List[Dict[str, Any]]: 分析结果列表
"""
query = """
SELECT iar.*, mi.image_path, mi.image_type, am.model_name
FROM image_analysis_results iar
JOIN medical_images mi ON iar.image_id = mi.image_id
JOIN ai_models am ON iar.model_id = am.model_id
ORDER BY iar.analysis_date DESC
LIMIT %s
"""
try:
return self.db.execute_query(query, (limit,))
except Exception as e:
logger.error(f"Error finding latest analyses: {e}")
return []
def get_average_confidence_by_model(self) -> Dict[str, float]:
"""获取各AI模型的平均置信度
Returns:
Dict[str, float]: 模型名称及对应的平均置信度
"""
query = """
SELECT am.model_name, AVG(iar.confidence_score) as avg_confidence
FROM image_analysis_results iar
JOIN ai_models am ON iar.model_id = am.model_id
GROUP BY am.model_name
"""
try:
results = self.db.execute_query(query)
return {result['model_name']: float(result['avg_confidence']) for result in results}
except Exception as e:
logger.error(f"Error getting average confidence by model: {e}")
return {}
modules\data_management\user_dao.py
"""
用户数据访问对象 - 提供用户、医生和患者相关的数据库操作
"""
import logging
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from .base_dao import BaseDAO
# 配置日志
logger = logging.getLogger(__name__)
class UserDAO(BaseDAO):
"""用户数据访问对象类"""
def __init__(self):
"""初始化用户DAO"""
super().__init__('users', 'user_id')
def find_by_username(self, username: str) -> Optional[Dict[str, Any]]:
"""根据用户名查找用户
Args:
username: 用户名
Returns:
Dict[str, Any]: 用户信息,如果没有找到则返回None
"""
return self.find_by_criteria({'username': username}, limit=1, offset=0)[0] if self.find_by_criteria({'username': username}, limit=1, offset=0) else None
def find_by_email(self, email: str) -> Optional[Dict[str, Any]]:
"""根据邮箱查找用户
Args:
email: 邮箱地址
Returns:
Dict[str, Any]: 用户信息,如果没有找到则返回None
"""
return self.find_by_criteria({'email': email}, limit=1, offset=0)[0] if self.find_by_criteria({'email': email}, limit=1, offset=0) else None
def find_by_role(self, role: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据角色查找用户
Args:
role: 用户角色
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 用户列表
"""
return self.find_by_criteria({'role': role}, limit=limit, offset=offset)
def update_last_login(self, user_id: int) -> bool:
"""更新用户最后登录时间
Args:
user_id: 用户ID
Returns:
bool: 更新是否成功
"""
return self.update(user_id, {'last_login': datetime.now().strftime('%Y-%m-%d %H:%M:%S')})
def change_password(self, user_id: int, hashed_password: str) -> bool:
"""更改用户密码
Args:
user_id: 用户ID
hashed_password: 加密后的密码
Returns:
bool: 更新是否成功
"""
return self.update(user_id, {'password': hashed_password})
def activate_user(self, user_id: int) -> bool:
"""激活用户
Args:
user_id: 用户ID
Returns:
bool: 更新是否成功
"""
return self.update(user_id, {'is_active': True})
def deactivate_user(self, user_id: int) -> bool:
"""停用用户
Args:
user_id: 用户ID
Returns:
bool: 更新是否成功
"""
return self.update(user_id, {'is_active': False})
def search_users(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""搜索用户
Args:
search_term: 搜索关键词
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 用户列表
"""
query = """
SELECT * FROM users
WHERE username LIKE %s OR email LIKE %s
OR first_name LIKE %s OR last_name LIKE %s
LIMIT %s OFFSET %s
"""
search_pattern = f"%{search_term}%"
params = (search_pattern, search_pattern, search_pattern, search_pattern, limit, offset)
try:
return self.db.execute_query(query, params)
except Exception as e:
logger.error(f"Error searching users: {e}")
return []
class DoctorDAO(BaseDAO):
"""医生数据访问对象类"""
def __init__(self):
"""初始化医生DAO"""
super().__init__('doctors', 'doctor_id')
def find_with_user_info(self, doctor_id: Optional[int] = None, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询医生信息,包含用户基本信息
Args:
doctor_id: 医生ID,如果为None则查询所有医生
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医生信息列表
"""
query = """
SELECT d.*, u.username, u.email, u.first_name, u.last_name,
u.gender, u.profile_image
FROM doctors d
JOIN users u ON d.user_id = u.user_id
"""
params = []
if doctor_id is not None:
query += " WHERE d.doctor_id = %s"
params.append(doctor_id)
query += " LIMIT %s OFFSET %s"
params.extend([limit, offset])
try:
return self.db.execute_query(query, tuple(params))
except Exception as e:
logger.error(f"Error finding doctors with user info: {e}")
return []
def find_by_specialization(self, specialization: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据专业查询医生
Args:
specialization: 专业领域
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医生信息列表
"""
return self.find_by_criteria({'specialization': specialization}, limit=limit, offset=offset)
def find_by_department(self, department: str, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""根据科室查询医生
Args:
department: 科室名称
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医生信息列表
"""
return self.find_by_criteria({'department': department}, limit=limit, offset=offset)
def find_top_rated(self, limit: int = 10) -> List[Dict[str, Any]]:
"""查询评分最高的医生
Args:
limit: 返回记录数量限制
Returns:
List[Dict[str, Any]]: 医生信息列表
"""
query = """
SELECT d.*, u.first_name, u.last_name, u.profile_image
FROM doctors d
JOIN users u ON d.user_id = u.user_id
ORDER BY d.rating DESC
LIMIT %s
"""
try:
return self.db.execute_query(query, (limit,))
except Exception as e:
logger.error(f"Error finding top rated doctors: {e}")
return []
def update_rating(self, doctor_id: int, new_rating: float) -> bool:
"""更新医生评分
Args:
doctor_id: 医生ID
new_rating: 新评分
Returns:
bool: 更新是否成功
"""
return self.update(doctor_id, {'rating': new_rating})
def search_doctors(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""搜索医生
Args:
search_term: 搜索关键词
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 医生信息列表
"""
query = """
SELECT d.*, u.first_name, u.last_name, u.email, u.profile_image
FROM doctors d
JOIN users u ON d.user_id = u.user_id
WHERE d.specialization LIKE %s OR d.department LIKE %s
OR u.first_name LIKE %s OR u.last_name LIKE %s
LIMIT %s OFFSET %s
"""
search_pattern = f"%{search_term}%"
params = (search_pattern, search_pattern, search_pattern, search_pattern, limit, offset)
try:
return self.db.execute_query(query, params)
except Exception as e:
logger.error(f"Error searching doctors: {e}")
return []
class PatientDAO(BaseDAO):
"""患者数据访问对象类"""
def __init__(self):
"""初始化患者DAO"""
super().__init__('patients', 'patient_id')
def find_with_user_info(self, patient_id: Optional[int] = None, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询患者信息,包含用户基本信息
Args:
patient_id: 患者ID,如果为None则查询所有患者
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 患者信息列表
"""
query = """
SELECT p.*, u.username, u.email, u.first_name, u.last_name,
u.gender, u.date_of_birth, u.address, u.profile_image
FROM patients p
JOIN users u ON p.user_id = u.user_id
"""
params = []
if patient_id is not None:
query += " WHERE p.patient_id = %s"
params.append(patient_id)
query += " LIMIT %s OFFSET %s"
params.extend([limit, offset])
try:
return self.db.execute_query(query, tuple(params))
except Exception as e:
logger.error(f"Error finding patients with user info: {e}")
return []
def find_by_medical_record_number(self, mrn: str) -> Optional[Dict[str, Any]]:
"""根据病历号查询患者
Args:
mrn: 病历号
Returns:
Dict[str, Any]: 患者信息,如果没有找到则返回None
"""
return self.find_by_criteria({'medical_record_number': mrn}, limit=1, offset=0)[0] if self.find_by_criteria({'medical_record_number': mrn}, limit=1, offset=0) else None
def find_patients_by_doctor(self, doctor_id: int, limit: int = 1000, offset: int = 0) -> List[Dict[str, Any]]:
"""查询医生的患者
Args:
doctor_id: 医生ID
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 患者信息列表
"""
query = """
SELECT DISTINCT p.*, u.first_name, u.last_name
FROM patients p
JOIN users u ON p.user_id = u.user_id
JOIN medical_cases mc ON p.patient_id = mc.patient_id
WHERE mc.primary_doctor_id = %s
LIMIT %s OFFSET %s
"""
try:
return self.db.execute_query(query, (doctor_id, limit, offset))
except Exception as e:
logger.error(f"Error finding patients by doctor: {e}")
return []
def search_patients(self, search_term: str, limit: int = 100, offset: int = 0) -> List[Dict[str, Any]]:
"""搜索患者
Args:
search_term: 搜索关键词
limit: 返回记录数量限制
offset: 起始偏移量
Returns:
List[Dict[str, Any]]: 患者信息列表
"""
query = """
SELECT p.*, u.first_name, u.last_name, u.email
FROM patients p
JOIN users u ON p.user_id = u.user_id
WHERE p.medical_record_number LIKE %s
OR u.first_name LIKE %s OR u.last_name LIKE %s
LIMIT %s OFFSET %s
"""
search_pattern = f"%{search_term}%"
params = (search_pattern, search_pattern, search_pattern, limit, offset)
try:
return self.db.execute_query(query, params)
except Exception as e:
logger.error(f"Error searching patients: {e}")
return []
modules\data_management_init_.py
modules\diagnosis_engine\diagnosis_engine.py
"""
Diagnosis Engine Module
This module serves as the main entry point for the diagnosis engine, integrating
the knowledge graph, inference engine, and integration engine components to provide
comprehensive diagnostic suggestions based on multiple data sources.
"""
import os
import json
import logging
import time
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path
from .knowledge_graph import MedicalKnowledgeGraph
from .inference_engine import MedicalInferenceEngine
from .integration_engine import DiagnosticIntegrationEngine
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DiagnosisEngine:
"""
Diagnosis Engine for providing comprehensive diagnostic suggestions.
This class serves as the main entry point for the diagnosis engine, integrating
the knowledge graph, inference engine, and integration engine components to provide
comprehensive diagnostic suggestions based on multiple data sources.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the Diagnosis Engine.
Args:
config_path: Path to the configuration file for the diagnosis engine.
If None, default configuration will be used.
"""
self.config = self._load_config(config_path)
self.knowledge_graph = MedicalKnowledgeGraph(config_path)
self.inference_engine = MedicalInferenceEngine(self.knowledge_graph, config_path)
self.integration_engine = DiagnosticIntegrationEngine(self.inference_engine, config_path)
# Load rules and cases if specified in config
self._load_knowledge_resources()
logger.info("Diagnosis Engine initialized")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file or use default configuration.
Args:
config_path: Path to the configuration file.
Returns:
Dictionary containing configuration parameters.
"""
default_config = {
"knowledge_resources": {
"rules_file": "",
"cases_file": "",
"knowledge_graph_nodes": "",
"knowledge_graph_relations": ""
},
"explanation": {
"generate_explanation": True,
"explanation_detail_level": "detailed",
"include_evidence": True,
"include_confidence_scores": True,
"include_alternative_diagnoses": True
},
"performance": {
"use_gpu": True,
"batch_processing": True,
"cache_results": True,
"cache_expiry_hours": 24
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return default_config
else:
# Try to load from the default location
try:
base_dir = Path(__file__).parent.parent.parent
default_path = base_dir / "config" / "diagnosis_engine_config.json"
if os.path.exists(default_path):
with open(default_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {default_path}")
return config
except Exception as e:
logger.error(f"Error loading default configuration: {e}")
logger.info("Using default diagnosis engine configuration")
return default_config
def _load_knowledge_resources(self):
"""
Load knowledge resources specified in the configuration.
"""
resources = self.config.get("knowledge_resources", {})
# Load rules
rules_file = resources.get("rules_file", "")
if rules_file and os.path.exists(rules_file):
self.inference_engine.load_rules(rules_file)
# Load cases
cases_file = resources.get("cases_file", "")
if cases_file and os.path.exists(cases_file):
self.inference_engine.load_cases(cases_file)
# Load knowledge graph
nodes_file = resources.get("knowledge_graph_nodes", "")
relations_file = resources.get("knowledge_graph_relations", "")
if nodes_file and relations_file and os.path.exists(nodes_file) and os.path.exists(relations_file):
self.knowledge_graph.import_from_csv(nodes_file, relations_file)
def diagnose(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]] = None,
text_analysis: Optional[Dict[str, Any]] = None,
lab_results: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Generate diagnostic suggestions based on multiple data sources.
Args:
patient_data: Dictionary containing patient data, including demographics.
image_analysis: Results from medical image analysis.
text_analysis: Results from clinical text analysis.
lab_results: Laboratory test results.
Returns:
Dictionary containing diagnostic results.
"""
start_time = time.time()
# Integrate results from different sources
diagnosis_results = self.integration_engine.integrate_results(
patient_data, image_analysis, text_analysis, lab_results
)
# Add performance metrics
diagnosis_results["performance"] = {
"processing_time": time.time() - start_time,
"timestamp": time.time()
}
logger.info(f"Diagnosis completed in {diagnosis_results['performance']['processing_time']:.2f} seconds")
return diagnosis_results
def generate_report(self, diagnosis_results: Dict[str, Any], patient_data: Dict[str, Any],
format: str = "text") -> str:
"""
Generate a diagnostic report based on the results.
Args:
diagnosis_results: Results from diagnostic integration.
patient_data: Patient demographic data.
format: Report format ("text", "html", or "json").
Returns:
Diagnostic report in the specified format.
"""
return self.integration_engine.generate_diagnostic_report(
diagnosis_results, patient_data, format
)
def visualize_knowledge_graph(self, node_ids: Optional[List[str]] = None,
filename: Optional[str] = None) -> None:
"""
Visualize the knowledge graph or a subgraph.
Args:
node_ids: List of node identifiers to visualize (optional).
filename: Path to save the visualization (optional).
"""
self.knowledge_graph.visualize(node_ids, filename)
def get_knowledge_graph_statistics(self) -> Dict[str, Any]:
"""
Get statistics about the knowledge graph.
Returns:
Dictionary containing various statistics about the knowledge graph.
"""
return self.knowledge_graph.get_statistics()
def export_knowledge_graph(self, nodes_file: str, relations_file: str) -> None:
"""
Export the knowledge graph to CSV files.
Args:
nodes_file: Path to save the nodes CSV file.
relations_file: Path to save the relations CSV file.
"""
self.knowledge_graph.export_to_csv(nodes_file, relations_file)
def import_external_knowledge(self, source_type: str, source_path: str) -> Dict[str, Any]:
"""
Import knowledge from external sources.
Args:
source_type: Type of knowledge source ("umls", "snomed_ct", "icd10", "custom").
source_path: Path to the knowledge source file.
Returns:
Dictionary containing import results.
"""
# This is a placeholder for importing from external knowledge sources
# In a real implementation, this would connect to medical ontologies and terminologies
logger.warning(f"Import from {source_type} is not fully implemented")
return {
"success": False,
"message": f"Import from {source_type} is not implemented yet",
"source_type": source_type,
"source_path": source_path
}
def clear_knowledge_graph(self) -> None:
"""
Clear the knowledge graph.
"""
self.knowledge_graph.clear()
logger.info("Knowledge graph cleared")
modules\diagnosis_engine\inference_engine.py
"""
Medical Inference Engine Module
This module provides functionality for diagnostic reasoning based on medical knowledge
and patient data, using various inference methods including rule-based reasoning,
Bayesian networks, case-based reasoning, and deep learning approaches.
"""
import os
import json
import logging
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path
import networkx as nx
from sklearn.metrics.pairwise import cosine_similarity
from collections import defaultdict
from .knowledge_graph import MedicalKnowledgeGraph
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MedicalInferenceEngine:
"""
Medical Inference Engine for diagnostic reasoning.
This class provides methods for inferring diagnoses based on patient data,
using various reasoning methods including rule-based, Bayesian networks,
case-based reasoning, and deep learning approaches.
"""
def __init__(self, knowledge_graph: Optional[MedicalKnowledgeGraph] = None, config_path: Optional[str] = None):
"""
Initialize the Medical Inference Engine.
Args:
knowledge_graph: Medical knowledge graph to use for reasoning.
If None, a new knowledge graph will be created.
config_path: Path to the configuration file for the inference engine.
If None, default configuration will be used.
"""
self.knowledge_graph = knowledge_graph if knowledge_graph else MedicalKnowledgeGraph()
self.config = self._load_config(config_path)
self.reasoning_methods = {
"rule_based": self._rule_based_reasoning,
"bayesian_network": self._bayesian_network_reasoning,
"case_based": self._case_based_reasoning,
"deep_learning": self._deep_learning_reasoning,
"hybrid": self._hybrid_reasoning
}
self.case_database = [] # For case-based reasoning
self.rules = [] # For rule-based reasoning
logger.info("Medical Inference Engine initialized")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file or use default configuration.
Args:
config_path: Path to the configuration file.
Returns:
Dictionary containing configuration parameters.
"""
default_config = {
"reasoning_methods": ["rule_based", "bayesian_network", "case_based", "deep_learning"],
"default_method": "hybrid",
"confidence_threshold": 0.65,
"max_differential_diagnoses": 5
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f).get("inference_engine", {})
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return default_config
else:
# Try to load from the default location
try:
base_dir = Path(__file__).parent.parent.parent
default_path = base_dir / "config" / "diagnosis_engine_config.json"
if os.path.exists(default_path):
with open(default_path, 'r') as f:
config = json.load(f).get("inference_engine", {})
logger.info(f"Loaded configuration from {default_path}")
return config
except Exception as e:
logger.error(f"Error loading default configuration: {e}")
logger.info("Using default inference engine configuration")
return default_config
def load_rules(self, rules_file: str) -> None:
"""
Load diagnostic rules from a JSON file.
Args:
rules_file: Path to the JSON file containing diagnostic rules.
"""
try:
with open(rules_file, 'r') as f:
self.rules = json.load(f)
logger.info(f"Loaded {len(self.rules)} diagnostic rules from {rules_file}")
except Exception as e:
logger.error(f"Error loading rules: {e}")
def load_cases(self, cases_file: str) -> None:
"""
Load previous cases for case-based reasoning.
Args:
cases_file: Path to the CSV file containing previous cases.
"""
try:
self.case_database = pd.read_csv(cases_file).to_dict('records')
logger.info(f"Loaded {len(self.case_database)} cases from {cases_file}")
except Exception as e:
logger.error(f"Error loading cases: {e}")
def diagnose(self, patient_data: Dict[str, Any], method: Optional[str] = None) -> Dict[str, Any]:
"""
Generate diagnostic suggestions based on patient data.
Args:
patient_data: Dictionary containing patient data, including:
- demographics: age, gender, etc.
- symptoms: list of symptoms
- medical_history: past medical conditions
- lab_results: laboratory test results
- image_analysis: results from image analysis
- text_analysis: results from clinical text analysis
method: Reasoning method to use. If None, the default method from
configuration will be used.
Returns:
Dictionary containing diagnostic results, including:
- primary_diagnosis: most likely diagnosis
- differential_diagnoses: alternative diagnoses
- confidence_scores: confidence scores for each diagnosis
- evidence: supporting evidence for diagnoses
- explanation: explanation of the reasoning process
"""
if method is None:
method = self.config.get("default_method", "hybrid")
if method not in self.reasoning_methods:
logger.warning(f"Unknown reasoning method: {method}, using default")
method = self.config.get("default_method", "hybrid")
# Preprocess patient data
processed_data = self._preprocess_patient_data(patient_data)
# Apply the selected reasoning method
reasoning_func = self.reasoning_methods[method]
diagnosis_results = reasoning_func(processed_data)
# Filter results based on confidence threshold
confidence_threshold = self.config.get("confidence_threshold", 0.65)
max_differential = self.config.get("max_differential_diagnoses", 5)
filtered_diagnoses = [d for d in diagnosis_results["diagnoses"]
if d["confidence"] >= confidence_threshold]
# Sort by confidence score and limit to max_differential
sorted_diagnoses = sorted(filtered_diagnoses, key=lambda x: x["confidence"], reverse=True)
primary_diagnosis = sorted_diagnoses[0] if sorted_diagnoses else None
differential_diagnoses = sorted_diagnoses[1:max_differential+1] if len(sorted_diagnoses) > 1 else []
# Generate explanation
explanation = self._generate_explanation(primary_diagnosis, differential_diagnoses,
diagnosis_results["evidence"], method)
return {
"success": True,
"primary_diagnosis": primary_diagnosis,
"differential_diagnoses": differential_diagnoses,
"evidence": diagnosis_results["evidence"],
"explanation": explanation,
"method": method,
"confidence_threshold": confidence_threshold
}
def _preprocess_patient_data(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess patient data for diagnostic reasoning.
Args:
patient_data: Raw patient data.
Returns:
Preprocessed patient data.
"""
processed_data = patient_data.copy()
# Normalize symptom names
if "symptoms" in processed_data:
processed_data["symptoms"] = [self._normalize_term(s) for s in processed_data["symptoms"]]
# Normalize disease names in medical history
if "medical_history" in processed_data:
processed_data["medical_history"] = [self._normalize_term(d) for d in processed_data["medical_history"]]
# Extract entities from text analysis if available
if "text_analysis" in processed_data:
text_analysis = processed_data["text_analysis"]
# Extract symptoms from text analysis if not already provided
if "symptoms" not in processed_data and "entities" in text_analysis:
symptoms = [e["text"] for e in text_analysis["entities"]
if e["label"].lower() == "symptom"]
processed_data["symptoms"] = symptoms
# Extract medical history from text analysis if not already provided
if "medical_history" not in processed_data and "entities" in text_analysis:
diseases = [e["text"] for e in text_analysis["entities"]
if e["label"].lower() == "disease"]
processed_data["medical_history"] = diseases
return processed_data
def _normalize_term(self, term: str) -> str:
"""
Normalize medical terms for consistent matching.
Args:
term: Medical term to normalize.
Returns:
Normalized term.
"""
# Simple normalization: lowercase and remove trailing whitespace
normalized = term.lower().strip()
# TODO: Implement more sophisticated normalization using medical ontologies
return normalized
def _rule_based_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply rule-based reasoning for diagnosis.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary containing diagnoses and supporting evidence.
"""
diagnoses = []
evidence = defaultdict(list)
# Extract patient symptoms
patient_symptoms = set(patient_data.get("symptoms", []))
# Check each rule
for rule in self.rules:
rule_name = rule.get("name", "Unknown Rule")
disease = rule.get("disease", "")
required_symptoms = set(rule.get("required_symptoms", []))
supporting_symptoms = set(rule.get("supporting_symptoms", []))
exclusion_symptoms = set(rule.get("exclusion_symptoms", []))
# Check if any exclusion symptoms are present
if patient_symptoms.intersection(exclusion_symptoms):
continue
# Check if all required symptoms are present
if not required_symptoms.issubset(patient_symptoms):
continue
# Calculate confidence based on supporting symptoms
supporting_count = len(patient_symptoms.intersection(supporting_symptoms))
total_supporting = len(supporting_symptoms)
if total_supporting > 0:
support_confidence = supporting_count / total_supporting
else:
support_confidence = 1.0 if required_symptoms else 0.0
# Adjust confidence based on required symptoms
required_confidence = 1.0 if required_symptoms else 0.5
# Final confidence is a weighted combination
confidence = 0.7 * required_confidence + 0.3 * support_confidence
# Add to diagnoses if confidence is high enough
if confidence > 0:
diagnoses.append({
"disease": disease,
"confidence": confidence,
"rule": rule_name
})
# Record evidence
for symptom in required_symptoms:
if symptom in patient_symptoms:
evidence[disease].append({
"type": "required_symptom",
"item": symptom,
"weight": "high"
})
for symptom in supporting_symptoms:
if symptom in patient_symptoms:
evidence[disease].append({
"type": "supporting_symptom",
"item": symptom,
"weight": "medium"
})
return {
"diagnoses": diagnoses,
"evidence": dict(evidence)
}
def _bayesian_network_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply Bayesian network reasoning for diagnosis.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary containing diagnoses and supporting evidence.
"""
diagnoses = []
evidence = defaultdict(list)
# Extract patient symptoms and other findings
patient_symptoms = set(patient_data.get("symptoms", []))
patient_findings = {
**{f"symptom_{s}": 1 for s in patient_symptoms},
**{f"lab_{k}": v for k, v in patient_data.get("lab_results", {}).items()}
}
# Get all disease nodes from the knowledge graph
disease_nodes = [n for n, data in self.knowledge_graph.graph.nodes(data=True)
if data.get("type") == "disease"]
for disease_id in disease_nodes:
# Get related symptoms for this disease
related_nodes = self.knowledge_graph.query_related_nodes(
disease_id, relation_type="has_symptom", direction="outgoing"
)
disease_symptoms = [node["id"] for node in related_nodes]
# Skip if no symptoms are defined for this disease
if not disease_symptoms:
continue
# Calculate prior probability (can be refined based on prevalence data)
prior_prob = 0.01 # Default prior probability
# Calculate likelihood using naive Bayes approach
likelihood = 1.0
symptom_evidence = []
for symptom in disease_symptoms:
# Get the conditional probability of the symptom given the disease
# This should come from the knowledge graph or a probability table
# For now, we'll use a simplified approach
edge_data = self.knowledge_graph.graph.get_edge_data(disease_id, symptom) or {}
conditional_prob = edge_data.get("probability", 0.7) # Default conditional probability
if symptom in patient_symptoms:
likelihood *= conditional_prob
symptom_evidence.append({
"type": "symptom",
"item": symptom,
"weight": "high" if conditional_prob > 0.8 else "medium",
"probability": conditional_prob
})
else:
likelihood *= (1 - conditional_prob)
# Calculate posterior probability using Bayes' rule
# P(disease|symptoms) = P(symptoms|disease) * P(disease) / P(symptoms)
# Since P(symptoms) is the same for all diseases, we can ignore it for ranking
posterior_prob = likelihood * prior_prob
# Add to diagnoses if probability is above zero
if posterior_prob > 0:
disease_data = self.knowledge_graph.graph.nodes[disease_id]
diagnoses.append({
"disease": disease_data.get("name", disease_id),
"confidence": posterior_prob,
"method": "bayesian"
})
# Record evidence
evidence[disease_data.get("name", disease_id)] = symptom_evidence
# Normalize confidence scores to [0, 1] range
max_confidence = max([d["confidence"] for d in diagnoses]) if diagnoses else 1
for diagnosis in diagnoses:
diagnosis["confidence"] = diagnosis["confidence"] / max_confidence
return {
"diagnoses": diagnoses,
"evidence": dict(evidence)
}
def _case_based_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply case-based reasoning for diagnosis.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary containing diagnoses and supporting evidence.
"""
diagnoses = []
evidence = defaultdict(list)
# Skip if no case database is available
if not self.case_database:
logger.warning("No case database available for case-based reasoning")
return {"diagnoses": [], "evidence": {}}
# Extract patient features
patient_features = self._extract_case_features(patient_data)
# Find similar cases
similar_cases = []
for case in self.case_database:
case_features = {k: v for k, v in case.items() if k != "diagnosis"}
similarity = self._calculate_case_similarity(patient_features, case_features)
similar_cases.append({
"case": case,
"similarity": similarity
})
# Sort by similarity
similar_cases.sort(key=lambda x: x["similarity"], reverse=True)
# Get top similar cases
top_cases = similar_cases[:10] # Consider top 10 most similar cases
# Count diagnoses in top cases
diagnosis_counts = defaultdict(int)
diagnosis_similarities = defaultdict(list)
for case_data in top_cases:
case = case_data["case"]
similarity = case_data["similarity"]
diagnosis = case.get("diagnosis", "")
if diagnosis:
diagnosis_counts[diagnosis] += 1
diagnosis_similarities[diagnosis].append(similarity)
# Calculate confidence for each diagnosis
for diagnosis, count in diagnosis_counts.items():
# Confidence is based on frequency and average similarity
frequency = count / len(top_cases)
avg_similarity = sum(diagnosis_similarities[diagnosis]) / len(diagnosis_similarities[diagnosis])
confidence = 0.7 * frequency + 0.3 * avg_similarity
diagnoses.append({
"disease": diagnosis,
"confidence": confidence,
"method": "case_based",
"similar_cases": count
})
# Record evidence from similar cases
for i, case_data in enumerate(top_cases):
case = case_data["case"]
if case.get("diagnosis") == diagnosis:
evidence[diagnosis].append({
"type": "similar_case",
"item": f"Case {i+1}",
"similarity": case_data["similarity"],
"weight": "high" if case_data["similarity"] > 0.8 else "medium"
})
return {
"diagnoses": diagnoses,
"evidence": dict(evidence)
}
def _extract_case_features(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract features from patient data for case-based reasoning.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary of features for case comparison.
"""
features = {}
# Extract demographic features
demographics = patient_data.get("demographics", {})
features["age"] = demographics.get("age", 0)
features["gender"] = demographics.get("gender", "unknown")
# Extract symptom features (as binary indicators)
for symptom in patient_data.get("symptoms", []):
features[f"symptom_{symptom}"] = 1
# Extract lab result features
for lab_test, value in patient_data.get("lab_results", {}).items():
features[f"lab_{lab_test}"] = value
return features
def _calculate_case_similarity(self, features1: Dict[str, Any], features2: Dict[str, Any]) -> float:
"""
Calculate similarity between two cases based on their features.
Args:
features1: Features of the first case.
features2: Features of the second case.
Returns:
Similarity score between 0 and 1.
"""
# Combine all feature keys
all_keys = set(features1.keys()) | set(features2.keys())
# Initialize similarity components
matches = 0
total = 0
for key in all_keys:
# Skip if key is not relevant for comparison
if key in ["patient_id", "case_id"]:
continue
total += 1
# Check if both cases have this feature
if key in features1 and key in features2:
value1 = features1[key]
value2 = features2[key]
# Calculate similarity based on feature type
if isinstance(value1, (int, float)) and isinstance(value2, (int, float)):
# Numeric feature: calculate normalized difference
max_val = max(abs(value1), abs(value2))
if max_val > 0:
similarity = 1 - abs(value1 - value2) / max_val
else:
similarity = 1.0
elif isinstance(value1, str) and isinstance(value2, str):
# String feature: exact match
similarity = 1.0 if value1 == value2 else 0.0
else:
# Other types: exact match
similarity = 1.0 if value1 == value2 else 0.0
matches += similarity
# Calculate overall similarity
if total > 0:
return matches / total
else:
return 0.0
def _deep_learning_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply deep learning reasoning for diagnosis.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary containing diagnoses and supporting evidence.
"""
# This is a placeholder for deep learning-based diagnosis
# In a real implementation, this would use a trained neural network model
logger.warning("Deep learning reasoning is not fully implemented")
# Return empty results for now
return {
"diagnoses": [],
"evidence": {}
}
def _hybrid_reasoning(self, patient_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Apply hybrid reasoning combining multiple methods.
Args:
patient_data: Preprocessed patient data.
Returns:
Dictionary containing diagnoses and supporting evidence.
"""
all_diagnoses = []
all_evidence = defaultdict(list)
# Apply each enabled reasoning method
enabled_methods = self.config.get("reasoning_methods", ["rule_based", "bayesian_network"])
for method_name in enabled_methods:
if method_name == "hybrid":
continue # Skip to avoid recursion
if method_name in self.reasoning_methods:
method_func = self.reasoning_methods[method_name]
result = method_func(patient_data)
# Add method identifier to each diagnosis
for diagnosis in result["diagnoses"]:
diagnosis["method"] = method_name
all_diagnoses.extend(result["diagnoses"])
# Merge evidence
for disease, evidence_list in result["evidence"].items():
all_evidence[disease].extend(evidence_list)
# Combine diagnoses from different methods
combined_diagnoses = self._combine_diagnoses(all_diagnoses)
return {
"diagnoses": combined_diagnoses,
"evidence": dict(all_evidence)
}
def _combine_diagnoses(self, diagnoses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Combine diagnoses from different reasoning methods.
Args:
diagnoses: List of diagnoses from different methods.
Returns:
Combined list of diagnoses with updated confidence scores.
"""
# Group diagnoses by disease
disease_groups = defaultdict(list)
for diagnosis in diagnoses:
disease = diagnosis["disease"]
disease_groups[disease].append(diagnosis)
# Combine diagnoses for each disease
combined_diagnoses = []
for disease, group in disease_groups.items():
# Calculate combined confidence
# Strategy: weighted average of confidence scores
method_weights = {
"rule_based": 0.3,
"bayesian_network": 0.4,
"case_based": 0.2,
"deep_learning": 0.1
}
total_weight = 0
weighted_confidence = 0
for diagnosis in group:
method = diagnosis.get("method", "unknown")
weight = method_weights.get(method, 0.1)
confidence = diagnosis.get("confidence", 0)
weighted_confidence += weight * confidence
total_weight += weight
if total_weight > 0:
combined_confidence = weighted_confidence / total_weight
else:
combined_confidence = 0
# Create combined diagnosis
combined_diagnosis = {
"disease": disease,
"confidence": combined_confidence,
"methods": [d.get("method", "unknown") for d in group],
"individual_confidences": {d.get("method", f"method_{i}"): d.get("confidence", 0)
for i, d in enumerate(group)}
}
combined_diagnoses.append(combined_diagnosis)
return combined_diagnoses
def _generate_explanation(self, primary_diagnosis: Optional[Dict[str, Any]],
differential_diagnoses: List[Dict[str, Any]],
evidence: Dict[str, List[Dict[str, Any]]],
method: str) -> Dict[str, Any]:
"""
Generate an explanation for the diagnostic results.
Args:
primary_diagnosis: Primary diagnosis.
differential_diagnoses: List of differential diagnoses.
evidence: Supporting evidence for diagnoses.
method: Reasoning method used.
Returns:
Dictionary containing explanation components.
"""
explanation = {
"reasoning_method": method,
"summary": "",
"primary_diagnosis_explanation": "",
"differential_diagnoses_explanation": [],
"evidence_summary": {}
}
# Generate summary
if primary_diagnosis:
disease = primary_diagnosis["disease"]
confidence = primary_diagnosis["confidence"]
confidence_text = self._confidence_to_text(confidence)
explanation["summary"] = (
f"The most likely diagnosis is {disease} ({confidence_text} confidence: {confidence:.2f}). "
f"This diagnosis was determined using {method} reasoning."
)
# Generate primary diagnosis explanation
disease_evidence = evidence.get(disease, [])
if disease_evidence:
evidence_text = self._summarize_evidence(disease_evidence)
explanation["primary_diagnosis_explanation"] = (
f"The diagnosis of {disease} is supported by the following evidence: {evidence_text}"
)
else:
explanation["summary"] = "No confident diagnosis could be determined based on the available information."
# Generate differential diagnoses explanation
for diagnosis in differential_diagnoses:
disease = diagnosis["disease"]
confidence = diagnosis["confidence"]
confidence_text = self._confidence_to_text(confidence)
disease_evidence = evidence.get(disease, [])
evidence_text = self._summarize_evidence(disease_evidence)
diff_explanation = {
"disease": disease,
"confidence": confidence,
"confidence_text": confidence_text,
"explanation": f"Alternative diagnosis: {disease} ({confidence_text} confidence: {confidence:.2f}). "
f"Supporting evidence: {evidence_text}"
}
explanation["differential_diagnoses_explanation"].append(diff_explanation)
# Generate evidence summary
for disease, disease_evidence in evidence.items():
explanation["evidence_summary"][disease] = self._categorize_evidence(disease_evidence)
return explanation
def _confidence_to_text(self, confidence: float) -> str:
"""
Convert confidence score to text description.
Args:
confidence: Confidence score between 0 and 1.
Returns:
Text description of confidence.
"""
if confidence >= 0.9:
return "very high"
elif confidence >= 0.75:
return "high"
elif confidence >= 0.6:
return "moderate"
elif confidence >= 0.4:
return "low"
else:
return "very low"
def _summarize_evidence(self, evidence_list: List[Dict[str, Any]]) -> str:
"""
Summarize evidence items into a text description.
Args:
evidence_list: List of evidence items.
Returns:
Text summary of evidence.
"""
if not evidence_list:
return "No specific evidence available."
# Group evidence by type and weight
high_evidence = [e["item"] for e in evidence_list if e.get("weight") == "high"]
medium_evidence = [e["item"] for e in evidence_list if e.get("weight") == "medium"]
other_evidence = [e["item"] for e in evidence_list if e.get("weight") not in ["high", "medium"]]
evidence_parts = []
if high_evidence:
evidence_parts.append(f"Strong indicators: {', '.join(high_evidence)}")
if medium_evidence:
evidence_parts.append(f"Supporting factors: {', '.join(medium_evidence)}")
if other_evidence:
evidence_parts.append(f"Other factors: {', '.join(other_evidence)}")
return ". ".join(evidence_parts)
def _categorize_evidence(self, evidence_list: List[Dict[str, Any]]) -> Dict[str, List[str]]:
"""
Categorize evidence items by type and weight.
Args:
evidence_list: List of evidence items.
Returns:
Dictionary with categorized evidence.
"""
categorized = {
"symptoms": [],
"lab_results": [],
"imaging": [],
"patient_history": [],
"other": []
}
for evidence in evidence_list:
item = evidence["item"]
evidence_type = evidence.get("type", "")
if "symptom" in evidence_type:
categorized["symptoms"].append(item)
elif "lab" in evidence_type:
categorized["lab_results"].append(item)
elif "image" in evidence_type:
categorized["imaging"].append(item)
elif "history" in evidence_type:
categorized["patient_history"].append(item)
else:
categorized["other"].append(item)
return categorized
modules\diagnosis_engine\integration_engine.py
"""
Diagnostic Integration Engine Module
This module provides functionality for integrating results from different analysis modules,
including medical image analysis, clinical text analysis, and lab results, to provide
comprehensive diagnostic suggestions.
"""
import os
import json
import logging
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Any, Optional, Set, Union
from pathlib import Path
from collections import defaultdict
from .knowledge_graph import MedicalKnowledgeGraph
from .inference_engine import MedicalInferenceEngine
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DiagnosticIntegrationEngine:
"""
Diagnostic Integration Engine for combining results from different analysis modules.
This class provides methods for integrating results from medical image analysis,
clinical text analysis, and lab results to provide comprehensive diagnostic suggestions.
"""
def __init__(self, inference_engine: Optional[MedicalInferenceEngine] = None, config_path: Optional[str] = None):
"""
Initialize the Diagnostic Integration Engine.
Args:
inference_engine: Medical inference engine to use for reasoning.
If None, a new inference engine will be created.
config_path: Path to the configuration file for the integration engine.
If None, default configuration will be used.
"""
self.inference_engine = inference_engine if inference_engine else MedicalInferenceEngine()
self.config = self._load_config(config_path)
logger.info("Diagnostic Integration Engine initialized")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file or use default configuration.
Args:
config_path: Path to the configuration file.
Returns:
Dictionary containing configuration parameters.
"""
default_config = {
"image_analysis_weight": 0.5,
"text_analysis_weight": 0.5,
"lab_results_weight": 0.3,
"patient_history_weight": 0.4,
"fusion_method": "weighted_average"
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f).get("integration", {})
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return default_config
else:
# Try to load from the default location
try:
base_dir = Path(__file__).parent.parent.parent
default_path = base_dir / "config" / "diagnosis_engine_config.json"
if os.path.exists(default_path):
with open(default_path, 'r') as f:
config = json.load(f).get("integration", {})
logger.info(f"Loaded configuration from {default_path}")
return config
except Exception as e:
logger.error(f"Error loading default configuration: {e}")
logger.info("Using default integration engine configuration")
return default_config
def integrate_results(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]] = None,
text_analysis: Optional[Dict[str, Any]] = None,
lab_results: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Integrate results from different analysis modules.
Args:
patient_data: Dictionary containing patient demographic data.
image_analysis: Results from medical image analysis.
text_analysis: Results from clinical text analysis.
lab_results: Laboratory test results.
Returns:
Dictionary containing integrated diagnostic results.
"""
# Prepare integrated data for inference
integrated_data = self._prepare_integrated_data(patient_data, image_analysis, text_analysis, lab_results)
# Run diagnostic inference
diagnosis_results = self.inference_engine.diagnose(integrated_data)
# Enhance results with source information
enhanced_results = self._enhance_results_with_sources(diagnosis_results, image_analysis, text_analysis, lab_results)
return enhanced_results
def _prepare_integrated_data(self, patient_data: Dict[str, Any], image_analysis: Optional[Dict[str, Any]],
text_analysis: Optional[Dict[str, Any]],
lab_results: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Prepare integrated data for diagnostic inference.
Args:
patient_data: Patient demographic data.
image_analysis: Results from medical image analysis.
text_analysis: Results from clinical text analysis.
lab_results: Laboratory test results.
Returns:
Integrated data ready for diagnostic inference.
"""
integrated_data = {
"demographics": patient_data.get("demographics", {}),
"symptoms": [],
"findings": [],
"medical_history": patient_data.get("medical_history", [])
}
# Extract symptoms and findings from text analysis
if text_analysis:
# Import text analysis results to knowledge graph
if hasattr(self.inference_engine, "knowledge_graph"):
self.inference_engine.knowledge_graph.import_from_text_analysis(text_analysis)
# Extract symptoms
if "entities" in text_analysis:
symptoms = [e["text"] for e in text_analysis["entities"]
if e.get("label", "").lower() == "symptom"]
integrated_data["symptoms"].extend(symptoms)
# Extract medical history
if "entities" in text_analysis:
diseases = [e["text"] for e in text_analysis["entities"]
if e.get("label", "").lower() == "disease"]
integrated_data["medical_history"].extend(diseases)
# Add text analysis results
integrated_data["text_analysis"] = text_analysis
# Extract findings from image analysis
if image_analysis:
if "findings" in image_analysis:
integrated_data["findings"].extend(image_analysis["findings"])
# Extract abnormalities as potential symptoms
if "abnormalities" in image_analysis:
for abnormality in image_analysis["abnormalities"]:
if "description" in abnormality:
integrated_data["symptoms"].append(abnormality["description"])
# Add image analysis results
integrated_data["image_analysis"] = image_analysis
# Add lab results
if lab_results:
integrated_data["lab_results"] = lab_results
# Extract abnormal lab results as findings
if "abnormal_results" in lab_results:
for test, result in lab_results["abnormal_results"].items():
integrated_data["findings"].append(f"Abnormal {test}: {result}")
# Remove duplicates
integrated_data["symptoms"] = list(set(integrated_data["symptoms"]))
integrated_data["findings"] = list(set(integrated_data["findings"]))
integrated_data["medical_history"] = list(set(integrated_data["medical_history"]))
return integrated_data
def _enhance_results_with_sources(self, diagnosis_results: Dict[str, Any],
image_analysis: Optional[Dict[str, Any]],
text_analysis: Optional[Dict[str, Any]],
lab_results: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Enhance diagnostic results with information about data sources.
Args:
diagnosis_results: Results from diagnostic inference.
image_analysis: Results from medical image analysis.
text_analysis: Results from clinical text analysis.
lab_results: Laboratory test results.
Returns:
Enhanced diagnostic results with source information.
"""
enhanced_results = diagnosis_results.copy()
# Add source information
sources = []
if image_analysis:
sources.append({
"type": "image_analysis",
"modality": image_analysis.get("modality", "unknown"),
"body_part": image_analysis.get("body_part", "unknown"),
"confidence": image_analysis.get("confidence", 1.0),
"weight": self.config.get("image_analysis_weight", 0.5)
})
if text_analysis:
sources.append({
"type": "text_analysis",
"document_type": text_analysis.get("document_type", "unknown"),
"confidence": text_analysis.get("confidence", 1.0),
"weight": self.config.get("text_analysis_weight", 0.5)
})
if lab_results:
sources.append({
"type": "lab_results",
"test_count": len(lab_results.get("results", {})),
"abnormal_count": len(lab_results.get("abnormal_results", {})),
"weight": self.config.get("lab_results_weight", 0.3)
})
enhanced_results["data_sources"] = sources
# Add confidence breakdown by source
if enhanced_results.get("primary_diagnosis"):
primary_diagnosis = enhanced_results["primary_diagnosis"]
confidence_breakdown = self._calculate_confidence_breakdown(
primary_diagnosis, image_analysis, text_analysis, lab_results
)
primary_diagnosis["confidence_breakdown"] = confidence_breakdown
return enhanced_results
def _calculate_confidence_breakdown(self, diagnosis: Dict[str, Any],
image_analysis: Optional[Dict[str, Any]],
text_analysis: Optional[Dict[str, Any]],
lab_results: Optional[Dict[str, Any]]) -> Dict[str, float]:
"""
Calculate confidence breakdown by data source.
Args:
diagnosis: Diagnosis to calculate confidence breakdown for.
image_analysis: Results from medical image analysis.
text_analysis: Results from clinical text analysis.
lab_results: Laboratory test results.
Returns:
Dictionary with confidence contribution from each source.
"""
disease = diagnosis["disease"]
confidence_breakdown = {}
# Image analysis contribution
if image_analysis:
# Check if the disease is mentioned in image findings
image_confidence = 0.0
if "findings" in image_analysis:
for finding in image_analysis["findings"]:
if disease.lower() in finding.lower():
image_confidence = 0.8 # High confidence if explicitly mentioned
break
# Check if related abnormalities are present
if image_confidence == 0.0 and "abnormalities" in image_analysis:
for abnormality in image_analysis["abnormalities"]:
# This is simplified; in a real system, we would use the knowledge graph
# to check if the abnormality is related to the disease
if "description" in abnormality and any(
term in abnormality["description"].lower()
for term in disease.lower().split()
):
image_confidence = 0.5 # Moderate confidence for related abnormalities
break
confidence_breakdown["image_analysis"] = image_confidence * self.config.get("image_analysis_weight", 0.5)
# Text analysis contribution
if text_analysis:
text_confidence = 0.0
# Check if the disease is mentioned in entities
if "entities" in text_analysis:
for entity in text_analysis["entities"]:
if entity.get("label", "").lower() == "disease" and entity["text"].lower() == disease.lower():
text_confidence = 0.9 # Very high confidence if explicitly mentioned
break
# Check if related symptoms are present
if text_confidence < 0.5 and "entities" in text_analysis:
symptom_count = 0
for entity in text_analysis["entities"]:
if entity.get("label", "").lower() == "symptom":
# This is simplified; in a real system, we would use the knowledge graph
# to check if the symptom is related to the disease
symptom_count += 1
if symptom_count > 0:
text_confidence = min(0.7, 0.3 + 0.1 * symptom_count) # Confidence increases with symptom count
confidence_breakdown["text_analysis"] = text_confidence * self.config.get("text_analysis_weight", 0.5)
# Lab results contribution
if lab_results:
lab_confidence = 0.0
# Check if there are abnormal results related to the disease
if "abnormal_results" in lab_results:
# This is simplified; in a real system, we would use the knowledge graph
# to check if the abnormal results are related to the disease
abnormal_count = len(lab_results["abnormal_results"])
if abnormal_count > 0:
lab_confidence = min(0.6, 0.2 + 0.1 * abnormal_count) # Confidence increases with abnormal count
confidence_breakdown["lab_results"] = lab_confidence * self.config.get("lab_results_weight", 0.3)
return confidence_breakdown
def generate_diagnostic_report(self, diagnosis_results: Dict[str, Any],
patient_data: Dict[str, Any],
format: str = "text") -> str:
"""
Generate a diagnostic report based on the integrated results.
Args:
diagnosis_results: Results from diagnostic integration.
patient_data: Patient demographic data.
format: Report format ("text", "html", or "json").
Returns:
Diagnostic report in the specified format.
"""
if format == "json":
# Return JSON representation
return json.dumps(diagnosis_results, indent=2)
# Extract patient information
demographics = patient_data.get("demographics", {})
patient_name = demographics.get("name", "Unknown")
patient_age = demographics.get("age", "Unknown")
patient_gender = demographics.get("gender", "Unknown")
# Extract diagnosis information
primary_diagnosis = diagnosis_results.get("primary_diagnosis", {})
differential_diagnoses = diagnosis_results.get("differential_diagnoses", [])
evidence = diagnosis_results.get("evidence", {})
explanation = diagnosis_results.get("explanation", {})
if format == "html":
# Generate HTML report
html_report = f"""<!DOCTYPE html>
<html>
<head>
<title>Diagnostic Report for {patient_name}</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; }}
h1, h2, h3 {{ color: #2c3e50; }}
.section {{ margin-bottom: 20px; }}
.diagnosis {{ background-color: #f8f9fa; padding: 10px; border-radius: 5px; }}
.primary {{ border-left: 5px solid #27ae60; }}
.differential {{ border-left: 5px solid #3498db; }}
.evidence {{ margin-left: 20px; }}
.confidence {{ font-weight: bold; }}
.high {{ color: #27ae60; }}
.moderate {{ color: #f39c12; }}
.low {{ color: #e74c3c; }}
table {{ border-collapse: collapse; width: 100%; }}
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
th {{ background-color: #f2f2f2; }}
</style>
</head>
<body>
<h1>Diagnostic Report</h1>
<div class="section">
<h2>Patient Information</h2>
<table>
<tr><th>Name</th><td>{patient_name}</td></tr>
<tr><th>Age</th><td>{patient_age}</td></tr>
<tr><th>Gender</th><td>{patient_gender}</td></tr>
</table>
</div>
<div class="section">
<h2>Diagnostic Summary</h2>
<p>{explanation.get("summary", "No diagnostic summary available.")}</p>
</div>
"""
# Add primary diagnosis
if primary_diagnosis:
disease = primary_diagnosis.get("disease", "Unknown")
confidence = primary_diagnosis.get("confidence", 0)
confidence_class = "high" if confidence >= 0.75 else "moderate" if confidence >= 0.5 else "low"
html_report += f"""
<div class="section">
<h2>Primary Diagnosis</h2>
<div class="diagnosis primary">
<h3>{disease}</h3>
<p class="confidence">Confidence: <span class="{confidence_class}">{confidence:.2f}</span></p>
<p>{explanation.get("primary_diagnosis_explanation", "")}</p>
<h4>Supporting Evidence:</h4>
<div class="evidence">
"""
# Add evidence for primary diagnosis
disease_evidence = evidence.get(disease, [])
if disease_evidence:
html_report += "<ul>"
for ev in disease_evidence:
html_report += f"<li>{ev.get('item', '')}: {ev.get('type', '')}</li>"
html_report += "</ul>"
else:
html_report += "<p>No specific evidence available.</p>"
html_report += """
</div>
</div>
</div>
"""
# Add differential diagnoses
if differential_diagnoses:
html_report += """
<div class="section">
<h2>Differential Diagnoses</h2>
"""
for diagnosis in differential_diagnoses:
disease = diagnosis.get("disease", "Unknown")
confidence = diagnosis.get("confidence", 0)
confidence_class = "high" if confidence >= 0.75 else "moderate" if confidence >= 0.5 else "low"
html_report += f"""
<div class="diagnosis differential">
<h3>{disease}</h3>
<p class="confidence">Confidence: <span class="{confidence_class}">{confidence:.2f}</span></p>
"""
# Add evidence for differential diagnosis
disease_evidence = evidence.get(disease, [])
if disease_evidence:
html_report += "<h4>Supporting Evidence:</h4><div class=\"evidence\"><ul>"
for ev in disease_evidence:
html_report += f"<li>{ev.get('item', '')}: {ev.get('type', '')}</li>"
html_report += "</ul></div>"
html_report += """
</div>
"""
html_report += """
</div>
"""
# Add data sources
data_sources = diagnosis_results.get("data_sources", [])
if data_sources:
html_report += """
<div class="section">
<h2>Data Sources</h2>
<table>
<tr>
<th>Source Type</th>
<th>Details</th>
<th>Weight</th>
</tr>
"""
for source in data_sources:
source_type = source.get("type", "Unknown")
details = ", ".join([f"{k}: {v}" for k, v in source.items()
if k not in ["type", "weight"]])
weight = source.get("weight", 0)
html_report += f"""
<tr>
<td>{source_type}</td>
<td>{details}</td>
<td>{weight:.2f}</td>
</tr>
"""
html_report += """
</table>
</div>
"""
# Add footer
html_report += """
<div class="section">
<p><em>This report was generated by the Medical Diagnostic System. It is intended to assist healthcare professionals and should not replace clinical judgment.</em></p>
</div>
</body>
</html>
"""
return html_report
else: # Text format
# Generate text report
text_report = f"""DIAGNOSTIC REPORT
================
PATIENT INFORMATION
------------------
Name: {patient_name}
Age: {patient_age}
Gender: {patient_gender}
DIAGNOSTIC SUMMARY
-----------------
{explanation.get("summary", "No diagnostic summary available.")}
"""
# Add primary diagnosis
if primary_diagnosis:
disease = primary_diagnosis.get("disease", "Unknown")
confidence = primary_diagnosis.get("confidence", 0)
text_report += f"""PRIMARY DIAGNOSIS
----------------
{disease} (Confidence: {confidence:.2f})
{explanation.get("primary_diagnosis_explanation", "")}
Supporting Evidence:
"""
# Add evidence for primary diagnosis
disease_evidence = evidence.get(disease, [])
if disease_evidence:
for ev in disease_evidence:
text_report += f"- {ev.get('item', '')}: {ev.get('type', '')}\n"
else:
text_report += "No specific evidence available.\n"
text_report += "\n"
# Add differential diagnoses
if differential_diagnoses:
text_report += """DIFFERENTIAL DIAGNOSES
----------------------
"""
for diagnosis in differential_diagnoses:
disease = diagnosis.get("disease", "Unknown")
confidence = diagnosis.get("confidence", 0)
text_report += f"{disease} (Confidence: {confidence:.2f})\n"
# Add evidence for differential diagnosis
disease_evidence = evidence.get(disease, [])
if disease_evidence:
text_report += "Supporting Evidence:\n"
for ev in disease_evidence:
text_report += f"- {ev.get('item', '')}: {ev.get('type', '')}\n"
text_report += "\n"
# Add data sources
data_sources = diagnosis_results.get("data_sources", [])
if data_sources:
text_report += """DATA SOURCES
------------
"""
for source in data_sources:
source_type = source.get("type", "Unknown")
details = ", ".join([f"{k}: {v}" for k, v in source.items()
if k not in ["type", "weight"]])
weight = source.get("weight", 0)
text_report += f"{source_type} (Weight: {weight:.2f}) - {details}\n"
text_report += "\n"
# Add footer
text_report += """NOTE: This report was generated by the Medical Diagnostic System. It is intended to assist healthcare professionals and should not replace clinical judgment.
"""
return text_report
modules\diagnosis_engine\knowledge_graph.py
"""
Medical Knowledge Graph Module
This module provides functionality for building and querying a medical knowledge graph
that serves as the foundation for the diagnostic reasoning engine.
"""
import os
import json
import logging
from typing import Dict, List, Tuple, Any, Optional, Set, Union
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from pathlib import Path
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MedicalKnowledgeGraph:
"""
Medical Knowledge Graph for storing and querying medical knowledge.
This class provides methods for building a knowledge graph from various sources,
including medical ontologies, clinical guidelines, and extracted relations from
clinical text analysis.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the Medical Knowledge Graph.
Args:
config_path: Path to the configuration file for the knowledge graph.
If None, default configuration will be used.
"""
self.graph = nx.DiGraph()
self.config = self._load_config(config_path)
self.node_types = {
"disease": {"color": "red", "shape": "o"},
"symptom": {"color": "blue", "shape": "s"},
"medication": {"color": "green", "shape": "^"},
"procedure": {"color": "purple", "shape": "d"},
"lab_test": {"color": "orange", "shape": "p"},
"anatomy": {"color": "brown", "shape": "h"},
"risk_factor": {"color": "pink", "shape": "*"}
}
self.relation_types = {
"causes": {"color": "red", "style": "solid"},
"treats": {"color": "green", "style": "solid"},
"diagnoses": {"color": "blue", "style": "dashed"},
"contraindicates": {"color": "orange", "style": "dotted"},
"has_symptom": {"color": "purple", "style": "solid"},
"located_in": {"color": "brown", "style": "dashed"},
"increases_risk": {"color": "pink", "style": "dotted"}
}
logger.info("Medical Knowledge Graph initialized")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file or use default configuration.
Args:
config_path: Path to the configuration file.
Returns:
Dictionary containing configuration parameters.
"""
default_config = {
"database_type": "in_memory",
"external_sources": ["umls", "snomed_ct", "icd10"],
"relation_confidence_threshold": 0.7
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f).get("knowledge_graph", {})
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration: {e}")
return default_config
else:
# Try to load from the default location
try:
base_dir = Path(__file__).parent.parent.parent
default_path = base_dir / "config" / "diagnosis_engine_config.json"
if os.path.exists(default_path):
with open(default_path, 'r') as f:
config = json.load(f).get("knowledge_graph", {})
logger.info(f"Loaded configuration from {default_path}")
return config
except Exception as e:
logger.error(f"Error loading default configuration: {e}")
logger.info("Using default knowledge graph configuration")
return default_config
def add_node(self, node_id: str, node_type: str, properties: Dict[str, Any] = None) -> None:
"""
Add a node to the knowledge graph.
Args:
node_id: Unique identifier for the node.
node_type: Type of the node (e.g., disease, symptom, medication).
properties: Additional properties of the node.
"""
if properties is None:
properties = {}
if node_type not in self.node_types:
logger.warning(f"Unknown node type: {node_type}, using default properties")
node_visual = {"color": "gray", "shape": "o"}
else:
node_visual = self.node_types[node_type]
# Merge visual properties with other properties
node_data = {
"id": node_id,
"type": node_type,
"visual": node_visual,
**properties
}
self.graph.add_node(node_id, **node_data)
logger.debug(f"Added node: {node_id} of type {node_type}")
def add_relation(self, source_id: str, target_id: str, relation_type: str,
confidence: float = 1.0, properties: Dict[str, Any] = None) -> None:
"""
Add a relation between two nodes in the knowledge graph.
Args:
source_id: Identifier of the source node.
target_id: Identifier of the target node.
relation_type: Type of the relation (e.g., causes, treats).
confidence: Confidence score for the relation (0.0 to 1.0).
properties: Additional properties of the relation.
"""
if properties is None:
properties = {}
# Check if nodes exist, if not, create placeholder nodes
if source_id not in self.graph:
logger.warning(f"Source node {source_id} does not exist, creating placeholder")
self.add_node(source_id, "unknown")
if target_id not in self.graph:
logger.warning(f"Target node {target_id} does not exist, creating placeholder")
self.add_node(target_id, "unknown")
# Check confidence threshold
if confidence < self.config.get("relation_confidence_threshold", 0.0):
logger.debug(f"Relation confidence {confidence} below threshold, not adding relation")
return
if relation_type not in self.relation_types:
logger.warning(f"Unknown relation type: {relation_type}, using default properties")
relation_visual = {"color": "gray", "style": "solid"}
else:
relation_visual = self.relation_types[relation_type]
# Merge visual properties with other properties
relation_data = {
"type": relation_type,
"confidence": confidence,
"visual": relation_visual,
**properties
}
self.graph.add_edge(source_id, target_id, **relation_data)
logger.debug(f"Added relation: {source_id} --[{relation_type}]--> {target_id} (conf: {confidence})")
def import_from_csv(self, nodes_file: str, relations_file: str) -> None:
"""
Import knowledge graph from CSV files.
Args:
nodes_file: Path to CSV file containing node data.
relations_file: Path to CSV file containing relation data.
"""
try:
# Import nodes
nodes_df = pd.read_csv(nodes_file)
for _, row in nodes_df.iterrows():
node_id = row['id']
node_type = row['type']
properties = {col: row[col] for col in nodes_df.columns if col not in ['id', 'type']}
self.add_node(node_id, node_type, properties)
# Import relations
relations_df = pd.read_csv(relations_file)
for _, row in relations_df.iterrows():
source_id = row['source_id']
target_id = row['target_id']
relation_type = row['type']
confidence = float(row.get('confidence', 1.0))
properties = {col: row[col] for col in relations_df.columns
if col not in ['source_id', 'target_id', 'type', 'confidence']}
self.add_relation(source_id, target_id, relation_type, confidence, properties)
logger.info(f"Imported {len(nodes_df)} nodes and {len(relations_df)} relations from CSV")
except Exception as e:
logger.error(f"Error importing from CSV: {e}")
def import_from_text_analysis(self, analysis_results: Dict[str, Any]) -> None:
"""
Import knowledge from clinical text analysis results.
Args:
analysis_results: Results from clinical text analysis containing
entities and relations.
"""
try:
# Extract entities
entities = analysis_results.get("entities", [])
for entity in entities:
entity_id = entity.get("id", entity.get("text", ""))
entity_type = entity.get("label", "unknown")
properties = {
"text": entity.get("text", ""),
"source": "text_analysis",
"confidence": entity.get("confidence", 1.0)
}
self.add_node(entity_id, entity_type, properties)
# Extract relations
relations = analysis_results.get("relations", [])
for relation in relations:
source_id = relation.get("entity1", {}).get("id", relation.get("entity1", {}).get("text", ""))
target_id = relation.get("entity2", {}).get("id", relation.get("entity2", {}).get("text", ""))
relation_type = relation.get("relation_type", "unknown")
confidence = relation.get("confidence", 1.0)
properties = {"source": "text_analysis"}
self.add_relation(source_id, target_id, relation_type, confidence, properties)
logger.info(f"Imported {len(entities)} entities and {len(relations)} relations from text analysis")
except Exception as e:
logger.error(f"Error importing from text analysis: {e}")
def query_related_nodes(self, node_id: str, relation_type: Optional[str] = None,
direction: str = "outgoing", max_depth: int = 1) -> List[Dict[str, Any]]:
"""
Query nodes related to a given node.
Args:
node_id: Identifier of the node to query.
relation_type: Type of relation to filter by (optional).
direction: Direction of relations ('outgoing', 'incoming', or 'both').
max_depth: Maximum depth of traversal.
Returns:
List of related nodes with their properties.
"""
if node_id not in self.graph:
logger.warning(f"Node {node_id} not found in knowledge graph")
return []
related_nodes = []
if direction in ["outgoing", "both"]:
# Get outgoing relations
for target_id in self.graph.successors(node_id):
edge_data = self.graph.get_edge_data(node_id, target_id)
if relation_type is None or edge_data.get("type") == relation_type:
node_data = self.graph.nodes[target_id]
related_nodes.append({
"id": target_id,
"type": node_data.get("type"),
"relation": {
"type": edge_data.get("type"),
"direction": "outgoing",
"confidence": edge_data.get("confidence", 1.0)
},
"properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
})
if direction in ["incoming", "both"]:
# Get incoming relations
for source_id in self.graph.predecessors(node_id):
edge_data = self.graph.get_edge_data(source_id, node_id)
if relation_type is None or edge_data.get("type") == relation_type:
node_data = self.graph.nodes[source_id]
related_nodes.append({
"id": source_id,
"type": node_data.get("type"),
"relation": {
"type": edge_data.get("type"),
"direction": "incoming",
"confidence": edge_data.get("confidence", 1.0)
},
"properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
})
return related_nodes
def find_path(self, source_id: str, target_id: str, max_length: int = 5) -> List[List[Dict[str, Any]]]:
"""
Find paths between two nodes in the knowledge graph.
Args:
source_id: Identifier of the source node.
target_id: Identifier of the target node.
max_length: Maximum path length to consider.
Returns:
List of paths, where each path is a list of nodes and relations.
"""
if source_id not in self.graph or target_id not in self.graph:
logger.warning(f"Source or target node not found in knowledge graph")
return []
try:
# Find all simple paths between source and target
all_paths = list(nx.all_simple_paths(self.graph, source_id, target_id, cutoff=max_length))
# Format paths with node and relation information
formatted_paths = []
for path in all_paths:
formatted_path = []
for i in range(len(path) - 1):
current_node = path[i]
next_node = path[i + 1]
# Get node data
current_node_data = self.graph.nodes[current_node]
# Get edge data
edge_data = self.graph.get_edge_data(current_node, next_node)
formatted_path.append({
"node": {
"id": current_node,
"type": current_node_data.get("type"),
"properties": {k: v for k, v in current_node_data.items()
if k not in ["id", "type"]}
},
"relation": {
"type": edge_data.get("type"),
"confidence": edge_data.get("confidence", 1.0),
"properties": {k: v for k, v in edge_data.items()
if k not in ["type", "confidence"]}
}
})
# Add the last node
last_node_data = self.graph.nodes[path[-1]]
formatted_path.append({
"node": {
"id": path[-1],
"type": last_node_data.get("type"),
"properties": {k: v for k, v in last_node_data.items()
if k not in ["id", "type"]}
},
"relation": None
})
formatted_paths.append(formatted_path)
return formatted_paths
except Exception as e:
logger.error(f"Error finding path: {e}")
return []
def get_node_info(self, node_id: str) -> Dict[str, Any]:
"""
Get detailed information about a node.
Args:
node_id: Identifier of the node.
Returns:
Dictionary containing node information.
"""
if node_id not in self.graph:
logger.warning(f"Node {node_id} not found in knowledge graph")
return {}
node_data = self.graph.nodes[node_id]
# Get outgoing relations
outgoing_relations = []
for target_id in self.graph.successors(node_id):
edge_data = self.graph.get_edge_data(node_id, target_id)
target_data = self.graph.nodes[target_id]
outgoing_relations.append({
"target": {
"id": target_id,
"type": target_data.get("type"),
"name": target_data.get("name", target_id)
},
"type": edge_data.get("type"),
"confidence": edge_data.get("confidence", 1.0)
})
# Get incoming relations
incoming_relations = []
for source_id in self.graph.predecessors(node_id):
edge_data = self.graph.get_edge_data(source_id, node_id)
source_data = self.graph.nodes[source_id]
incoming_relations.append({
"source": {
"id": source_id,
"type": source_data.get("type"),
"name": source_data.get("name", source_id)
},
"type": edge_data.get("type"),
"confidence": edge_data.get("confidence", 1.0)
})
return {
"id": node_id,
"type": node_data.get("type"),
"properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]},
"outgoing_relations": outgoing_relations,
"incoming_relations": incoming_relations
}
def find_nodes_by_property(self, property_name: str, property_value: Any) -> List[Dict[str, Any]]:
"""
Find nodes that have a specific property value.
Args:
property_name: Name of the property to search for.
property_value: Value of the property to match.
Returns:
List of nodes matching the criteria.
"""
matching_nodes = []
for node_id, node_data in self.graph.nodes(data=True):
if property_name in node_data and node_data[property_name] == property_value:
matching_nodes.append({
"id": node_id,
"type": node_data.get("type"),
"properties": {k: v for k, v in node_data.items() if k not in ["id", "type"]}
})
return matching_nodes
def get_subgraph(self, node_ids: List[str], include_relations: bool = True) -> nx.DiGraph:
"""
Extract a subgraph containing specified nodes and their relations.
Args:
node_ids: List of node identifiers to include in the subgraph.
include_relations: Whether to include relations between the nodes.
Returns:
NetworkX DiGraph representing the subgraph.
"""
# Filter nodes that exist in the graph
valid_nodes = [node_id for node_id in node_ids if node_id in self.graph]
if not valid_nodes:
logger.warning("No valid nodes provided for subgraph extraction")
return nx.DiGraph()
if include_relations:
# Create a subgraph with the specified nodes and all edges between them
subgraph = self.graph.subgraph(valid_nodes).copy()
else:
# Create a subgraph with only the specified nodes, no edges
subgraph = nx.DiGraph()
for node_id in valid_nodes:
node_data = self.graph.nodes[node_id]
subgraph.add_node(node_id, **node_data)
return subgraph
def visualize(self, node_ids: Optional[List[str]] = None,
filename: Optional[str] = None,
show_labels: bool = True,
figsize: Tuple[int, int] = (12, 10)) -> None:
"""
Visualize the knowledge graph or a subgraph.
Args:
node_ids: List of node identifiers to visualize (optional).
filename: Path to save the visualization (optional).
show_labels: Whether to show node labels.
figsize: Figure size as (width, height) in inches.
"""
if node_ids:
# Visualize a subgraph
graph_to_viz = self.get_subgraph(node_ids)
title = f"Medical Knowledge Subgraph ({len(node_ids)} nodes)"
else:
# Visualize the entire graph
graph_to_viz = self.graph
title = f"Medical Knowledge Graph ({len(self.graph.nodes)} nodes, {len(self.graph.edges)} edges)"
if len(graph_to_viz.nodes) == 0:
logger.warning("No nodes to visualize")
return
plt.figure(figsize=figsize)
plt.title(title)
# Set up node positions using spring layout
pos = nx.spring_layout(graph_to_viz, seed=42)
# Draw nodes with different colors and shapes based on node type
for node_type, style in self.node_types.items():
node_list = [n for n, data in graph_to_viz.nodes(data=True) if data.get("type") == node_type]
if node_list:
nx.draw_networkx_nodes(
graph_to_viz, pos,
nodelist=node_list,
node_color=style["color"],
node_shape=style["shape"],
node_size=500,
alpha=0.8
)
# Draw any nodes with unknown types
unknown_nodes = [n for n, data in graph_to_viz.nodes(data=True)
if data.get("type") not in self.node_types]
if unknown_nodes:
nx.draw_networkx_nodes(
graph_to_viz, pos,
nodelist=unknown_nodes,
node_color="gray",
node_shape="o",
node_size=500,
alpha=0.8
)
# Draw edges with different colors and styles based on relation type
for rel_type, style in self.relation_types.items():
edge_list = [(u, v) for u, v, data in graph_to_viz.edges(data=True)
if data.get("type") == rel_type]
if edge_list:
nx.draw_networkx_edges(
graph_to_viz, pos,
edgelist=edge_list,
edge_color=style["color"],
style=style["style"],
alpha=0.7,
arrowsize=15
)
# Draw any edges with unknown types
unknown_edges = [(u, v) for u, v, data in graph_to_viz.edges(data=True)
if data.get("type") not in self.relation_types]
if unknown_edges:
nx.draw_networkx_edges(
graph_to_viz, pos,
edgelist=unknown_edges,
edge_color="gray",
style="solid",
alpha=0.7,
arrowsize=15
)
# Draw node labels if requested
if show_labels:
labels = {n: data.get("name", n) for n, data in graph_to_viz.nodes(data=True)}
nx.draw_networkx_labels(graph_to_viz, pos, labels, font_size=10)
# Create legend for node types
node_handles = []
node_labels = []
for node_type, style in self.node_types.items():
node_handles.append(plt.Line2D([0], [0], marker=style["shape"], color='w',
markerfacecolor=style["color"], markersize=10))
node_labels.append(node_type)
# Create legend for relation types
edge_handles = []
edge_labels = []
for rel_type, style in self.relation_types.items():
if style["style"] == "solid":
linestyle = "-"
elif style["style"] == "dashed":
linestyle = "--"
else: # dotted
linestyle = ":"
edge_handles.append(plt.Line2D([0], [1], color=style["color"], linestyle=linestyle))
edge_labels.append(rel_type)
# Add legends
plt.legend(node_handles, node_labels, title="Node Types", loc="upper left", bbox_to_anchor=(1, 1))
plt.legend(edge_handles, edge_labels, title="Relation Types", loc="upper left", bbox_to_anchor=(1, 0.8))
plt.axis('off')
plt.tight_layout()
if filename:
plt.savefig(filename, bbox_inches='tight')
logger.info(f"Visualization saved to {filename}")
plt.show()
def export_to_csv(self, nodes_file: str, relations_file: str) -> None:
"""
Export the knowledge graph to CSV files.
Args:
nodes_file: Path to save the nodes CSV file.
relations_file: Path to save the relations CSV file.
"""
try:
# Export nodes
nodes_data = []
for node_id, node_data in self.graph.nodes(data=True):
node_dict = {"id": node_id, "type": node_data.get("type", "unknown")}
for k, v in node_data.items():
if k not in ["id", "type", "visual"]:
node_dict[k] = v
nodes_data.append(node_dict)
nodes_df = pd.DataFrame(nodes_data)
nodes_df.to_csv(nodes_file, index=False)
# Export relations
relations_data = []
for source_id, target_id, edge_data in self.graph.edges(data=True):
relation_dict = {
"source_id": source_id,
"target_id": target_id,
"type": edge_data.get("type", "unknown"),
"confidence": edge_data.get("confidence", 1.0)
}
for k, v in edge_data.items():
if k not in ["type", "confidence", "visual"]:
relation_dict[k] = v
relations_data.append(relation_dict)
relations_df = pd.DataFrame(relations_data)
relations_df.to_csv(relations_file, index=False)
logger.info(f"Exported {len(nodes_data)} nodes to {nodes_file}")
logger.info(f"Exported {len(relations_data)} relations to {relations_file}")
except Exception as e:
logger.error(f"Error exporting to CSV: {e}")
def get_statistics(self) -> Dict[str, Any]:
"""
Get statistics about the knowledge graph.
Returns:
Dictionary containing various statistics about the knowledge graph.
"""
stats = {
"num_nodes": len(self.graph.nodes),
"num_edges": len(self.graph.edges),
"node_types": {},
"relation_types": {},
"avg_degree": 0,
"density": nx.density(self.graph)
}
# Count node types
for _, data in self.graph.nodes(data=True):
node_type = data.get("type", "unknown")
if node_type in stats["node_types"]:
stats["node_types"][node_type] += 1
else:
stats["node_types"][node_type] = 1
# Count relation types
for _, _, data in self.graph.edges(data=True):
relation_type = data.get("type", "unknown")
if relation_type in stats["relation_types"]:
stats["relation_types"][relation_type] += 1
else:
stats["relation_types"][relation_type] = 1
# Calculate average degree
if stats["num_nodes"] > 0:
stats["avg_degree"] = 2 * stats["num_edges"] / stats["num_nodes"]
return stats
def clear(self) -> None:
"""
Clear the knowledge graph.
"""
self.graph.clear()
logger.info("Knowledge graph cleared")
modules\image_analysis\image_analysis.py
"""
Medical Image Analysis Module Main Script
This script demonstrates how to use the image analysis module components
for processing, segmenting, and classifying medical images.
"""
import os
import logging
import argparse
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from pathlib import Path
from typing import Dict, List, Any, Optional, Union, Tuple
# Import image analysis modules
from modules.image_analysis.image_processor import (
MedicalImageProcessor, XRayProcessor, CTProcessor, MRIProcessor
)
from modules.image_analysis.image_classifier import (
MedicalImageClassifier, ChestXRayClassifier, BrainMRIClassifier
)
from modules.image_analysis.image_segmentation import (
MedicalImageSegmenter, LungSegmenter, BrainSegmenter, OrganSegmenter, DeepLearningSegmenter
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("medical_image_analysis.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def analyze_medical_image(image_path: str, image_type: str, output_dir: str,
config_path: Optional[str] = None) -> Dict[str, Any]:
"""
Analyze a medical image
Args:
image_path: Path to the medical image
image_type: Type of medical image ('xray', 'ct', 'mri')
output_dir: Directory to save results
config_path: Optional path to configuration file
Returns:
Dict[str, Any]: Analysis results
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Initialize appropriate processor and segmenter based on image type
if image_type == 'xray':
processor = XRayProcessor(config_path)
segmenter = LungSegmenter(config_path)
elif image_type == 'ct':
processor = CTProcessor(config_path)
segmenter = OrganSegmenter(config_path)
elif image_type == 'mri':
processor = MRIProcessor(config_path)
segmenter = BrainSegmenter(config_path)
else:
processor = MedicalImageProcessor(config_path)
segmenter = MedicalImageSegmenter(config_path)
# Load and preprocess image
logger.info(f"Loading image: {image_path}")
image = processor.load_image(image_path)
if image is None:
logger.error(f"Failed to load image: {image_path}")
return {"success": False, "error": "Failed to load image"}
logger.info("Preprocessing image...")
preprocessed = processor.preprocess(image)
# Save preprocessed image
preprocessed_path = os.path.join(output_dir, "preprocessed.png")
if len(preprocessed.shape) == 2:
cv2.imwrite(preprocessed_path, (preprocessed * 255).astype(np.uint8))
else:
cv2.imwrite(preprocessed_path, cv2.cvtColor((preprocessed * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
# Segment image
logger.info("Segmenting image...")
if image_type == 'xray':
segmentation_mask = segmenter.segment_lungs(preprocessed)
# Enhance lung structures
enhanced = segmenter.enhance_lung_structures(preprocessed, segmentation_mask)
enhanced_path = os.path.join(output_dir, "enhanced.png")
cv2.imwrite(enhanced_path, (enhanced * 255).astype(np.uint8))
elif image_type == 'mri':
segmentation_mask = segmenter.segment_brain(preprocessed)
# Try to segment tumor
tumor_mask = segmenter.segment_tumor(preprocessed, segmentation_mask)
if np.any(tumor_mask):
tumor_path = os.path.join(output_dir, "tumor_mask.png")
cv2.imwrite(tumor_path, tumor_mask.astype(np.uint8) * 255)
elif image_type == 'ct':
liver_mask = segmenter.segment_liver(preprocessed)
kidney_mask = segmenter.segment_kidneys(preprocessed)
bone_mask = segmenter.segment_bones(preprocessed)
# Create combined mask
segmentation_mask = np.zeros_like(preprocessed[:,:,0] if len(preprocessed.shape) > 2 else preprocessed, dtype=np.uint8)
segmentation_mask[liver_mask] = 1 # Liver
segmentation_mask[kidney_mask] = 2 # Kidneys
segmentation_mask[bone_mask] = 3 # Bones
# Save individual masks
cv2.imwrite(os.path.join(output_dir, 'liver_mask.png'), liver_mask.astype(np.uint8) * 255)
cv2.imwrite(os.path.join(output_dir, 'kidney_mask.png'), kidney_mask.astype(np.uint8) * 255)
cv2.imwrite(os.path.join(output_dir, 'bone_mask.png'), bone_mask.astype(np.uint8) * 255)
else:
segmentation_mask = segmenter.threshold_segmentation(preprocessed)
# Extract region properties
regions = segmenter.extract_region_properties(segmentation_mask)
# Extract features
logger.info("Extracting features...")
features = processor.extract_features(preprocessed, segmentation_mask)
# Visualize results
logger.info("Generating visualization...")
visualization_path = os.path.join(output_dir, "segmentation_result.png")
segmenter.visualize_segmentation(preprocessed, segmentation_mask, regions, visualization_path)
# Save segmentation mask
mask_path = os.path.join(output_dir, "segmentation_mask.png")
cv2.imwrite(mask_path, segmentation_mask.astype(np.uint8) * 255)
# Save results
results = {
"success": True,
"image_path": image_path,
"image_type": image_type,
"features": {k: float(v) if isinstance(v, (np.float32, np.float64)) else v for k, v in features.items()},
"regions": [
{
"label": int(region["label"]),
"area": float(region["area"]),
"centroid": [float(c) for c in region["centroid"]],
"equivalent_diameter": float(region["equivalent_diameter"])
}
for region in regions
],
"visualization_path": visualization_path,
"mask_path": mask_path,
"preprocessed_path": preprocessed_path
}
results_path = os.path.join(output_dir, "analysis_results.json")
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {results_path}")
return results
def classify_medical_image(image_path: str, image_type: str, model_path: str,
output_dir: str, config_path: Optional[str] = None) -> Dict[str, Any]:
"""
Classify a medical image
Args:
image_path: Path to the medical image
image_type: Type of medical image ('xray', 'ct', 'mri')
model_path: Path to the trained model
output_dir: Directory to save results
config_path: Optional path to configuration file
Returns:
Dict[str, Any]: Classification results
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Initialize appropriate classifier based on image type
if image_type == 'xray':
classifier = ChestXRayClassifier(config_path)
elif image_type == 'mri':
classifier = BrainMRIClassifier(config_path)
else:
classifier = MedicalImageClassifier(config_path)
# Load model if it exists
if os.path.exists(model_path):
logger.info(f"Loading model from {model_path}")
success = classifier.load_model(model_path)
if not success:
logger.error(f"Failed to load model from {model_path}")
return {"success": False, "error": "Failed to load model"}
else:
logger.error(f"Model not found: {model_path}")
return {"success": False, "error": "Model not found"}
# Load and preprocess image
logger.info(f"Loading image: {image_path}")
image = cv2.imread(image_path)
if image is None:
logger.error(f"Failed to load image: {image_path}")
return {"success": False, "error": "Failed to load image"}
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Classify image
logger.info("Classifying image...")
prediction = classifier.predict(image)
if not prediction["success"]:
logger.error(f"Classification failed: {prediction.get('error', 'Unknown error')}")
return prediction
# Visualize prediction
logger.info("Generating visualization...")
visualization_path = os.path.join(output_dir, "classification_result.png")
classifier.visualize_predictions(image, prediction, visualization_path)
# Add visualization path to results
prediction["visualization_path"] = visualization_path
# Save results
results_path = os.path.join(output_dir, "classification_results.json")
with open(results_path, 'w') as f:
json.dump(prediction, f, indent=2)
logger.info(f"Results saved to {results_path}")
return prediction
def segment_with_deep_learning(image_path: str, model_path: str, output_dir: str,
config_path: Optional[str] = None) -> Dict[str, Any]:
"""
Segment a medical image using deep learning
Args:
image_path: Path to the medical image
model_path: Path to the trained segmentation model
output_dir: Directory to save results
config_path: Optional path to configuration file
Returns:
Dict[str, Any]: Segmentation results
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
# Initialize deep learning segmenter
segmenter = DeepLearningSegmenter(config_path, model_path)
# Load image
logger.info(f"Loading image: {image_path}")
image = cv2.imread(image_path)
if image is None:
logger.error(f"Failed to load image: {image_path}")
return {"success": False, "error": "Failed to load image"}
# Convert BGR to RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Segment image
logger.info("Segmenting image with deep learning model...")
segmentation_mask = segmenter.segment_with_model(image)
if segmentation_mask.size == 0:
logger.error("Segmentation failed")
return {"success": False, "error": "Segmentation failed"}
# Extract region properties
regions = segmenter.extract_region_properties(segmentation_mask)
# Visualize results
logger.info("Generating visualization...")
visualization_path = os.path.join(output_dir, "deep_segmentation_result.png")
segmenter.visualize_segmentation(image, segmentation_mask, regions, visualization_path)
# Save segmentation mask
mask_path = os.path.join(output_dir, "deep_segmentation_mask.png")
cv2.imwrite(mask_path, segmentation_mask.astype(np.uint8) * 255)
# Save results
results = {
"success": True,
"image_path": image_path,
"model_path": model_path,
"regions": [
{
"label": int(region["label"]),
"area": float(region["area"]),
"centroid": [float(c) for c in region["centroid"]],
"equivalent_diameter": float(region["equivalent_diameter"])
}
for region in regions
],
"visualization_path": visualization_path,
"mask_path": mask_path
}
results_path = os.path.join(output_dir, "deep_segmentation_results.json")
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {results_path}")
return results
def main():
"""Main function to run the image analysis module"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='Medical Image Analysis Module')
parser.add_argument('--image', type=str, required=True,
help='Path to medical image')
parser.add_argument('--type', type=str, choices=['xray', 'ct', 'mri'], required=True,
help='Type of medical image')
parser.add_argument('--mode', type=str, choices=['analyze', 'classify', 'segment'], required=True,
help='Analysis mode')
parser.add_argument('--output', type=str, default='./output',
help='Directory to save results')
parser.add_argument('--config', type=str, default=None,
help='Path to configuration file')
parser.add_argument('--model', type=str, default=None,
help='Path to trained model (for classification or deep segmentation mode)')
args = parser.parse_args()
# Check if image exists
if not os.path.isfile(args.image):
logger.error(f"Image not found: {args.image}")
return
# Create output directory
os.makedirs(args.output, exist_ok=True)
# Run analysis based on mode
if args.mode == 'analyze':
results = analyze_medical_image(args.image, args.type, args.output, args.config)
elif args.mode == 'classify':
if args.model is None:
logger.error("Model path is required for classification mode")
return
results = classify_medical_image(args.image, args.type, args.model, args.output, args.config)
else: # segment
if args.model is None:
logger.error("Model path is required for deep segmentation mode")
return
results = segment_with_deep_learning(args.image, args.model, args.output, args.config)
# Print summary
if results["success"]:
print("\n===== Medical Image Analysis Summary =====")
print(f"Image: {args.image}")
print(f"Type: {args.type}")
print(f"Mode: {args.mode}")
if args.mode == 'analyze':
print("\nFeatures:")
for feature, value in results["features"].items():
print(f" - {feature}: {value}")
print(f"\nDetected {len(results['regions'])} regions")
elif args.mode == 'classify':
print(f"\nPredicted class: {results['predicted_class']}")
print(f"Confidence: {results['confidence']:.4f}")
print("\nClass probabilities:")
for cls, prob in results["probabilities"].items():
print(f" - {cls}: {prob:.4f}")
else: # segment
print(f"\nDetected {len(results['regions'])} regions")
print(f"\nResults saved to {args.output}")
else:
print(f"Error: {results.get('error', 'Unknown error')}")
if __name__ == "__main__":
main()