Python项目:智能医疗辅助诊断系统开发(4)

源代码续

"""
Medical Named Entity Recognition Module

This module provides functionality for extracting medical entities from clinical text,
such as diseases, symptoms, medications, and anatomical structures.
"""

import os
import json
import logging
from typing import List, Dict, Any, Optional, Union, Tuple

import numpy as np
import spacy
from spacy.tokens import Doc, Span
from spacy.training import Example
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("medical_ner.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class MedicalNamedEntityRecognizer:
    """
    Class for recognizing medical named entities in clinical text
    """
    
    def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
        """
        Initialize the medical named entity recognizer
        
        Args:
            config_path: Path to the configuration file
            model_path: Path to a pre-trained model
        """
        self.config = self._load_config(config_path)
        self.entity_types = self.config["ner"]["entity_types"]
        self.confidence_threshold = self.config["ner"]["confidence_threshold"]
        
        # Initialize models
        self.spacy_model = None
        self.transformer_model = None
        self.tokenizer = None
        self.ner_pipeline = None
        
        # Load models based on configuration
        self._load_models(model_path)
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "ner": {
                "model": "medical",
                "confidence_threshold": 0.7,
                "entity_types": [
                    "DISEASE", "SYMPTOM", "TREATMENT", "MEDICATION", "ANATOMY", 
                    "PROCEDURE", "TEST", "DOSAGE", "FREQUENCY", "DURATION"
                ]
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def _load_models(self, model_path: Optional[str]) -> None:
        """
        Load NER models
        
        Args:
            model_path: Path to a pre-trained model
        """
        # Try to load spaCy model
        try:
            self.spacy_model = spacy.load("en_core_sci_md")
            logger.info("Loaded spaCy scientific model: en_core_sci_md")
        except OSError:
            try:
                self.spacy_model = spacy.load("en_core_web_sm")
                logger.info("Loaded spaCy model: en_core_web_sm")
            except OSError:
                logger.warning("spaCy model not found. Using spaCy's blank model.")
                self.spacy_model = spacy.blank('en')
        
        # Load transformer model if available
        if model_path and os.path.exists(model_path):
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(model_path)
                self.transformer_model = AutoModelForTokenClassification.from_pretrained(model_path)
                self.ner_pipeline = pipeline(
                    "ner", 
                    model=self.transformer_model, 
                    tokenizer=self.tokenizer,
                    aggregation_strategy="simple"
                )
                logger.info(f"Loaded transformer model from {model_path}")
            except Exception as e:
                logger.error(f"Error loading transformer model from {model_path}: {e}")
        else:
            # Use a pre-trained medical NER model from Hugging Face
            try:
                model_name = "samrawal/bert-base-uncased_clinical-ner"
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.transformer_model = AutoModelForTokenClassification.from_pretrained(model_name)
                self.ner_pipeline = pipeline(
                    "ner", 
                    model=self.transformer_model, 
                    tokenizer=self.tokenizer,
                    aggregation_strategy="simple"
                )
                logger.info(f"Loaded pre-trained transformer model: {model_name}")
            except Exception as e:
                logger.error(f"Error loading pre-trained transformer model: {e}")
    
    def extract_entities_spacy(self, text: str) -> List[Dict[str, Any]]:
        """
        Extract medical entities using spaCy
        
        Args:
            text: Clinical text
            
        Returns:
            List[Dict[str, Any]]: List of extracted entities
        """
        if not self.spacy_model:
            logger.error("spaCy model not available")
            return []
        
        doc = self.spacy_model(text)
        entities = []
        
        for ent in doc.ents:
            # Filter by entity type if needed
            entity = {
                "text": ent.text,
                "label": ent.label_,
                "start": ent.start_char,
                "end": ent.end_char,
                "source": "spacy"
            }
            entities.append(entity)
        
        return entities
    
    def extract_entities_transformer(self, text: str) -> List[Dict[str, Any]]:
        """
        Extract medical entities using transformer model
        
        Args:
            text: Clinical text
            
        Returns:
            List[Dict[str, Any]]: List of extracted entities
        """
        if not self.ner_pipeline:
            logger.error("Transformer model not available")
            return []
        
        try:
            # Process text with transformer NER pipeline
            ner_results = self.ner_pipeline(text)
            
            entities = []
            for result in ner_results:
                # Filter by confidence threshold
                if result["score"] >= self.confidence_threshold:
                    entity = {
                        "text": result["word"],
                        "label": result["entity"],
                        "start": result["start"],
                        "end": result["end"],
                        "confidence": result["score"],
                        "source": "transformer"
                    }
                    entities.append(entity)
            
            return entities
        except Exception as e:
            logger.error(f"Error in transformer entity extraction: {e}")
            return []
    
    def extract_entities(self, text: str, method: str = "hybrid") -> List[Dict[str, Any]]:
        """
        Extract medical entities from clinical text
        
        Args:
            text: Clinical text
            method: Extraction method ('spacy', 'transformer', or 'hybrid')
            
        Returns:
            List[Dict[str, Any]]: List of extracted entities
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return []
        
        if method == "spacy":
            return self.extract_entities_spacy(text)
        elif method == "transformer":
            return self.extract_entities_transformer(text)
        elif method == "hybrid":
            # Get entities from both methods
            spacy_entities = self.extract_entities_spacy(text)
            transformer_entities = self.extract_entities_transformer(text)
            
            # Combine and deduplicate entities
            all_entities = spacy_entities + transformer_entities
            
            # Simple deduplication based on text and label
            deduplicated = {}
            for entity in all_entities:
                key = f"{entity['text']}_{entity['label']}"
                if key not in deduplicated or (
                    'confidence' in entity and 
                    ('confidence' not in deduplicated[key] or 
                     entity['confidence'] > deduplicated[key]['confidence'])
                ):
                    deduplicated[key] = entity
            
            return list(deduplicated.values())
        else:
            logger.error(f"Unknown extraction method: {method}")
            return []
    
    def extract_medical_concepts(self, text: str) -> Dict[str, List[Dict[str, Any]]]:
        """
        Extract medical concepts categorized by type
        
        Args:
            text: Clinical text
            
        Returns:
            Dict[str, List[Dict[str, Any]]]: Dictionary of entities by type
        """
        # Extract all entities
        entities = self.extract_entities(text)
        
        # Group by entity type
        grouped_entities = {}
        for entity in entities:
            entity_type = entity["label"]
            if entity_type not in grouped_entities:
                grouped_entities[entity_type] = []
            grouped_entities[entity_type].append(entity)
        
        return grouped_entities
    
    def extract_diseases_and_symptoms(self, text: str) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
        """
        Extract diseases and symptoms from clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: Diseases and symptoms
        """
        entities = self.extract_entities(text)
        
        diseases = []
        symptoms = []
        
        for entity in entities:
            if entity["label"] in ["DISEASE", "PROBLEM"]:
                diseases.append(entity)
            elif entity["label"] in ["SYMPTOM", "FINDING"]:
                symptoms.append(entity)
        
        return diseases, symptoms
    
    def extract_medications(self, text: str) -> List[Dict[str, Any]]:
        """
        Extract medications from clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            List[Dict[str, Any]]: List of medications
        """
        entities = self.extract_entities(text)
        
        medications = []
        for entity in entities:
            if entity["label"] in ["MEDICATION", "DRUG"]:
                medications.append(entity)
        
        return medications
    
    def extract_treatments(self, text: str) -> List[Dict[str, Any]]:
        """
        Extract treatments from clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            List[Dict[str, Any]]: List of treatments
        """
        entities = self.extract_entities(text)
        
        treatments = []
        for entity in entities:
            if entity["label"] in ["TREATMENT", "PROCEDURE"]:
                treatments.append(entity)
        
        return treatments
    
    def visualize_entities(self, text: str, entities: List[Dict[str, Any]]) -> str:
        """
        Create a visualization of entities in text
        
        Args:
            text: Original text
            entities: List of entities
            
        Returns:
            str: HTML representation of text with highlighted entities
        """
        if not text or not entities:
            return text
        
        # Sort entities by start position (descending) to avoid index issues when inserting HTML tags
        sorted_entities = sorted(entities, key=lambda x: x["start"], reverse=True)
        
        # Create a copy of the text to modify
        html_text = text
        
        # Define colors for different entity types
        colors = {
            "DISEASE": "#ff9999",
            "SYMPTOM": "#99ff99",
            "MEDICATION": "#9999ff",
            "TREATMENT": "#ffff99",
            "ANATOMY": "#ff99ff",
            "PROCEDURE": "#99ffff",
            "TEST": "#ffcc99",
            "DOSAGE": "#ccff99",
            "FREQUENCY": "#cc99ff",
            "DURATION": "#99ccff",
            "PROBLEM": "#ff9999",
            "FINDING": "#99ff99",
            "DRUG": "#9999ff"
        }
        
        # Insert HTML tags for each entity
        for entity in sorted_entities:
            start = entity["start"]
            end = entity["end"]
            label = entity["label"]
            
            # Get color for entity type, default to gray
            color = colors.get(label, "#cccccc")
            
            # Create span with background color and tooltip
            span_html = f'<span style="background-color: {color};" title="{label}">{html_text[start:end]}</span>'
            
            # Replace the entity text with the HTML span
            html_text = html_text[:start] + span_html + html_text[end:]
        
        # Wrap in a div for styling
        html_text = f'<div style="font-family: Arial, sans-serif; line-height: 1.5;">{html_text}</div>'
        
        return html_text
    
    def train_custom_model(self, training_data: List[Dict[str, Any]], 
                          output_path: str, epochs: int = 10) -> bool:
        """
        Train a custom NER model on medical data
        
        Args:
            training_data: List of training examples
            output_path: Path to save the trained model
            epochs: Number of training epochs
            
        Returns:
            bool: Success status
        """
        # This is a simplified implementation
        # In a real system, this would be more complex
        
        try:
            # Create a blank spaCy model
            nlp = spacy.blank("en")
            
            # Add NER component
            ner = nlp.add_pipe("ner")
            
            # Add entity labels
            for example in training_data:
                for entity in example["entities"]:
                    ner.add_label(entity["label"])
            
            # Prepare training examples
            train_examples = []
            for example in training_data:
                text = example["text"]
                entities = example["entities"]
                
                # Create Doc object
                doc = nlp.make_doc(text)
                
                # Create spans for entities
                spans = []
                for entity in entities:
                    span = Span(doc, entity["start"], entity["end"], label=entity["label"])
                    spans.append(span)
                
                # Set entities
                doc.ents = spans
                
                # Create Example object
                train_examples.append(Example.from_dict(doc, {"entities": entities}))
            
            # Train the model
            optimizer = nlp.begin_training()
            for _ in range(epochs):
                losses = {}
                for example in train_examples:
                    nlp.update([example], drop=0.5, losses=losses)
            
            # Save the model
            nlp.to_disk(output_path)
            logger.info(f"Custom NER model saved to {output_path}")
            
            return True
        except Exception as e:
            logger.error(f"Error training custom NER model: {e}")
            return False

modules\text_analysis\relation_extractor.py

"""
Medical Relation Extraction Module

This module provides functionality for extracting relationships between medical entities
in clinical text, such as disease-symptom, medication-disease, etc.
"""

import os
import json
import logging
import re
from typing import List, Dict, Any, Optional, Union, Tuple, Set

import numpy as np
import spacy
from spacy.tokens import Doc
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    pipeline
)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("medical_relation_extraction.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class MedicalRelationExtractor:
    """
    Class for extracting relationships between medical entities in clinical text
    """
    
    def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
        """
        Initialize the medical relation extractor
        
        Args:
            config_path: Path to the configuration file
            model_path: Path to a pre-trained model
        """
        self.config = self._load_config(config_path)
        self.relation_types = self.config["relation_extraction"]["relation_types"]
        self.confidence_threshold = self.config["relation_extraction"]["confidence_threshold"]
        
        # Initialize models
        self.spacy_model = None
        self.transformer_model = None
        self.tokenizer = None
        self.relation_pipeline = None
        
        # Load models based on configuration
        self._load_models(model_path)
        
        # Define relation patterns for rule-based extraction
        self.relation_patterns = self._define_relation_patterns()
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "relation_extraction": {
                "model": "medical-bert",
                "relation_types": [
                    "TREATS", "CAUSES", "DIAGNOSES", "PREVENTS", "CONTRAINDICATES",
                    "SYMPTOM_OF", "MANIFESTATION_OF", "SUGGESTS"
                ],
                "confidence_threshold": 0.6
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def _load_models(self, model_path: Optional[str]) -> None:
        """
        Load relation extraction models
        
        Args:
            model_path: Path to a pre-trained model
        """
        # Try to load spaCy model
        try:
            self.spacy_model = spacy.load("en_core_web_sm")
            logger.info("Loaded spaCy model: en_core_web_sm")
        except OSError:
            logger.warning("spaCy model not found. Using spaCy's blank model.")
            self.spacy_model = spacy.blank('en')
        
        # Load transformer model if available
        if model_path and os.path.exists(model_path):
            try:
                self.tokenizer = AutoTokenizer.from_pretrained(model_path)
                self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_path)
                self.relation_pipeline = pipeline(
                    "text-classification", 
                    model=self.transformer_model, 
                    tokenizer=self.tokenizer
                )
                logger.info(f"Loaded transformer model from {model_path}")
            except Exception as e:
                logger.error(f"Error loading transformer model from {model_path}: {e}")
        else:
            # Use a pre-trained model from Hugging Face
            try:
                model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_name)
                self.relation_pipeline = pipeline(
                    "text-classification", 
                    model=self.transformer_model, 
                    tokenizer=self.tokenizer
                )
                logger.info(f"Loaded pre-trained transformer model: {model_name}")
            except Exception as e:
                logger.error(f"Error loading pre-trained transformer model: {e}")
    
    def _define_relation_patterns(self) -> Dict[str, List[Dict[str, Any]]]:
        """
        Define patterns for rule-based relation extraction
        
        Returns:
            Dict[str, List[Dict[str, Any]]]: Dictionary of relation patterns
        """
        patterns = {
            "TREATS": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:treats|treat|treating|treated|therapy for|treatment for|used for|indicated for|prescribed for)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "MEDICATION",
                    "entity2_type": "DISEASE"
                },
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is effective for|is used to treat|helps with|alleviates|relieves|cures)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "MEDICATION",
                    "entity2_type": "DISEASE"
                }
            ],
            "CAUSES": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:causes|cause|causing|caused|leads to|results in|induces|triggers)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "DISEASE",
                    "entity2_type": "SYMPTOM"
                },
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is caused by|results from|is triggered by|is induced by|develops from)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "SYMPTOM",
                    "entity2_type": "DISEASE"
                }
            ],
            "DIAGNOSES": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:diagnoses|diagnose|diagnosing|diagnosed|confirms|confirms diagnosis of|indicates|is diagnostic of)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "TEST",
                    "entity2_type": "DISEASE"
                }
            ],
            "PREVENTS": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:prevents|prevent|preventing|prevented|protects against|reduces risk of|prophylaxis for)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "MEDICATION",
                    "entity2_type": "DISEASE"
                }
            ],
            "CONTRAINDICATES": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:contraindicates|contraindicate|contraindicating|contraindicated in|should not be used with|avoid in|not recommended for)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "MEDICATION",
                    "entity2_type": "DISEASE"
                }
            ],
            "SYMPTOM_OF": [
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is a symptom of|symptom of|sign of|manifestation of|indicates|suggests|points to)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "SYMPTOM",
                    "entity2_type": "DISEASE"
                },
                {
                    "pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:presents with|manifests as|shows|exhibits|displays|characterized by)\s+(\b\w+(?:\s+\w+){0,5})",
                    "entity1_type": "DISEASE",
                    "entity2_type": "SYMPTOM"
                }
            ]
        }
        
        return patterns
    
    def extract_relations_rule_based(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Extract relations between entities using rule-based patterns
        
        Args:
            text: Clinical text
            entities: List of entities extracted from the text
            
        Returns:
            List[Dict[str, Any]]: List of extracted relations
        """
        if not text or not entities:
            return []
        
        relations = []
        
        # Create a dictionary of entities by type
        entities_by_type = {}
        for entity in entities:
            entity_type = entity["label"]
            if entity_type not in entities_by_type:
                entities_by_type[entity_type] = []
            entities_by_type[entity_type].append(entity)
        
        # Apply relation patterns
        for relation_type, patterns in self.relation_patterns.items():
            for pattern_info in patterns:
                pattern = pattern_info["pattern"]
                entity1_type = pattern_info["entity1_type"]
                entity2_type = pattern_info["entity2_type"]
                
                # Find all matches of the pattern
                for match in re.finditer(pattern, text):
                    entity1_text = match.group(1).strip()
                    entity2_text = match.group(2).strip()
                    
                    # Find entities that match the extracted text
                    entity1_matches = []
                    entity2_matches = []
                    
                    if entity1_type in entities_by_type:
                        for entity in entities_by_type[entity1_type]:
                            if entity1_text.lower() in entity["text"].lower() or entity["text"].lower() in entity1_text.lower():
                                entity1_matches.append(entity)
                    
                    if entity2_type in entities_by_type:
                        for entity in entities_by_type[entity2_type]:
                            if entity2_text.lower() in entity["text"].lower() or entity["text"].lower() in entity2_text.lower():
                                entity2_matches.append(entity)
                    
                    # Create relations for all combinations of matching entities
                    for entity1 in entity1_matches:
                        for entity2 in entity2_matches:
                            relation = {
                                "relation_type": relation_type,
                                "entity1": entity1,
                                "entity2": entity2,
                                "confidence": 0.8,  # Default confidence for rule-based
                                "source": "rule-based",
                                "context": text[max(0, match.start() - 50):min(len(text), match.end() + 50)]
                            }
                            relations.append(relation)
        
        return relations
    
    def extract_relations_dependency(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Extract relations between entities using dependency parsing
        
        Args:
            text: Clinical text
            entities: List of entities extracted from the text
            
        Returns:
            List[Dict[str, Any]]: List of extracted relations
        """
        if not text or not entities or not self.spacy_model:
            return []
        
        relations = []
        
        # Process text with spaCy
        doc = self.spacy_model(text)
        
        # Create a mapping from character spans to entities
        span_to_entity = {}
        for entity in entities:
            start = entity["start"]
            end = entity["end"]
            span_to_entity[(start, end)] = entity
        
        # Create a mapping from tokens to entities
        token_to_entity = {}
        for token in doc:
            token_start = token.idx
            token_end = token.idx + len(token.text)
            
            for (start, end), entity in span_to_entity.items():
                # Check if token is within entity span or overlaps significantly
                if (token_start >= start and token_end <= end) or \
                   (token_start < end and token_end > start and min(token_end, end) - max(token_start, start) > 0.5 * (token_end - token_start)):
                    token_to_entity[token] = entity
                    break
        
        # Define relation mappings based on dependency patterns
        relation_mappings = {
            "TREATS": [
                {"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["treat", "cure", "heal", "alleviate", "relieve", "manage"]},
                {"dep": "dobj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["treat", "cure", "heal", "alleviate", "relieve", "manage"]}
            ],
            "CAUSES": [
                {"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["cause", "lead", "result", "induce", "trigger"]},
                {"dep": "dobj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["cause", "lead", "result", "induce", "trigger"]}
            ],
            "SYMPTOM_OF": [
                {"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["indicate", "suggest", "imply", "point", "manifest"]},
                {"dep": "prep", "head_dep": "ROOT", "head_pos": "NOUN", "head_text": ["symptom", "sign", "manifestation", "indication"]}
            ]
        }
        
        # Extract relations based on dependency patterns
        for sent in doc.sents:
            # Find the root of the sentence
            root = None
            for token in sent:
                if token.dep_ == "ROOT":
                    root = token
                    break
            
            if not root:
                continue
            
            # Check each relation mapping
            for relation_type, patterns in relation_mappings.items():
                for pattern in patterns:
                    # Find subject and object based on pattern
                    subject = None
                    object = None
                    
                    # Check if root matches the pattern
                    if root.pos_ == pattern.get("head_pos") and \
                       (not pattern.get("head_lemma") or root.lemma_ in pattern.get("head_lemma")) and \
                       (not pattern.get("head_text") or root.text.lower() in pattern.get("head_text")):
                        
                        # Find subject and object
                        for token in sent:
                            if token.dep_ == "nsubj" and token.head == root:
                                # Expand to include compound nouns
                                subject_span = self._expand_noun_phrase(token)
                                subject = subject_span
                            
                            if token.dep_ == "dobj" and token.head == root:
                                # Expand to include compound nouns
                                object_span = self._expand_noun_phrase(token)
                                object = object_span
                        
                        # If both subject and object are found and are entities
                        if subject and object and subject in token_to_entity and object in token_to_entity:
                            entity1 = token_to_entity[subject]
                            entity2 = token_to_entity[object]
                            
                            relation = {
                                "relation_type": relation_type,
                                "entity1": entity1,
                                "entity2": entity2,
                                "confidence": 0.7,  # Default confidence for dependency-based
                                "source": "dependency",
                                "context": sent.text
                            }
                            relations.append(relation)
        
        return relations
    
    def _expand_noun_phrase(self, token) -> Any:
        """
        Expand a token to include its full noun phrase
        
        Args:
            token: spaCy token
            
        Returns:
            Any: Head token of the noun phrase
        """
        # Start with the token itself
        head_token = token
        
        # Check for compound nouns
        for child in token.children:
            if child.dep_ == "compound" and child.i < token.i:
                head_token = child
                break
        
        return head_token
    
    def extract_relations_transformer(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Extract relations between entities using transformer models
        
        Args:
            text: Clinical text
            entities: List of entities extracted from the text
            
        Returns:
            List[Dict[str, Any]]: List of extracted relations
        """
        if not text or not entities or not self.relation_pipeline:
            return []
        
        relations = []
        
        # Group entities by type
        entity_types = {
            "DISEASE": ["DISEASE", "PROBLEM"],
            "SYMPTOM": ["SYMPTOM", "FINDING"],
            "MEDICATION": ["MEDICATION", "DRUG"],
            "TREATMENT": ["TREATMENT", "PROCEDURE"],
            "TEST": ["TEST", "PROCEDURE"]
        }
        
        entities_by_group = {}
        for entity in entities:
            for group, types in entity_types.items():
                if entity["label"] in types:
                    if group not in entities_by_group:
                        entities_by_group[group] = []
                    entities_by_group[group].append(entity)
        
        # Define relation candidates based on entity types
        relation_candidates = [
            {"type": "TREATS", "source": "MEDICATION", "target": "DISEASE"},
            {"type": "CAUSES", "source": "DISEASE", "target": "SYMPTOM"},
            {"type": "DIAGNOSES", "source": "TEST", "target": "DISEASE"},
            {"type": "PREVENTS", "source": "MEDICATION", "target": "DISEASE"},
            {"type": "SYMPTOM_OF", "source": "SYMPTOM", "target": "DISEASE"}
        ]
        
        # Generate relation candidates
        for candidate in relation_candidates:
            source_type = candidate["source"]
            target_type = candidate["target"]
            relation_type = candidate["type"]
            
            if source_type in entities_by_group and target_type in entities_by_group:
                source_entities = entities_by_group[source_type]
                target_entities = entities_by_group[target_type]
                
                for source_entity in source_entities:
                    for target_entity in target_entities:
                        # Skip if entities are too far apart
                        if abs(source_entity["start"] - target_entity["start"]) > 200:
                            continue
                        
                        # Create input for transformer
                        source_text = source_entity["text"]
                        target_text = target_entity["text"]
                        
                        # Get context around entities
                        context_start = max(0, min(source_entity["start"], target_entity["start"]) - 50)
                        context_end = min(len(text), max(source_entity["end"], target_entity["end"]) + 50)
                        context = text[context_start:context_end]
                        
                        # Replace entity mentions with special tokens
                        context = context.replace(source_text, f"[E1] {source_text} [/E1]")
                        context = context.replace(target_text, f"[E2] {target_text} [/E2]")
                        
                        # Classify relation
                        try:
                            result = self.relation_pipeline(context)
                            label = result[0]["label"]
                            score = result[0]["score"]
                            
                            # Check if score is above threshold
                            if score >= self.confidence_threshold:
                                relation = {
                                    "relation_type": relation_type,
                                    "entity1": source_entity,
                                    "entity2": target_entity,
                                    "confidence": score,
                                    "source": "transformer",
                                    "context": context
                                }
                                relations.append(relation)
                        except Exception as e:
                            logger.error(f"Error in transformer relation extraction: {e}")
        
        return relations
    
    def extract_relations(self, text: str, entities: List[Dict[str, Any]], 
                         method: str = "hybrid") -> List[Dict[str, Any]]:
        """
        Extract relations between entities in clinical text
        
        Args:
            text: Clinical text
            entities: List of entities extracted from the text
            method: Extraction method ('rule-based', 'dependency', 'transformer', or 'hybrid')
            
        Returns:
            List[Dict[str, Any]]: List of extracted relations
        """
        if not text or not isinstance(text, str) or not entities:
            logger.warning(f"Invalid input for relation extraction")
            return []
        
        if method == "rule-based":
            return self.extract_relations_rule_based(text, entities)
        elif method == "dependency":
            return self.extract_relations_dependency(text, entities)
        elif method == "transformer":
            return self.extract_relations_transformer(text, entities)
        elif method == "hybrid":
            # Get relations from all methods
            rule_based_relations = self.extract_relations_rule_based(text, entities)
            dependency_relations = self.extract_relations_dependency(text, entities)
            transformer_relations = self.extract_relations_transformer(text, entities)
            
            # Combine and deduplicate relations
            all_relations = rule_based_relations + dependency_relations + transformer_relations
            
            # Simple deduplication based on relation type and entities
            deduplicated = {}
            for relation in all_relations:
                entity1_text = relation["entity1"]["text"]
                entity2_text = relation["entity2"]["text"]
                relation_type = relation["relation_type"]
                
                key = f"{relation_type}_{entity1_text}_{entity2_text}"
                if key not in deduplicated or relation["confidence"] > deduplicated[key]["confidence"]:
                    deduplicated[key] = relation
            
            return list(deduplicated.values())
        else:
            logger.error(f"Unknown relation extraction method: {method}")
            return []
    
    def build_knowledge_graph(self, relations: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Build a knowledge graph from extracted relations
        
        Args:
            relations: List of extracted relations
            
        Returns:
            Dict[str, Any]: Knowledge graph representation
        """
        if not relations:
            return {"nodes": [], "edges": []}
        
        nodes = {}
        edges = []
        
        # Create nodes from entities in relations
        for relation in relations:
            entity1 = relation["entity1"]
            entity2 = relation["entity2"]
            
            # Add entity1 as node if not already added
            if entity1["text"] not in nodes:
                nodes[entity1["text"]] = {
                    "id": len(nodes),
                    "text": entity1["text"],
                    "type": entity1["label"]
                }
            
            # Add entity2 as node if not already added
            if entity2["text"] not in nodes:
                nodes[entity2["text"]] = {
                    "id": len(nodes),
                    "text": entity2["text"],
                    "type": entity2["label"]
                }
            
            # Add edge between entities
            edge = {
                "source": nodes[entity1["text"]]["id"],
                "target": nodes[entity2["text"]]["id"],
                "relation": relation["relation_type"],
                "confidence": relation["confidence"]
            }
            edges.append(edge)
        
        # Convert nodes dictionary to list
        nodes_list = list(nodes.values())
        
        return {
            "nodes": nodes_list,
            "edges": edges
        }
    
    def visualize_relations(self, text: str, relations: List[Dict[str, Any]], 
                           output_path: Optional[str] = None) -> str:
        """
        Create a visualization of extracted relations
        
        Args:
            text: Original text
            relations: List of extracted relations
            output_path: Optional path to save visualization
            
        Returns:
            str: HTML representation of relations visualization
        """
        if not relations:
            return "<div>No relations found</div>"
        
        # Create HTML visualization
        html = f"""
        <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; border: 1px solid #ddd; border-radius: 5px;">
            <h2 style="color: #333;">Medical Relations Extraction</h2>
            
            <div style="margin: 20px 0; padding: 15px; background-color: #f9f9f9; border-radius: 5px;">
                <h3 style="margin-top: 0; color: #555;">Input Text:</h3>
                <p style="line-height: 1.5;">{text[:500]}{'...' if len(text) > 500 else ''}</p>
            </div>
            
            <div style="margin: 20px 0;">
                <h3 style="color: #555;">Extracted Relations:</h3>
                <table style="width: 100%; border-collapse: collapse; margin-top: 10px;">
                    <thead>
                        <tr style="background-color: #f2f2f2;">
                            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Relation Type</th>
                            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Entity 1</th>
                            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Entity 2</th>
                            <th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Confidence</th>
                        </tr>
                    </thead>
                    <tbody>
        """
        
        # Add rows for each relation
        for relation in relations:
            entity1_text = relation["entity1"]["text"]
            entity1_type = relation["entity1"]["label"]
            entity2_text = relation["entity2"]["text"]
            entity2_type = relation["entity2"]["label"]
            relation_type = relation["relation_type"]
            confidence = relation["confidence"]
            
            html += f"""
                        <tr>
                            <td style="padding: 8px; text-align: left; border: 1px solid #ddd;">{relation_type}</td>
                            <td style="padding: 8px; text-align: left; border: 1px solid #ddd;">
                                <span style="font-weight: bold;">{entity1_text}</span>
                                <span style="color: #777; font-size: 12px;"> ({entity1_type})</span>
                            </td>
                            <td style="padding: 8px; text-align: left; border: 1px solid #ddd;">
                                <span style="font-weight: bold;">{entity2_text}</span>
                                <span style="color: #777; font-size: 12px;"> ({entity2_type})</span>
                            </td>
                            <td style="padding: 8px; text-align: left; border: 1px solid #ddd;">{confidence:.2f}</td>
                        </tr>
            """
        
        html += """
                    </tbody>
                </table>
            </div>
        </div>
        """
        
        # Save to file if output path is provided
        if output_path:
            try:
                with open(output_path, "w", encoding="utf-8") as f:
                    f.write(html)
                logger.info(f"Visualization saved to {output_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {e}")
        
        return html

modules\text_analysis\text_classifier.py

"""
Clinical Text Classifier Module

This module provides functionality for classifying clinical text data,
including disease classification, severity assessment, and risk prediction.
"""

import os
import json
import logging
import pickle
from typing import List, Dict, Any, Optional, Union, Tuple

import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import MultinomialNB
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    TrainingArguments, 
    Trainer,
    DataCollatorWithPadding
)
from datasets import Dataset

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("clinical_text_classification.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ClinicalTextClassifier:
    """
    Class for classifying clinical text data
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the clinical text classifier
        
        Args:
            config_path: Path to the configuration file
        """
        self.config = self._load_config(config_path)
        self.model_type = self.config["classification"]["model_type"]
        self.algorithm = self.config["classification"]["algorithm"]
        
        # Initialize models
        self.traditional_model = None
        self.vectorizer = None
        self.transformer_model = None
        self.tokenizer = None
        self.label_map = {}
        self.inverse_label_map = {}
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "classification": {
                "model_type": "traditional",
                "algorithm": "svm",
                "pretrained_model": "bert-base-uncased",
                "batch_size": 16,
                "learning_rate": 2e-5,
                "epochs": 3,
                "max_length": 512
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def _create_traditional_model(self) -> Any:
        """
        Create a traditional machine learning model based on configuration
        
        Returns:
            Any: Initialized model
        """
        if self.algorithm == "svm":
            return SVC(probability=True, kernel='linear')
        elif self.algorithm == "random_forest":
            return RandomForestClassifier(n_estimators=100, random_state=42)
        elif self.algorithm == "naive_bayes":
            return MultinomialNB()
        else:
            logger.warning(f"Unknown algorithm: {self.algorithm}, using SVM as default")
            return SVC(probability=True, kernel='linear')
    
    def train_traditional_model(self, texts: List[str], labels: List[str], 
                               test_size: float = 0.2) -> Dict[str, Any]:
        """
        Train a traditional machine learning model for text classification
        
        Args:
            texts: List of clinical texts
            labels: List of corresponding labels
            test_size: Proportion of data to use for testing
            
        Returns:
            Dict[str, Any]: Training results
        """
        if len(texts) != len(labels):
            logger.error("Number of texts and labels must be equal")
            return {"success": False, "error": "Number of texts and labels must be equal"}
        
        if len(texts) == 0:
            logger.error("Empty training data")
            return {"success": False, "error": "Empty training data"}
        
        try:
            # Create label mapping
            unique_labels = sorted(set(labels))
            self.label_map = {label: i for i, label in enumerate(unique_labels)}
            self.inverse_label_map = {i: label for label, i in self.label_map.items()}
            
            # Convert labels to numeric
            numeric_labels = [self.label_map[label] for label in labels]
            
            # Split data
            X_train, X_test, y_train, y_test = train_test_split(
                texts, numeric_labels, test_size=test_size, random_state=42, stratify=numeric_labels
            )
            
            # Create and fit vectorizer
            self.vectorizer = TfidfVectorizer(max_features=5000)
            X_train_vec = self.vectorizer.fit_transform(X_train)
            X_test_vec = self.vectorizer.transform(X_test)
            
            # Create and train model
            self.traditional_model = self._create_traditional_model()
            self.traditional_model.fit(X_train_vec, y_train)
            
            # Evaluate model
            y_pred = self.traditional_model.predict(X_test_vec)
            accuracy = accuracy_score(y_test, y_pred)
            f1 = f1_score(y_test, y_pred, average='weighted')
            
            # Generate classification report
            report = classification_report(y_test, y_pred, target_names=unique_labels, output_dict=True)
            
            logger.info(f"Model trained with accuracy: {accuracy:.4f}, F1 score: {f1:.4f}")
            
            return {
                "success": True,
                "accuracy": accuracy,
                "f1_score": f1,
                "report": report,
                "num_classes": len(unique_labels),
                "classes": unique_labels,
                "num_samples": len(texts),
                "num_train": len(X_train),
                "num_test": len(X_test)
            }
        except Exception as e:
            logger.error(f"Error training traditional model: {e}")
            return {"success": False, "error": str(e)}
    
    def train_transformer_model(self, texts: List[str], labels: List[str], 
                               output_dir: str, test_size: float = 0.2) -> Dict[str, Any]:
        """
        Train a transformer model for text classification
        
        Args:
            texts: List of clinical texts
            labels: List of corresponding labels
            output_dir: Directory to save the trained model
            test_size: Proportion of data to use for testing
            
        Returns:
            Dict[str, Any]: Training results
        """
        if len(texts) != len(labels):
            logger.error("Number of texts and labels must be equal")
            return {"success": False, "error": "Number of texts and labels must be equal"}
        
        if len(texts) == 0:
            logger.error("Empty training data")
            return {"success": False, "error": "Empty training data"}
        
        try:
            # Create label mapping
            unique_labels = sorted(set(labels))
            self.label_map = {label: i for i, label in enumerate(unique_labels)}
            self.inverse_label_map = {i: label for label, i in self.label_map.items()}
            
            # Convert labels to numeric
            numeric_labels = [self.label_map[label] for label in labels]
            
            # Split data
            X_train, X_test, y_train, y_test = train_test_split(
                texts, numeric_labels, test_size=test_size, random_state=42, stratify=numeric_labels
            )
            
            # Load tokenizer and model
            model_name = self.config["classification"]["pretrained_model"]
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
                model_name, num_labels=len(unique_labels)
            )
            
            # Tokenize data
            def tokenize_function(examples):
                return self.tokenizer(
                    examples["text"], 
                    padding="max_length", 
                    truncation=True, 
                    max_length=self.config["classification"]["max_length"]
                )
            
            # Create datasets
            train_dataset = Dataset.from_dict({
                "text": X_train, 
                "label": y_train
            })
            test_dataset = Dataset.from_dict({
                "text": X_test, 
                "label": y_test
            })
            
            # Tokenize datasets
            train_dataset = train_dataset.map(tokenize_function, batched=True)
            test_dataset = test_dataset.map(tokenize_function, batched=True)
            
            # Data collator
            data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
            
            # Training arguments
            training_args = TrainingArguments(
                output_dir=output_dir,
                num_train_epochs=self.config["classification"]["epochs"],
                per_device_train_batch_size=self.config["classification"]["batch_size"],
                per_device_eval_batch_size=self.config["classification"]["batch_size"],
                learning_rate=self.config["classification"]["learning_rate"],
                weight_decay=0.01,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                load_best_model_at_end=True,
                push_to_hub=False,
            )
            
            # Initialize trainer
            trainer = Trainer(
                model=self.transformer_model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=test_dataset,
                tokenizer=self.tokenizer,
                data_collator=data_collator,
            )
            
            # Train model
            trainer.train()
            
            # Evaluate model
            eval_results = trainer.evaluate()
            
            # Save model and tokenizer
            self.transformer_model.save_pretrained(output_dir)
            self.tokenizer.save_pretrained(output_dir)
            
            # Save label mapping
            with open(os.path.join(output_dir, "label_map.json"), "w") as f:
                json.dump(self.label_map, f)
            
            logger.info(f"Transformer model trained and saved to {output_dir}")
            
            return {
                "success": True,
                "eval_loss": eval_results["eval_loss"],
                "num_classes": len(unique_labels),
                "classes": unique_labels,
                "num_samples": len(texts),
                "num_train": len(X_train),
                "num_test": len(X_test),
                "model_path": output_dir
            }
        except Exception as e:
            logger.error(f"Error training transformer model: {e}")
            return {"success": False, "error": str(e)}
    
    def load_model(self, model_path: str) -> bool:
        """
        Load a trained model
        
        Args:
            model_path: Path to the trained model
            
        Returns:
            bool: Success status
        """
        try:
            if os.path.isdir(model_path):
                # Try to load as transformer model
                try:
                    self.tokenizer = AutoTokenizer.from_pretrained(model_path)
                    self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_path)
                    
                    # Load label mapping
                    label_map_path = os.path.join(model_path, "label_map.json")
                    if os.path.exists(label_map_path):
                        with open(label_map_path, "r") as f:
                            self.label_map = json.load(f)
                            self.inverse_label_map = {int(i): label for label, i in self.label_map.items()}
                    
                    self.model_type = "transformer"
                    logger.info(f"Loaded transformer model from {model_path}")
                    return True
                except Exception as e:
                    logger.error(f"Error loading transformer model: {e}")
                    return False
            else:
                # Try to load as traditional model
                with open(model_path, "rb") as f:
                    model_data = pickle.load(f)
                
                self.traditional_model = model_data["model"]
                self.vectorizer = model_data["vectorizer"]
                self.label_map = model_data["label_map"]
                self.inverse_label_map = model_data["inverse_label_map"]
                self.model_type = "traditional"
                self.algorithm = model_data.get("algorithm", "unknown")
                
                logger.info(f"Loaded traditional model from {model_path}")
                return True
        except Exception as e:
            logger.error(f"Error loading model: {e}")
            return False
    
    def save_traditional_model(self, output_path: str) -> bool:
        """
        Save a trained traditional model
        
        Args:
            output_path: Path to save the model
            
        Returns:
            bool: Success status
        """
        if not self.traditional_model or not self.vectorizer:
            logger.error("No trained model to save")
            return False
        
        try:
            model_data = {
                "model": self.traditional_model,
                "vectorizer": self.vectorizer,
                "label_map": self.label_map,
                "inverse_label_map": self.inverse_label_map,
                "algorithm": self.algorithm
            }
            
            with open(output_path, "wb") as f:
                pickle.dump(model_data, f)
            
            logger.info(f"Model saved to {output_path}")
            return True
        except Exception as e:
            logger.error(f"Error saving model: {e}")
            return False
    
    def predict(self, text: str) -> Dict[str, Any]:
        """
        Predict class for clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            Dict[str, Any]: Prediction results
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return {"success": False, "error": "Invalid input text"}
        
        try:
            if self.model_type == "transformer" and self.transformer_model and self.tokenizer:
                # Tokenize input
                inputs = self.tokenizer(
                    text, 
                    padding="max_length", 
                    truncation=True, 
                    max_length=self.config["classification"]["max_length"],
                    return_tensors="pt"
                )
                
                # Get prediction
                with torch.no_grad():
                    outputs = self.transformer_model(**inputs)
                    logits = outputs.logits
                    probabilities = torch.nn.functional.softmax(logits, dim=1)
                
                # Get predicted class and probabilities
                predicted_class_id = torch.argmax(probabilities, dim=1).item()
                predicted_class = self.inverse_label_map.get(predicted_class_id, "Unknown")
                
                # Convert probabilities to dictionary
                probs_dict = {
                    self.inverse_label_map.get(i, f"Class_{i}"): prob.item()
                    for i, prob in enumerate(probabilities[0])
                }
                
                return {
                    "success": True,
                    "predicted_class": predicted_class,
                    "confidence": probabilities[0][predicted_class_id].item(),
                    "probabilities": probs_dict
                }
            
            elif self.model_type == "traditional" and self.traditional_model and self.vectorizer:
                # Vectorize input
                X_vec = self.vectorizer.transform([text])
                
                # Get prediction
                predicted_class_id = self.traditional_model.predict(X_vec)[0]
                predicted_class = self.inverse_label_map.get(predicted_class_id, "Unknown")
                
                # Get probabilities if available
                probabilities = {}
                if hasattr(self.traditional_model, "predict_proba"):
                    probs = self.traditional_model.predict_proba(X_vec)[0]
                    probabilities = {
                        self.inverse_label_map.get(i, f"Class_{i}"): prob
                        for i, prob in enumerate(probs)
                    }
                    confidence = probs[predicted_class_id]
                else:
                    confidence = 1.0
                
                return {
                    "success": True,
                    "predicted_class": predicted_class,
                    "confidence": confidence,
                    "probabilities": probabilities
                }
            
            else:
                logger.error("No trained model available")
                return {"success": False, "error": "No trained model available"}
        
        except Exception as e:
            logger.error(f"Error in prediction: {e}")
            return {"success": False, "error": str(e)}
    
    def batch_predict(self, texts: List[str]) -> List[Dict[str, Any]]:
        """
        Predict classes for a batch of clinical texts
        
        Args:
            texts: List of clinical texts
            
        Returns:
            List[Dict[str, Any]]: List of prediction results
        """
        if not texts:
            logger.warning("Empty input list")
            return []
        
        results = []
        for text in texts:
            result = self.predict(text)
            results.append(result)
        
        return results
    
    def evaluate(self, texts: List[str], labels: List[str]) -> Dict[str, Any]:
        """
        Evaluate model on test data
        
        Args:
            texts: List of clinical texts
            labels: List of corresponding labels
            
        Returns:
            Dict[str, Any]: Evaluation results
        """
        if len(texts) != len(labels):
            logger.error("Number of texts and labels must be equal")
            return {"success": False, "error": "Number of texts and labels must be equal"}
        
        if len(texts) == 0:
            logger.error("Empty evaluation data")
            return {"success": False, "error": "Empty evaluation data"}
        
        try:
            # Convert labels to numeric if needed
            if all(label in self.label_map for label in labels):
                numeric_labels = [self.label_map[label] for label in labels]
            else:
                # Create new label mapping for unknown labels
                unique_labels = sorted(set(labels))
                label_map = {label: i for i, label in enumerate(unique_labels)}
                numeric_labels = [label_map[label] for label in labels]
            
            predictions = []
            for text in texts:
                result = self.predict(text)
                if result["success"]:
                    predicted_label = result["predicted_class"]
                    predictions.append(predicted_label)
                else:
                    logger.warning(f"Prediction failed for text: {text[:50]}...")
                    predictions.append("Unknown")
            
            # Convert predictions to numeric
            numeric_predictions = []
            for pred in predictions:
                if pred in self.label_map:
                    numeric_predictions.append(self.label_map[pred])
                else:
                    # Assign a default value for unknown predictions
                    numeric_predictions.append(-1)
            
            # Calculate metrics
            accuracy = accuracy_score([n for n in numeric_labels if n >= 0], 
                                     [p for i, p in enumerate(numeric_predictions) if numeric_labels[i] >= 0])
            
            # Generate classification report
            label_names = [self.inverse_label_map.get(i, f"Class_{i}") for i in set(numeric_labels)]
            report = classification_report(numeric_labels, numeric_predictions, 
                                          target_names=label_names, output_dict=True)
            
            logger.info(f"Evaluation completed with accuracy: {accuracy:.4f}")
            
            return {
                "success": True,
                "accuracy": accuracy,
                "report": report,
                "num_samples": len(texts)
            }
        except Exception as e:
            logger.error(f"Error in evaluation: {e}")
            return {"success": False, "error": str(e)}
    
    def visualize_predictions(self, text: str, prediction: Dict[str, Any], 
                             output_path: Optional[str] = None) -> str:
        """
        Create a visualization of prediction results
        
        Args:
            text: Original text
            prediction: Prediction result
            output_path: Optional path to save visualization
            
        Returns:
            str: HTML representation of prediction visualization
        """
        if not prediction["success"]:
            return f"<div>Error: {prediction.get('error', 'Unknown error')}</div>"
        
        # Create HTML visualization
        html = f"""
        <div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; border: 1px solid #ddd; border-radius: 5px;">
            <h2 style="color: #333;">Text Classification Result</h2>
            
            <div style="margin: 20px 0; padding: 15px; background-color: #f9f9f9; border-radius: 5px;">
                <h3 style="margin-top: 0; color: #555;">Input Text:</h3>
                <p style="line-height: 1.5;">{text[:500]}{'...' if len(text) > 500 else ''}</p>
            </div>
            
            <div style="margin: 20px 0;">
                <h3 style="color: #555;">Prediction:</h3>
                <div style="display: flex; align-items: center;">
                    <div style="font-size: 18px; font-weight: bold; color: #2c3e50; flex: 1;">
                        {prediction["predicted_class"]}
                    </div>
                    <div style="background-color: #3498db; color: white; padding: 5px 10px; border-radius: 20px; font-weight: bold;">
                        Confidence: {prediction["confidence"]:.2%}
                    </div>
                </div>
            </div>
        """
        
        # Add probability distribution if available
        if "probabilities" in prediction and prediction["probabilities"]:
            html += """
            <div style="margin: 20px 0;">
                <h3 style="color: #555;">Class Probabilities:</h3>
                <div style="display: flex; flex-direction: column; gap: 10px;">
            """
            
            # Sort probabilities by value (descending)
            sorted_probs = sorted(prediction["probabilities"].items(), key=lambda x: x[1], reverse=True)
            
            for class_name, prob in sorted_probs:
                # Calculate width based on probability
                width = int(prob * 100)
                
                # Choose color based on whether it's the predicted class
                color = "#3498db" if class_name == prediction["predicted_class"] else "#95a5a6"
                
                html += f"""
                <div style="display: flex; align-items: center; gap: 10px;">
                    <div style="width: 150px; font-weight: {
                        'bold' if class_name == prediction["predicted_class"] else 'normal'
                    };">{class_name}</div>
                    <div style="flex: 1; background-color: #eee; border-radius: 5px; height: 20px;">
                        <div style="width: {width}%; background-color: {color}; height: 100%; border-radius: 5px;"></div>
                    </div>
                    <div style="width: 60px; text-align: right;">{prob:.2%}</div>
                </div>
                """
            
            html += """
                </div>
            </div>
            """
        
        html += """
        </div>
        """
        
        # Save to file if output path is provided
        if output_path:
            try:
                with open(output_path, "w", encoding="utf-8") as f:
                    f.write(html)
                logger.info(f"Visualization saved to {output_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {e}")
        
        return html

modules\text_analysis\text_preprocessor.py

"""
Clinical Text Preprocessor Module

This module provides functionality for preprocessing clinical text data,
including tokenization, normalization, and feature extraction.
"""

import re
import json
import logging
import string
import os
from typing import List, Dict, Any, Optional, Union, Set

import nltk
import numpy as np
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

# Download required NLTK resources
try:
    nltk.data.find('tokenizers/punkt')
    nltk.data.find('corpora/stopwords')
    nltk.data.find('corpora/wordnet')
except LookupError:
    nltk.download('punkt')
    nltk.download('stopwords')
    nltk.download('wordnet')

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("clinical_text_processing.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ClinicalTextPreprocessor:
    """
    Class for preprocessing clinical text data
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the clinical text preprocessor
        
        Args:
            config_path: Path to the configuration file
        """
        self.config = self._load_config(config_path)
        self.stop_words = set(stopwords.words('english'))
        self.stemmer = PorterStemmer()
        self.lemmatizer = WordNetLemmatizer()
        
        # Add medical stop words
        self.medical_stop_words = {
            'patient', 'doctor', 'hospital', 'clinic', 'medical', 'medicine',
            'treatment', 'therapy', 'diagnosis', 'prognosis', 'report', 'date',
            'name', 'id', 'mr', 'mrs', 'dr', 'test', 'result', 'normal', 'abnormal'
        }
        self.stop_words.update(self.medical_stop_words)
        
        # Load spaCy model for advanced NLP tasks
        try:
            self.nlp = spacy.load('en_core_web_sm')
            logger.info("Loaded spaCy model: en_core_web_sm")
        except OSError:
            logger.warning("spaCy model not found. Using spaCy's default model.")
            self.nlp = spacy.blank('en')
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "preprocessing": {
                "remove_stopwords": True,
                "stemming": True,
                "lemmatization": True,
                "lowercase": True,
                "remove_punctuation": True,
                "remove_numbers": False,
                "min_token_length": 2
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def preprocess_text(self, text: str) -> str:
        """
        Preprocess clinical text
        
        Args:
            text: Input clinical text
            
        Returns:
            str: Preprocessed text
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return ""
        
        # Convert to lowercase if configured
        if self.config["preprocessing"]["lowercase"]:
            text = text.lower()
        
        # Remove punctuation if configured
        if self.config["preprocessing"]["remove_punctuation"]:
            text = re.sub(f'[{string.punctuation}]', ' ', text)
        
        # Remove numbers if configured
        if self.config["preprocessing"]["remove_numbers"]:
            text = re.sub(r'\d+', '', text)
        
        # Tokenize
        tokens = word_tokenize(text)
        
        # Remove stopwords if configured
        if self.config["preprocessing"]["remove_stopwords"]:
            tokens = [t for t in tokens if t not in self.stop_words]
        
        # Apply stemming if configured
        if self.config["preprocessing"]["stemming"]:
            tokens = [self.stemmer.stem(t) for t in tokens]
        
        # Apply lemmatization if configured
        if self.config["preprocessing"]["lemmatization"]:
            tokens = [self.lemmatizer.lemmatize(t) for t in tokens]
        
        # Filter by token length
        min_length = self.config["preprocessing"]["min_token_length"]
        tokens = [t for t in tokens if len(t) >= min_length]
        
        # Join tokens back into text
        preprocessed_text = ' '.join(tokens)
        
        return preprocessed_text
    
    def extract_sentences(self, text: str) -> List[str]:
        """
        Extract sentences from clinical text
        
        Args:
            text: Input clinical text
            
        Returns:
            List[str]: List of sentences
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return []
        
        # Use NLTK's sentence tokenizer
        sentences = sent_tokenize(text)
        
        return sentences
    
    def extract_medical_terms(self, text: str) -> List[str]:
        """
        Extract medical terms from clinical text using rule-based approach
        
        Args:
            text: Input clinical text
            
        Returns:
            List[str]: List of medical terms
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return []
        
        # Process text with spaCy
        doc = self.nlp(text)
        
        # Extract noun phrases as potential medical terms
        medical_terms = []
        for chunk in doc.noun_chunks:
            # Filter out common non-medical terms
            if chunk.text.lower() not in self.stop_words:
                medical_terms.append(chunk.text)
        
        return medical_terms
    
    def extract_features(self, texts: List[str], method: str = 'tfidf', 
                        max_features: int = 1000) -> np.ndarray:
        """
        Extract features from clinical texts
        
        Args:
            texts: List of clinical texts
            method: Feature extraction method ('tfidf', 'count', 'binary')
            max_features: Maximum number of features
            
        Returns:
            np.ndarray: Feature matrix
        """
        if not texts:
            logger.warning("Empty text list provided for feature extraction")
            return np.array([])
        
        # Preprocess texts
        preprocessed_texts = [self.preprocess_text(text) for text in texts]
        
        # Extract features based on method
        if method == 'tfidf':
            vectorizer = TfidfVectorizer(max_features=max_features)
        elif method == 'count':
            vectorizer = CountVectorizer(max_features=max_features)
        elif method == 'binary':
            vectorizer = CountVectorizer(max_features=max_features, binary=True)
        else:
            logger.error(f"Unknown feature extraction method: {method}")
            return np.array([])
        
        # Fit and transform
        try:
            feature_matrix = vectorizer.fit_transform(preprocessed_texts)
            feature_names = vectorizer.get_feature_names_out()
            logger.info(f"Extracted {len(feature_names)} features using {method} method")
            return feature_matrix
        except Exception as e:
            logger.error(f"Error in feature extraction: {e}")
            return np.array([])
    
    def normalize_medical_concepts(self, terms: List[str]) -> Dict[str, str]:
        """
        Normalize medical concepts to standard terminology
        
        Args:
            terms: List of medical terms
            
        Returns:
            Dict[str, str]: Mapping from original terms to normalized terms
        """
        # This is a simplified implementation
        # In a real system, this would connect to UMLS or SNOMED CT
        
        # Example mapping for common variations
        normalization_map = {
            "heart attack": "myocardial infarction",
            "heart failure": "cardiac failure",
            "high blood pressure": "hypertension",
            "low blood pressure": "hypotension",
            "stroke": "cerebrovascular accident",
            "sugar": "glucose",
            "cancer": "malignant neoplasm",
            "chest pain": "angina pectoris"
        }
        
        result = {}
        for term in terms:
            term_lower = term.lower()
            if term_lower in normalization_map:
                result[term] = normalization_map[term_lower]
            else:
                result[term] = term
        
        return result
    
    def detect_section_headers(self, text: str) -> List[Dict[str, Any]]:
        """
        Detect clinical note section headers
        
        Args:
            text: Clinical note text
            
        Returns:
            List[Dict[str, Any]]: List of detected sections with positions
        """
        # Common section headers in clinical notes
        section_patterns = [
            r"(?i)(?:^|\n)[\s]*CHIEF\s+COMPLAINT[\s:]*",
            r"(?i)(?:^|\n)[\s]*HISTORY\s+(?:OF\s+)?(?:PRESENT\s+)?ILLNESS[\s:]*",
            r"(?i)(?:^|\n)[\s]*PAST\s+MEDICAL\s+HISTORY[\s:]*",
            r"(?i)(?:^|\n)[\s]*MEDICATIONS[\s:]*",
            r"(?i)(?:^|\n)[\s]*ALLERGIES[\s:]*",
            r"(?i)(?:^|\n)[\s]*FAMILY\s+HISTORY[\s:]*",
            r"(?i)(?:^|\n)[\s]*SOCIAL\s+HISTORY[\s:]*",
            r"(?i)(?:^|\n)[\s]*PHYSICAL\s+EXAMINATION[\s:]*",
            r"(?i)(?:^|\n)[\s]*VITAL\s+SIGNS[\s:]*",
            r"(?i)(?:^|\n)[\s]*LABORATORY\s+(?:DATA|RESULTS)[\s:]*",
            r"(?i)(?:^|\n)[\s]*ASSESSMENT[\s:]*",
            r"(?i)(?:^|\n)[\s]*PLAN[\s:]*",
            r"(?i)(?:^|\n)[\s]*IMPRESSION[\s:]*",
            r"(?i)(?:^|\n)[\s]*DIAGNOSIS[\s:]*"
        ]
        
        sections = []
        for pattern in section_patterns:
            for match in re.finditer(pattern, text):
                section_name = match.group(0).strip()
                start_pos = match.start()
                sections.append({
                    "name": section_name,
                    "start": start_pos,
                    "text": section_name
                })
        
        # Sort sections by their position in the text
        sections.sort(key=lambda x: x["start"])
        
        # Extract section content
        for i in range(len(sections)):
            start = sections[i]["start"] + len(sections[i]["text"])
            end = sections[i+1]["start"] if i < len(sections) - 1 else len(text)
            sections[i]["content"] = text[start:end].strip()
        
        return sections
    
    def extract_demographics(self, text: str) -> Dict[str, Any]:
        """
        Extract patient demographics from clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            Dict[str, Any]: Extracted demographics
        """
        demographics = {
            "age": None,
            "gender": None,
            "race": None
        }
        
        # Age patterns
        age_patterns = [
            r"(?i)(?:age|aged)[\s:]+(\d+)(?:\s+years?)?",
            r"(?i)(\d+)[\s-]*(?:year|y)(?:s|\.)?\s+old",
            r"(?i)(\d+)[\s-]*y\.?o\.?"
        ]
        
        for pattern in age_patterns:
            match = re.search(pattern, text)
            if match:
                demographics["age"] = int(match.group(1))
                break
        
        # Gender patterns
        male_patterns = [r"(?i)\b(?:male|m)\b", r"(?i)gender[\s:]+male"]
        female_patterns = [r"(?i)\b(?:female|f)\b", r"(?i)gender[\s:]+female"]
        
        for pattern in male_patterns:
            if re.search(pattern, text):
                demographics["gender"] = "male"
                break
        
        if demographics["gender"] is None:
            for pattern in female_patterns:
                if re.search(pattern, text):
                    demographics["gender"] = "female"
                    break
        
        # Race patterns (simplified)
        race_patterns = {
            "caucasian": [r"(?i)\b(?:caucasian|white)\b"],
            "african american": [r"(?i)\b(?:african american|black)\b"],
            "asian": [r"(?i)\b(?:asian)\b"],
            "hispanic": [r"(?i)\b(?:hispanic|latino)\b"],
            "other": [r"(?i)\b(?:other race|mixed race)\b"]
        }
        
        for race, patterns in race_patterns.items():
            for pattern in patterns:
                if re.search(pattern, text):
                    demographics["race"] = race
                    break
            if demographics["race"]:
                break
        
        return demographics

modules\text_analysis\text_summarizer.py

"""
Clinical Text Summarizer Module

This module provides functionality for summarizing clinical text data,
including extractive and abstractive summarization methods.
"""

import os
import json
import logging
import re
from typing import List, Dict, Any, Optional, Union, Tuple

import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM, 
    pipeline
)

# Download required NLTK resources
try:
    nltk.data.find('tokenizers/punkt')
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('punkt')
    nltk.download('stopwords')

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("clinical_text_summarization.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ClinicalTextSummarizer:
    """
    Class for summarizing clinical text data
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the clinical text summarizer
        
        Args:
            config_path: Path to the configuration file
        """
        self.config = self._load_config(config_path)
        self.stop_words = set(stopwords.words('english'))
        
        # Add medical stop words
        self.medical_stop_words = {
            'patient', 'doctor', 'hospital', 'clinic', 'medical', 'medicine',
            'treatment', 'therapy', 'diagnosis', 'prognosis', 'report', 'date',
            'name', 'id', 'mr', 'mrs', 'dr', 'test', 'result', 'normal', 'abnormal'
        }
        self.stop_words.update(self.medical_stop_words)
        
        # Initialize transformer models
        self.tokenizer = None
        self.model = None
        self.summarization_pipeline = None
        
        # Load transformer model if specified
        if self.config["summarization"]["model"]:
            self._load_transformer_model(self.config["summarization"]["model"])
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "summarization": {
                "max_length": 150,
                "min_length": 50,
                "model": "t5-small",
                "early_stopping": True
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def _load_transformer_model(self, model_name: str) -> None:
        """
        Load transformer model for summarization
        
        Args:
            model_name: Name or path of the model
        """
        try:
            # Map model name to actual model if it's a custom name
            model_mapping = {
                "medical-t5": "google/flan-t5-base",
                "clinical-t5": "google/flan-t5-base"
            }
            
            actual_model = model_mapping.get(model_name, model_name)
            
            # Load tokenizer and model
            self.tokenizer = AutoTokenizer.from_pretrained(actual_model)
            self.model = AutoModelForSeq2SeqLM.from_pretrained(actual_model)
            
            # Create summarization pipeline
            self.summarization_pipeline = pipeline(
                "summarization",
                model=self.model,
                tokenizer=self.tokenizer,
                framework="pt"
            )
            
            logger.info(f"Loaded transformer model: {actual_model}")
        except Exception as e:
            logger.error(f"Error loading transformer model: {e}")
            self.tokenizer = None
            self.model = None
            self.summarization_pipeline = None
    
    def extractive_summarize(self, text: str, ratio: float = 0.3, 
                            min_length: int = 3, max_length: int = 10) -> str:
        """
        Generate an extractive summary of clinical text
        
        Args:
            text: Clinical text to summarize
            ratio: Proportion of sentences to include in summary
            min_length: Minimum number of sentences in summary
            max_length: Maximum number of sentences in summary
            
        Returns:
            str: Extractive summary
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return ""
        
        try:
            # Split text into sentences
            sentences = sent_tokenize(text)
            
            if len(sentences) <= min_length:
                return text
            
            # Calculate number of sentences in summary
            num_sentences = max(min_length, min(max_length, int(len(sentences) * ratio)))
            
            # Create TF-IDF vectorizer
            vectorizer = TfidfVectorizer(stop_words=self.stop_words)
            
            # Get sentence vectors
            sentence_vectors = vectorizer.fit_transform(sentences)
            
            # Calculate sentence scores based on TF-IDF
            sentence_scores = {}
            for i, sentence in enumerate(sentences):
                # Sum the TF-IDF scores of all terms in the sentence
                sentence_scores[i] = np.sum(sentence_vectors[i].toarray())
            
            # Get top sentences
            top_sentence_indices = sorted(sentence_scores, key=sentence_scores.get, reverse=True)[:num_sentences]
            
            # Sort indices to maintain original order
            top_sentence_indices = sorted(top_sentence_indices)
            
            # Create summary
            summary = " ".join([sentences[i] for i in top_sentence_indices])
            
            return summary
        except Exception as e:
            logger.error(f"Error in extractive summarization: {e}")
            return text[:500] + "..." if len(text) > 500 else text
    
    def abstractive_summarize(self, text: str) -> str:
        """
        Generate an abstractive summary of clinical text using transformer models
        
        Args:
            text: Clinical text to summarize
            
        Returns:
            str: Abstractive summary
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return ""
        
        if not self.summarization_pipeline:
            logger.warning("Transformer model not available, falling back to extractive summarization")
            return self.extractive_summarize(text)
        
        try:
            # Truncate text if it's too long
            max_tokens = 1024  # Typical limit for transformer models
            if len(text.split()) > max_tokens:
                logger.info(f"Text too long ({len(text.split())} tokens), truncating to {max_tokens} tokens")
                text = " ".join(text.split()[:max_tokens])
            
            # Generate summary
            summary = self.summarization_pipeline(
                text,
                max_length=self.config["summarization"]["max_length"],
                min_length=self.config["summarization"]["min_length"],
                do_sample=False,
                early_stopping=self.config["summarization"]["early_stopping"]
            )
            
            # Extract summary text
            summary_text = summary[0]["summary_text"]
            
            return summary_text
        except Exception as e:
            logger.error(f"Error in abstractive summarization: {e}")
            return self.extractive_summarize(text)
    
    def summarize(self, text: str, method: str = "hybrid") -> Dict[str, Any]:
        """
        Generate a summary of clinical text
        
        Args:
            text: Clinical text to summarize
            method: Summarization method ('extractive', 'abstractive', or 'hybrid')
            
        Returns:
            Dict[str, Any]: Summarization results
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return {"success": False, "error": "Invalid input text"}
        
        try:
            if method == "extractive":
                summary = self.extractive_summarize(text)
                summary_type = "extractive"
            elif method == "abstractive":
                if self.summarization_pipeline:
                    summary = self.abstractive_summarize(text)
                    summary_type = "abstractive"
                else:
                    logger.warning("Transformer model not available, falling back to extractive summarization")
                    summary = self.extractive_summarize(text)
                    summary_type = "extractive (fallback)"
            elif method == "hybrid":
                # For hybrid, first do extractive to reduce size, then abstractive
                if self.summarization_pipeline:
                    extractive_summary = self.extractive_summarize(text, ratio=0.5)
                    summary = self.abstractive_summarize(extractive_summary)
                    summary_type = "hybrid"
                else:
                    summary = self.extractive_summarize(text)
                    summary_type = "extractive (fallback)"
            else:
                logger.error(f"Unknown summarization method: {method}")
                return {"success": False, "error": f"Unknown summarization method: {method}"}
            
            # Calculate compression ratio
            original_length = len(text.split())
            summary_length = len(summary.split())
            compression_ratio = summary_length / original_length if original_length > 0 else 0
            
            return {
                "success": True,
                "summary": summary,
                "method": summary_type,
                "original_length": original_length,
                "summary_length": summary_length,
                "compression_ratio": compression_ratio
            }
        except Exception as e:
            logger.error(f"Error in summarization: {e}")
            return {"success": False, "error": str(e)}
    
    def section_based_summarize(self, text: str, sections: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Generate summaries for each section of a clinical note
        
        Args:
            text: Full clinical text
            sections: List of section dictionaries with name, start, and content
            
        Returns:
            Dict[str, Any]: Section-based summarization results
        """
        if not text or not isinstance(text, str) or not sections:
            logger.warning(f"Invalid input for section-based summarization")
            return {"success": False, "error": "Invalid input"}
        
        try:
            section_summaries = {}
            
            for section in sections:
                section_name = section["name"]
                section_content = section["content"]
                
                # Skip empty sections or very short sections
                if not section_content or len(section_content.split()) < 10:
                    section_summaries[section_name] = section_content
                    continue
                
                # Summarize section content
                summary_result = self.summarize(section_content, method="hybrid")
                
                if summary_result["success"]:
                    section_summaries[section_name] = summary_result["summary"]
                else:
                    section_summaries[section_name] = section_content
            
            # Combine section summaries
            combined_summary = ""
            for section_name, summary in section_summaries.items():
                combined_summary += f"{section_name}\n{summary}\n\n"
            
            return {
                "success": True,
                "combined_summary": combined_summary,
                "section_summaries": section_summaries
            }
        except Exception as e:
            logger.error(f"Error in section-based summarization: {e}")
            return {"success": False, "error": str(e)}
    
    def key_findings_extract(self, text: str) -> Dict[str, Any]:
        """
        Extract key findings from clinical text
        
        Args:
            text: Clinical text
            
        Returns:
            Dict[str, Any]: Extracted key findings
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return {"success": False, "error": "Invalid input text"}
        
        try:
            # Define patterns for key findings
            patterns = {
                "diagnosis": [
                    r"(?i)diagnosis[\s:]+([^\n\.]+)",
                    r"(?i)impression[\s:]+([^\n\.]+)",
                    r"(?i)assessment[\s:]+([^\n\.]+)"
                ],
                "medications": [
                    r"(?i)medications?[\s:]+([^\n]+)",
                    r"(?i)prescribed[\s:]+([^\n]+)",
                    r"(?i)taking[\s:]+([^\n]+)"
                ],
                "allergies": [
                    r"(?i)allergies[\s:]+([^\n]+)",
                    r"(?i)allergic to[\s:]+([^\n]+)"
                ],
                "vital_signs": [
                    r"(?i)vital signs[\s:]+([^\n]+)",
                    r"(?i)temperature[\s:]+([^\n\.,]+)",
                    r"(?i)blood pressure[\s:]+([^\n\.,]+)",
                    r"(?i)pulse[\s:]+([^\n\.,]+)",
                    r"(?i)respiratory rate[\s:]+([^\n\.,]+)"
                ],
                "lab_results": [
                    r"(?i)lab(?:oratory)? results?[\s:]+([^\n]+)",
                    r"(?i)labs?[\s:]+([^\n]+)"
                ],
                "procedures": [
                    r"(?i)procedures?[\s:]+([^\n]+)",
                    r"(?i)surgery[\s:]+([^\n]+)",
                    r"(?i)operation[\s:]+([^\n]+)"
                ]
            }
            
            findings = {}
            
            # Extract findings using patterns
            for category, category_patterns in patterns.items():
                category_findings = []
                
                for pattern in category_patterns:
                    matches = re.finditer(pattern, text)
                    for match in matches:
                        finding = match.group(1).strip()
                        if finding and finding not in category_findings:
                            category_findings.append(finding)
                
                findings[category] = category_findings
            
            # Format findings as text
            formatted_findings = ""
            for category, category_findings in findings.items():
                if category_findings:
                    formatted_findings += f"{category.replace('_', ' ').title()}:\n"
                    for finding in category_findings:
                        formatted_findings += f"- {finding}\n"
                    formatted_findings += "\n"
            
            return {
                "success": True,
                "findings": findings,
                "formatted_findings": formatted_findings
            }
        except Exception as e:
            logger.error(f"Error extracting key findings: {e}")
            return {"success": False, "error": str(e)}
    
    def generate_discharge_summary(self, text: str) -> Dict[str, Any]:
        """
        Generate a discharge summary from clinical notes
        
        Args:
            text: Clinical text
            
        Returns:
            Dict[str, Any]: Generated discharge summary
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return {"success": False, "error": "Invalid input text"}
        
        try:
            # Extract key findings
            findings_result = self.key_findings_extract(text)
            
            if not findings_result["success"]:
                return findings_result
            
            findings = findings_result["findings"]
            
            # Generate summary of the full text
            summary_result = self.summarize(text, method="hybrid")
            
            if not summary_result["success"]:
                return summary_result
            
            summary = summary_result["summary"]
            
            # Combine findings and summary into discharge summary
            discharge_summary = "DISCHARGE SUMMARY\n\n"
            
            # Add key findings
            for category in ["diagnosis", "medications", "allergies", "procedures"]:
                if category in findings and findings[category]:
                    discharge_summary += f"{category.replace('_', ' ').title()}:\n"
                    for finding in findings[category]:
                        discharge_summary += f"- {finding}\n"
                    discharge_summary += "\n"
            
            # Add summary
            discharge_summary += "Clinical Summary:\n"
            discharge_summary += summary
            
            return {
                "success": True,
                "discharge_summary": discharge_summary,
                "findings": findings,
                "clinical_summary": summary
            }
        except Exception as e:
            logger.error(f"Error generating discharge summary: {e}")
            return {"success": False, "error": str(e)}
    
    def visualize_summary(self, original_text: str, summary: str, 
                         output_path: Optional[str] = None) -> str:
        """
        Create a visualization comparing original text and summary
        
        Args:
            original_text: Original clinical text
            summary: Generated summary
            output_path: Optional path to save visualization
            
        Returns:
            str: HTML representation of summary visualization
        """
        # Create HTML visualization
        html = f"""
        <div style="font-family: Arial, sans-serif; max-width: 1000px; margin: 0 auto; padding: 20px;">
            <h2 style="color: #333; text-align: center;">Clinical Text Summarization</h2>
            
            <div style="display: flex; gap: 20px; margin-top: 30px;">
                <div style="flex: 1; padding: 15px; border: 1px solid #ddd; border-radius: 5px;">
                    <h3 style="color: #555; margin-top: 0;">Original Text</h3>
                    <p style="line-height: 1.5; white-space: pre-wrap;">{original_text[:2000]}{'...' if len(original_text) > 2000 else ''}</p>
                    <div style="color: #777; font-size: 14px; margin-top: 10px;">
                        Word count: {len(original_text.split())}
                    </div>
                </div>
                
                <div style="flex: 1; padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">
                    <h3 style="color: #555; margin-top: 0;">Summary</h3>
                    <p style="line-height: 1.5; white-space: pre-wrap;">{summary}</p>
                    <div style="color: #777; font-size: 14px; margin-top: 10px;">
                        Word count: {len(summary.split())}
                        <br>
                        Compression ratio: {len(summary.split()) / len(original_text.split()):.2%}
                    </div>
                </div>
            </div>
        </div>
        """
        
        # Save to file if output path is provided
        if output_path:
            try:
                with open(output_path, "w", encoding="utf-8") as f:
                    f.write(html)
                logger.info(f"Visualization saved to {output_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {e}")
        
        return html

modules\user_interface\app.py

"""
User Interface Module - Main Application

This module provides the main Flask application for the intelligent medical diagnostic system,
including routes and views for both doctor and patient interfaces.
"""

import os
import json
import logging
from datetime import datetime
from pathlib import Path
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, session
from werkzeug.utils import secure_filename
import sys

# Add parent directory to path to import modules
sys.path.append(str(Path(__file__).parent.parent.parent))

from modules.diagnosis_engine.diagnosis_engine import DiagnosisEngine
from modules.user_interface.auth import auth_blueprint, login_required, is_doctor

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize Flask application
app = Flask(__name__, 
            static_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "static"),
            template_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates"))

# Load configuration
app.config.from_pyfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py'))

# Register blueprints
app.register_blueprint(auth_blueprint)

# Initialize diagnosis engine
diagnosis_engine = DiagnosisEngine()

# File upload settings
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'dcm', 'dicom', 'txt', 'pdf', 'csv', 'json'}
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER

def allowed_file(filename):
    """Check if the file extension is allowed."""
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/')
def index():
    """Render the home page."""
    return render_template('index.html')

@app.route('/dashboard')
@login_required
def dashboard():
    """Render the dashboard based on user role."""
    if is_doctor():
        return redirect(url_for('doctor_dashboard'))
    else:
        return redirect(url_for('patient_dashboard'))

@app.route('/doctor/dashboard')
@login_required
def doctor_dashboard():
    """Render the doctor dashboard."""
    if not is_doctor():
        flash('您没有权限访问医生控制面板', 'error')
        return redirect(url_for('dashboard'))
    
    # Mock data for demonstration
    recent_cases = [
        {
            'id': 'CASE001',
            'patient_name': '张三',
            'age': 65,
            'gender': '男',
            'primary_diagnosis': '肺炎',
            'confidence': 0.85,
            'date': '2023-05-15'
        },
        {
            'id': 'CASE002',
            'patient_name': '李四',
            'age': 42,
            'gender': '女',
            'primary_diagnosis': '冠心病',
            'confidence': 0.78,
            'date': '2023-05-14'
        },
        {
            'id': 'CASE003',
            'patient_name': '王五',
            'age': 58,
            'gender': '男',
            'primary_diagnosis': '2型糖尿病',
            'confidence': 0.92,
            'date': '2023-05-13'
        }
    ]
    
    statistics = {
        'total_cases': 152,
        'pending_review': 8,
        'completed_today': 12,
        'system_accuracy': 0.87
    }
    
    return render_template('doctor/dashboard.html', 
                          recent_cases=recent_cases, 
                          statistics=statistics)

@app.route('/patient/dashboard')
@login_required
def patient_dashboard():
    """Render the patient dashboard."""
    # Mock data for demonstration
    medical_records = [
        {
            'id': 'REC001',
            'date': '2023-05-15',
            'doctor': '张医生',
            'diagnosis': '上呼吸道感染',
            'treatment': '抗生素治疗'
        },
        {
            'id': 'REC002',
            'date': '2023-04-10',
            'doctor': '李医生',
            'diagnosis': '胃炎',
            'treatment': '质子泵抑制剂'
        }
    ]
    
    upcoming_appointments = [
        {
            'id': 'APT001',
            'date': '2023-05-20',
            'time': '10:30',
            'doctor': '王医生',
            'department': '内科'
        }
    ]
    
    return render_template('patient/dashboard.html', 
                          medical_records=medical_records, 
                          upcoming_appointments=upcoming_appointments)

@app.route('/doctor/new_case', methods=['GET', 'POST'])
@login_required
def new_case():
    """Handle new case creation."""
    if not is_doctor():
        flash('您没有权限创建新病例', 'error')
        return redirect(url_for('dashboard'))
    
    if request.method == 'POST':
        # Process form data
        patient_data = {
            'demographics': {
                'name': request.form.get('patient_name'),
                'age': int(request.form.get('patient_age')),
                'gender': request.form.get('patient_gender'),
                'height': float(request.form.get('patient_height')) if request.form.get('patient_height') else None,
                'weight': float(request.form.get('patient_weight')) if request.form.get('patient_weight') else None
            },
            'medical_history': request.form.get('medical_history', '').split(','),
            'allergies': request.form.get('allergies', '').split(',')
        }
        
        # Process file uploads
        image_analysis = None
        text_analysis = None
        lab_results = None
        
        if 'medical_image' in request.files:
            image_file = request.files['medical_image']
            if image_file and allowed_file(image_file.filename):
                filename = secure_filename(image_file.filename)
                file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
                image_file.save(file_path)
                
                # In a real application, this would call the image analysis module
                # For demo, we'll use mock data
                image_analysis = {
                    "modality": request.form.get('image_modality', ''),
                    "body_part": request.form.get('body_part', ''),
                    "findings": request.form.get('image_findings', '').split('\n'),
                    "confidence": 0.85
                }
        
        if 'clinical_text' in request.files:
            text_file = request.files['clinical_text']
            if text_file and allowed_file(text_file.filename):
                filename = secure_filename(text_file.filename)
                file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
                text_file.save(file_path)
                
                # In a real application, this would call the text analysis module
                # For demo, we'll use mock data based on form input
                text_analysis = {
                    "document_type": "临床病历",
                    "content": request.form.get('clinical_notes', ''),
                    "entities": [],  # Would be populated by NLP module
                    "confidence": 0.9
                }
        
        if 'lab_results' in request.files:
            lab_file = request.files['lab_results']
            if lab_file and allowed_file(lab_file.filename):
                filename = secure_filename(lab_file.filename)
                file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
                lab_file.save(file_path)
                
                # In a real application, this would parse the lab results file
                # For demo, we'll use mock data
                lab_results = {
                    "results": {},  # Would be populated by parsing the file
                    "abnormal_results": {},
                    "confidence": 0.95
                }
        
        # Generate diagnosis
        try:
            diagnosis_results = diagnosis_engine.diagnose(
                patient_data, image_analysis, text_analysis, lab_results
            )
            
            # Generate report
            report_format = request.form.get('report_format', 'html')
            report = diagnosis_engine.generate_report(
                diagnosis_results, patient_data, report_format
            )
            
            # Save report
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            report_filename = f"report_{timestamp}.{report_format}"
            report_path = os.path.join(app.config['UPLOAD_FOLDER'], report_filename)
            
            with open(report_path, 'w', encoding='utf-8') as f:
                f.write(report)
            
            # Store case ID in session for redirect
            case_id = f"CASE{timestamp}"
            session['last_case_id'] = case_id
            
            flash('诊断报告已生成', 'success')
            return redirect(url_for('view_case', case_id=case_id))
        
        except Exception as e:
            logger.error(f"Error generating diagnosis: {e}")
            flash(f'生成诊断报告时出错: {str(e)}', 'error')
    
    return render_template('doctor/new_case.html')

@app.route('/doctor/case/<case_id>')
@login_required
def view_case(case_id):
    """View a specific case."""
    if not is_doctor():
        flash('您没有权限查看病例', 'error')
        return redirect(url_for('dashboard'))
    
    # In a real application, this would fetch the case from a database
    # For demo, we'll use mock data
    if case_id == session.get('last_case_id'):
        # This would be the most recently created case
        case = {
            'id': case_id,
            'patient_name': '张三',
            'age': 65,
            'gender': '男',
            'primary_diagnosis': '肺炎',
            'confidence': 0.85,
            'differential_diagnoses': [
                {'disease': '慢性支气管炎', 'confidence': 0.65},
                {'disease': '肺结核', 'confidence': 0.45}
            ],
            'evidence': {
                '肺炎': [
                    {'item': '右肺下叶有不规则阴影', 'type': '影像发现'},
                    {'item': '体温38.2°C', 'type': '体征'},
                    {'item': '白细胞计数升高', 'type': '实验室检查'}
                ]
            },
            'date': datetime.now().strftime("%Y-%m-%d")
        }
    else:
        # Mock data for other cases
        case = {
            'id': case_id,
            'patient_name': '李四',
            'age': 42,
            'gender': '女',
            'primary_diagnosis': '冠心病',
            'confidence': 0.78,
            'differential_diagnoses': [
                {'disease': '心肌炎', 'confidence': 0.55},
                {'disease': '胸壁疼痛', 'confidence': 0.40}
            ],
            'evidence': {
                '冠心病': [
                    {'item': 'ST段压低', 'type': 'ECG发现'},
                    {'item': '胸痛', 'type': '症状'},
                    {'item': '家族史', 'type': '病史'}
                ]
            },
            'date': '2023-05-14'
        }
    
    return render_template('doctor/view_case.html', case=case)

@app.route('/doctor/patients')
@login_required
def patient_list():
    """View list of patients."""
    if not is_doctor():
        flash('您没有权限查看患者列表', 'error')
        return redirect(url_for('dashboard'))
    
    # Mock data for demonstration
    patients = [
        {
            'id': 'PAT001',
            'name': '张三',
            'age': 65,
            'gender': '男',
            'last_visit': '2023-05-15',
            'primary_diagnosis': '肺炎'
        },
        {
            'id': 'PAT002',
            'name': '李四',
            'age': 42,
            'gender': '女',
            'last_visit': '2023-05-14',
            'primary_diagnosis': '冠心病'
        },
        {
            'id': 'PAT003',
            'name': '王五',
            'age': 58,
            'gender': '男',
            'last_visit': '2023-05-13',
            'primary_diagnosis': '2型糖尿病'
        }
    ]
    
    return render_template('doctor/patients.html', patients=patients)

@app.route('/patient/medical_records')
@login_required
def medical_records():
    """View patient's medical records."""
    # Mock data for demonstration
    records = [
        {
            'id': 'REC001',
            'date': '2023-05-15',
            'doctor': '张医生',
            'diagnosis': '上呼吸道感染',
            'treatment': '抗生素治疗',
            'notes': '症状有所改善,继续观察'
        },
        {
            'id': 'REC002',
            'date': '2023-04-10',
            'doctor': '李医生',
            'diagnosis': '胃炎',
            'treatment': '质子泵抑制剂',
            'notes': '建议清淡饮食,避免辛辣刺激'
        },
        {
            'id': 'REC003',
            'date': '2023-02-22',
            'doctor': '王医生',
            'diagnosis': '季节性过敏',
            'treatment': '抗组胺药物',
            'notes': '建议避免过敏原,保持室内通风'
        }
    ]
    
    return render_template('patient/medical_records.html', records=records)

@app.route('/patient/appointments')
@login_required
def appointments():
    """View and manage patient's appointments."""
    # Mock data for demonstration
    upcoming = [
        {
            'id': 'APT001',
            'date': '2023-05-20',
            'time': '10:30',
            'doctor': '王医生',
            'department': '内科',
            'reason': '复诊'
        }
    ]
    
    past = [
        {
            'id': 'APT002',
            'date': '2023-05-15',
            'time': '14:00',
            'doctor': '张医生',
            'department': '呼吸科',
            'reason': '咳嗽、发热',
            'status': '已完成'
        },
        {
            'id': 'APT003',
            'date': '2023-04-10',
            'time': '09:15',
            'doctor': '李医生',
            'department': '消化科',
            'reason': '腹痛',
            'status': '已完成'
        }
    ]
    
    return render_template('patient/appointments.html', 
                          upcoming_appointments=upcoming, 
                          past_appointments=past)

@app.route('/api/diagnose', methods=['POST'])
@login_required
def api_diagnose():
    """API endpoint for diagnosis."""
    if not request.is_json:
        return jsonify({'error': 'Request must be JSON'}), 400
    
    data = request.get_json()
    
    # Extract data from request
    patient_data = data.get('patient_data', {})
    image_analysis = data.get('image_analysis')
    text_analysis = data.get('text_analysis')
    lab_results = data.get('lab_results')
    
    try:
        # Generate diagnosis
        diagnosis_results = diagnosis_engine.diagnose(
            patient_data, image_analysis, text_analysis, lab_results
        )
        
        return jsonify(diagnosis_results)
    
    except Exception as e:
        logger.error(f"API error: {e}")
        return jsonify({'error': str(e)}), 500

@app.errorhandler(404)
def page_not_found(e):
    """Handle 404 errors."""
    return render_template('errors/404.html'), 404

@app.errorhandler(500)
def server_error(e):
    """Handle 500 errors."""
    logger.error(f"Server error: {e}")
    return render_template('errors/500.html'), 500

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0', port=5000)

modules\user_interface\auth.py

"""
User Interface Module - Authentication

This module provides authentication functionality for the intelligent medical diagnostic system,
including user registration, login, and role-based access control.
"""

import os
import json
import logging
import functools
from datetime import datetime, timedelta
from flask import Blueprint, render_template, request, redirect, url_for, flash, session, g
from werkzeug.security import generate_password_hash, check_password_hash

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Initialize blueprint
auth_blueprint = Blueprint('auth', __name__)

# Mock user database (in a real application, this would be a database)
USERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'users.json')
os.makedirs(os.path.dirname(USERS_FILE), exist_ok=True)

def get_users():
    """Get users from JSON file."""
    if os.path.exists(USERS_FILE):
        try:
            with open(USERS_FILE, 'r', encoding='utf-8') as f:
                return json.load(f)
        except Exception as e:
            logger.error(f"Error loading users: {e}")
    
    # Default users if file doesn't exist
    return {
        "doctor1": {
            "username": "doctor1",
            "password": generate_password_hash("password123"),
            "name": "张医生",
            "role": "doctor",
            "department": "内科",
            "title": "主任医师",
            "email": "doctor1@example.com"
        },
        "doctor2": {
            "username": "doctor2",
            "password": generate_password_hash("password123"),
            "name": "李医生",
            "role": "doctor",
            "department": "外科",
            "title": "副主任医师",
            "email": "doctor2@example.com"
        },
        "patient1": {
            "username": "patient1",
            "password": generate_password_hash("password123"),
            "name": "张三",
            "role": "patient",
            "age": 65,
            "gender": "男",
            "email": "patient1@example.com"
        },
        "patient2": {
            "username": "patient2",
            "password": generate_password_hash("password123"),
            "name": "李四",
            "role": "patient",
            "age": 42,
            "gender": "女",
            "email": "patient2@example.com"
        }
    }

def save_users(users):
    """Save users to JSON file."""
    try:
        os.makedirs(os.path.dirname(USERS_FILE), exist_ok=True)
        with open(USERS_FILE, 'w', encoding='utf-8') as f:
            json.dump(users, f, ensure_ascii=False, indent=2)
        return True
    except Exception as e:
        logger.error(f"Error saving users: {e}")
        return False

# Initialize users file if it doesn't exist
if not os.path.exists(USERS_FILE):
    save_users(get_users())

def login_required(view):
    """Decorator to require login for views."""
    @functools.wraps(view)
    def wrapped_view(**kwargs):
        if 'user_id' not in session:
            flash('请先登录', 'error')
            return redirect(url_for('auth.login'))
        return view(**kwargs)
    return wrapped_view

def is_doctor():
    """Check if current user is a doctor."""
    if 'user_id' not in session:
        return False
    
    users = get_users()
    user = users.get(session['user_id'])
    
    if not user:
        return False
    
    return user.get('role') == 'doctor'

@auth_blueprint.before_app_request
def load_logged_in_user():
    """Load user data before each request."""
    user_id = session.get('user_id')
    
    if user_id is None:
        g.user = None
    else:
        users = get_users()
        g.user = users.get(user_id)

@auth_blueprint.route('/login', methods=['GET', 'POST'])
def login():
    """Handle user login."""
    if request.method == 'POST':
        username = request.form.get('username')
        password = request.form.get('password')
        
        users = get_users()
        user = users.get(username)
        
        error = None
        
        if user is None:
            error = '用户名不存在'
        elif not check_password_hash(user['password'], password):
            error = '密码不正确'
        
        if error is None:
            # Clear session and set user_id
            session.clear()
            session['user_id'] = username
            
            # Log login
            logger.info(f"User {username} logged in")
            
            # Redirect based on role
            if user.get('role') == 'doctor':
                return redirect(url_for('doctor_dashboard'))
            else:
                return redirect(url_for('patient_dashboard'))
        
        flash(error, 'error')
    
    return render_template('auth/login.html')

@auth_blueprint.route('/register', methods=['GET', 'POST'])
def register():
    """Handle user registration."""
    if request.method == 'POST':
        username = request.form.get('username')
        password = request.form.get('password')
        confirm_password = request.form.get('confirm_password')
        name = request.form.get('name')
        email = request.form.get('email')
        role = request.form.get('role')
        
        users = get_users()
        error = None
        
        if not username:
            error = '用户名不能为空'
        elif not password:
            error = '密码不能为空'
        elif password != confirm_password:
            error = '两次输入的密码不一致'
        elif username in users:
            error = f'用户名 {username} 已被注册'
        
        if error is None:
            # Create new user
            new_user = {
                'username': username,
                'password': generate_password_hash(password),
                'name': name,
                'email': email,
                'role': role,
                'created_at': datetime.now().isoformat()
            }
            
            # Add role-specific fields
            if role == 'doctor':
                new_user['department'] = request.form.get('department', '')
                new_user['title'] = request.form.get('title', '')
            elif role == 'patient':
                new_user['age'] = int(request.form.get('age', 0))
                new_user['gender'] = request.form.get('gender', '')
            
            # Save user
            users[username] = new_user
            if save_users(users):
                flash('注册成功,请登录', 'success')
                return redirect(url_for('auth.login'))
            else:
                error = '注册失败,请稍后再试'
        
        flash(error, 'error')
    
    return render_template('auth/register.html')

@auth_blueprint.route('/logout')
def logout():
    """Handle user logout."""
    session.clear()
    flash('您已成功退出登录', 'success')
    return redirect(url_for('index'))

@auth_blueprint.route('/profile')
@login_required
def profile():
    """View and edit user profile."""
    return render_template('auth/profile.html')

@auth_blueprint.route('/profile/edit', methods=['GET', 'POST'])
@login_required
def edit_profile():
    """Edit user profile."""
    if request.method == 'POST':
        users = get_users()
        user = users.get(session['user_id'])
        
        if not user:
            flash('用户不存在', 'error')
            return redirect(url_for('auth.profile'))
        
        # Update user information
        user['name'] = request.form.get('name', user['name'])
        user['email'] = request.form.get('email', user['email'])
        
        # Update role-specific fields
        if user.get('role') == 'doctor':
            user['department'] = request.form.get('department', user.get('department', ''))
            user['title'] = request.form.get('title', user.get('title', ''))
        elif user.get('role') == 'patient':
            user['age'] = int(request.form.get('age', user.get('age', 0)))
            user['gender'] = request.form.get('gender', user.get('gender', ''))
        
        # Check if password is being updated
        new_password = request.form.get('new_password')
        confirm_password = request.form.get('confirm_password')
        
        if new_password:
            if new_password != confirm_password:
                flash('两次输入的密码不一致', 'error')
                return redirect(url_for('auth.edit_profile'))
            
            user['password'] = generate_password_hash(new_password)
        
        # Save updated user information
        if save_users(users):
            flash('个人信息已更新', 'success')
        else:
            flash('更新失败,请稍后再试', 'error')
        
        return redirect(url_for('auth.profile'))
    
    return render_template('auth/edit_profile.html')

@auth_blueprint.route('/reset_password', methods=['GET', 'POST'])
def reset_password():
    """Handle password reset request."""
    if request.method == 'POST':
        username = request.form.get('username')
        email = request.form.get('email')
        
        users = get_users()
        user = users.get(username)
        
        if user and user.get('email') == email:
            # In a real application, this would send a password reset email
            # For demo, we'll just reset the password directly
            new_password = 'resetpassword123'
            user['password'] = generate_password_hash(new_password)
            
            if save_users(users):
                flash(f'密码已重置为: {new_password},请登录后立即修改密码', 'success')
                return redirect(url_for('auth.login'))
            else:
                flash('密码重置失败,请稍后再试', 'error')
        else:
            flash('用户名或邮箱不正确', 'error')
    
    return render_template('auth/reset_password.html')

modules\user_interface\config.py

"""
User Interface Module - Configuration

This module provides configuration settings for the Flask application.
"""

import os
import secrets
from datetime import timedelta

# Application settings
SECRET_KEY = secrets.token_hex(16)
SESSION_TYPE = 'filesystem'
PERMANENT_SESSION_LIFETIME = timedelta(days=1)
MAX_CONTENT_LENGTH = 16 * 1024 * 1024  # 16 MB max upload size

# Upload settings
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'dcm', 'dicom', 'txt', 'pdf', 'csv', 'json'}

# Diagnosis engine settings
DIAGNOSIS_ENGINE_CONFIG = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 
                                     'config', 'diagnosis_engine_config.json')

# Debug settings
DEBUG = True

modules\user_interface\templates\index.html

{% extends "layout.html" %}

{% block title %}首页 - 智能医疗辅助诊断系统{% endblock %}

{% block content %}
<div class="row align-items-center">
    <div class="col-lg-6">
        <h1 class="display-4 fw-bold mb-4">智能医疗辅助诊断系统</h1>
        <p class="lead mb-4">
            结合人工智能和医学专业知识,为医生提供辅助诊断支持,提高诊断准确率和医疗效率。
        </p>
        <div class="d-grid gap-2 d-md-flex">
            {% if not g.user %}
                <a href="{{ url_for('auth.login') }}" class="btn btn-primary btn-lg px-4 me-md-2">登录系统</a>
                <a href="{{ url_for('auth.register') }}" class="btn btn-outline-secondary btn-lg px-4">注册账号</a>
            {% else %}
                {% if g.user.role == 'doctor' %}
                    <a href="{{ url_for('doctor_dashboard') }}" class="btn btn-primary btn-lg px-4 me-md-2">进入控制面板</a>
                    <a href="{{ url_for('new_case') }}" class="btn btn-outline-secondary btn-lg px-4">新建病例</a>
                {% else %}
                    <a href="{{ url_for('patient_dashboard') }}" class="btn btn-primary btn-lg px-4 me-md-2">进入我的主页</a>
                    <a href="{{ url_for('medical_records') }}" class="btn btn-outline-secondary btn-lg px-4">查看病历记录</a>
                {% endif %}
            {% endif %}
        </div>
    </div>
    <div class="col-lg-6 d-none d-lg-block">
        <img src="{{ url_for('static', filename='img/medical_ai.jpg') }}" class="img-fluid rounded shadow" alt="医疗AI">
    </div>
</div>

<hr class="my-5">

<div class="row g-4 py-4">
    <div class="col-md-4">
        <div class="card h-100 shadow-sm">
            <div class="card-body text-center">
                <i class="fas fa-brain fa-3x text-primary mb-3"></i>
                <h3 class="card-title">AI辅助诊断</h3>
                <p class="card-text">利用先进的机器学习算法分析医学影像和临床数据,为医生提供诊断建议。</p>
            </div>
        </div>
    </div>
    <div class="col-md-4">
        <div class="card h-100 shadow-sm">
            <div class="card-body text-center">
                <i class="fas fa-x-ray fa-3x text-primary mb-3"></i>
                <h3 class="card-title">医学影像分析</h3>
                <p class="card-text">自动分析X光、CT和MRI等医学影像,检测异常并提供详细的分析报告。</p>
            </div>
        </div>
    </div>
    <div class="col-md-4">
        <div class="card h-100 shadow-sm">
            <div class="card-body text-center">
                <i class="fas fa-file-medical-alt fa-3x text-primary mb-3"></i>
                <h3 class="card-title">临床文本分析</h3>
                <p class="card-text">从病历、检查报告等临床文本中提取关键信息,辅助医生进行诊断决策。</p>
            </div>
        </div>
    </div>
</div>

<div class="row mt-5">
    <div class="col-12">
        <div class="card bg-light">
            <div class="card-body p-4">
                <h2 class="text-center mb-4">系统特点</h2>
                <div class="row g-4">
                    <div class="col-md-6">
                        <div class="d-flex">
                            <div class="flex-shrink-0">
                                <i class="fas fa-check-circle text-success me-2"></i>
                            </div>
                            <div>
                                <h5>多模态数据整合</h5>
                                <p>结合医学影像、临床文本和实验室检查结果,提供全面的诊断分析。</p>
                            </div>
                        </div>
                    </div>
                    <div class="col-md-6">
                        <div class="d-flex">
                            <div class="flex-shrink-0">
                                <i class="fas fa-check-circle text-success me-2"></i>
                            </div>
                            <div>
                                <h5>知识图谱推理</h5>
                                <p>基于医学知识图谱的推理系统,提供可解释的诊断建议。</p>
                            </div>
                        </div>
                    </div>
                    <div class="col-md-6">
                        <div class="d-flex">
                            <div class="flex-shrink-0">
                                <i class="fas fa-check-circle text-success me-2"></i>
                            </div>
                            <div>
                                <h5>深度学习模型</h5>
                                <p>采用先进的深度学习模型进行医学影像分割和分类。</p>
                            </div>
                        </div>
                    </div>
                    <div class="col-md-6">
                        <div class="d-flex">
                            <div class="flex-shrink-0">
                                <i class="fas fa-check-circle text-success me-2"></i>
                            </div>
                            <div>
                                <h5>用户友好界面</h5>
                                <p>直观易用的界面设计,适合医生和患者使用。</p>
                            </div>
                        </div>
                    </div>
                </div>
            </div>
        </div>
    </div>
</div>
{% endblock %}

modules\user_interface\templates\layout.html

<!DOCTYPE html>
<html lang="zh-CN">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>{% block title %}智能医疗辅助诊断系统{% endblock %}</title>
    <!-- Bootstrap CSS -->
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet">
    <!-- Font Awesome -->
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css">
    <!-- Custom CSS -->
    <link rel="stylesheet" href="{{ url_for('static', filename='css/style.css') }}">
    {% block styles %}{% endblock %}
</head>
<body>
    <!-- Navigation -->
    <nav class="navbar navbar-expand-lg navbar-dark bg-primary">
        <div class="container">
            <a class="navbar-brand" href="{{ url_for('index') }}">
                <i class="fas fa-heartbeat me-2"></i>智能医疗辅助诊断系统
            </a>
            <button class="navbar-toggler" type="button" data-bs-toggle="collapse" data-bs-target="#navbarNav">
                <span class="navbar-toggler-icon"></span>
            </button>
            <div class="collapse navbar-collapse" id="navbarNav">
                <ul class="navbar-nav me-auto">
                    <li class="nav-item">
                        <a class="nav-link" href="{{ url_for('index') }}">首页</a>
                    </li>
                    {% if g.user %}
                        {% if g.user.role == 'doctor' %}
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('doctor_dashboard') }}">控制面板</a>
                            </li>
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('new_case') }}">新建病例</a>
                            </li>
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('patient_list') }}">患者管理</a>
                            </li>
                        {% else %}
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('patient_dashboard') }}">我的主页</a>
                            </li>
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('medical_records') }}">病历记录</a>
                            </li>
                            <li class="nav-item">
                                <a class="nav-link" href="{{ url_for('appointments') }}">预约管理</a>
                            </li>
                        {% endif %}
                    {% endif %}
                </ul>
                <ul class="navbar-nav">
                    {% if g.user %}
                        <li class="nav-item dropdown">
                            <a class="nav-link dropdown-toggle" href="#" id="userDropdown" role="button" data-bs-toggle="dropdown">
                                <i class="fas fa-user-circle me-1"></i>{{ g.user.name }}
                            </a>
                            <ul class="dropdown-menu dropdown-menu-end">
                                <li><a class="dropdown-item" href="{{ url_for('auth.profile') }}">个人信息</a></li>
                                <li><hr class="dropdown-divider"></li>
                                <li><a class="dropdown-item" href="{{ url_for('auth.logout') }}">退出登录</a></li>
                            </ul>
                        </li>
                    {% else %}
                        <li class="nav-item">
                            <a class="nav-link" href="{{ url_for('auth.login') }}">登录</a>
                        </li>
                        <li class="nav-item">
                            <a class="nav-link" href="{{ url_for('auth.register') }}">注册</a>
                        </li>
                    {% endif %}
                </ul>
            </div>
        </div>
    </nav>

    <!-- Flash Messages -->
    {% with messages = get_flashed_messages(with_categories=true) %}
        {% if messages %}
            <div class="container mt-3">
                {% for category, message in messages %}
                    <div class="alert alert-{{ 'danger' if category == 'error' else category }} alert-dismissible fade show">
                        {{ message }}
                        <button type="button" class="btn-close" data-bs-dismiss="alert"></button>
                    </div>
                {% endfor %}
            </div>
        {% endif %}
    {% endwith %}

    <!-- Main Content -->
    <main class="container py-4">
        {% block content %}{% endblock %}
    </main>

    <!-- Footer -->
    <footer class="bg-light py-4 mt-auto">
        <div class="container">
            <div class="row">
                <div class="col-md-6">
                    <p class="mb-0">&copy; 2023 智能医疗辅助诊断系统</p>
                </div>
                <div class="col-md-6 text-md-end">
                    <a href="#" class="text-decoration-none me-3">关于我们</a>
                    <a href="#" class="text-decoration-none me-3">隐私政策</a>
                    <a href="#" class="text-decoration-none">联系我们</a>
                </div>
            </div>
        </div>
    </footer>

    <!-- Bootstrap JS -->
    <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js"></script>
    <!-- jQuery -->
    <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <!-- Custom JS -->
    <script src="{{ url_for('static', filename='js/main.js') }}"></script>
    {% block scripts %}{% endblock %}
</body>
</html>

modules\user_interface\templates\auth\login.html

{% extends "layout.html" %}

{% block title %}登录 - 智能医疗辅助诊断系统{% endblock %}

{% block content %}
<div class="row justify-content-center">
    <div class="col-md-6 col-lg-5">
        <div class="card shadow">
            <div class="card-header bg-primary text-white text-center py-3">
                <h4 class="mb-0"><i class="fas fa-sign-in-alt me-2"></i>用户登录</h4>
            </div>
            <div class="card-body p-4">
                <form method="post" action="{{ url_for('auth.login') }}">
                    <div class="mb-3">
                        <label for="username" class="form-label">用户名</label>
                        <div class="input-group">
                            <span class="input-group-text"><i class="fas fa-user"></i></span>
                            <input type="text" class="form-control" id="username" name="username" required>
                        </div>
                    </div>
                    <div class="mb-3">
                        <label for="password" class="form-label">密码</label>
                        <div class="input-group">
                            <span class="input-group-text"><i class="fas fa-lock"></i></span>
                            <input type="password" class="form-control" id="password" name="password" required>
                        </div>
                    </div>
                    <div class="mb-3 form-check">
                        <input type="checkbox" class="form-check-input" id="remember" name="remember">
                        <label class="form-check-label" for="remember">记住我</label>
                    </div>
                    <div class="d-grid gap-2">
                        <button type="submit" class="btn btn-primary btn-lg">登录</button>
                    </div>
                </form>
            </div>
            <div class="card-footer bg-light text-center py-3">
                <div class="mb-2">
                    <a href="{{ url_for('auth.reset_password') }}" class="text-decoration-none">忘记密码?</a>
                </div>
                <div>
                    还没有账号? <a href="{{ url_for('auth.register') }}" class="text-decoration-none">立即注册</a>
                </div>
            </div>
        </div>
        
        <!-- 演示账号信息 -->
        <div class="card mt-4 shadow-sm">
            <div class="card-header bg-light">
                <h5 class="mb-0">演示账号</h5>
            </div>
            <div class="card-body">
                <div class="row">
                    <div class="col-md-6">
                        <h6>医生账号</h6>
                        <p class="mb-1"><strong>用户名:</strong> doctor1</p>
                        <p><strong>密码:</strong> password123</p>
                    </div>
                    <div class="col-md-6">
                        <h6>患者账号</h6>
                        <p class="mb-1"><strong>用户名:</strong> patient1</p>
                        <p><strong>密码:</strong> password123</p>
                    </div>
                </div>
            </div>
        </div>
    </div>
</div>
{% endblock %}

modules\user_interface\templates\auth\register.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天天进步2015

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值