"""
Medical Named Entity Recognition Module
This module provides functionality for extracting medical entities from clinical text,
such as diseases, symptoms, medications, and anatomical structures.
"""
import os
import json
import logging
from typing import List, Dict, Any, Optional, Union, Tuple
import numpy as np
import spacy
from spacy.tokens import Doc, Span
from spacy.training import Example
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("medical_ner.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class MedicalNamedEntityRecognizer:
"""
Class for recognizing medical named entities in clinical text
"""
def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
"""
Initialize the medical named entity recognizer
Args:
config_path: Path to the configuration file
model_path: Path to a pre-trained model
"""
self.config = self._load_config(config_path)
self.entity_types = self.config["ner"]["entity_types"]
self.confidence_threshold = self.config["ner"]["confidence_threshold"]
# Initialize models
self.spacy_model = None
self.transformer_model = None
self.tokenizer = None
self.ner_pipeline = None
# Load models based on configuration
self._load_models(model_path)
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"ner": {
"model": "medical",
"confidence_threshold": 0.7,
"entity_types": [
"DISEASE", "SYMPTOM", "TREATMENT", "MEDICATION", "ANATOMY",
"PROCEDURE", "TEST", "DOSAGE", "FREQUENCY", "DURATION"
]
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def _load_models(self, model_path: Optional[str]) -> None:
"""
Load NER models
Args:
model_path: Path to a pre-trained model
"""
# Try to load spaCy model
try:
self.spacy_model = spacy.load("en_core_sci_md")
logger.info("Loaded spaCy scientific model: en_core_sci_md")
except OSError:
try:
self.spacy_model = spacy.load("en_core_web_sm")
logger.info("Loaded spaCy model: en_core_web_sm")
except OSError:
logger.warning("spaCy model not found. Using spaCy's blank model.")
self.spacy_model = spacy.blank('en')
# Load transformer model if available
if model_path and os.path.exists(model_path):
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.transformer_model = AutoModelForTokenClassification.from_pretrained(model_path)
self.ner_pipeline = pipeline(
"ner",
model=self.transformer_model,
tokenizer=self.tokenizer,
aggregation_strategy="simple"
)
logger.info(f"Loaded transformer model from {model_path}")
except Exception as e:
logger.error(f"Error loading transformer model from {model_path}: {e}")
else:
# Use a pre-trained medical NER model from Hugging Face
try:
model_name = "samrawal/bert-base-uncased_clinical-ner"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transformer_model = AutoModelForTokenClassification.from_pretrained(model_name)
self.ner_pipeline = pipeline(
"ner",
model=self.transformer_model,
tokenizer=self.tokenizer,
aggregation_strategy="simple"
)
logger.info(f"Loaded pre-trained transformer model: {model_name}")
except Exception as e:
logger.error(f"Error loading pre-trained transformer model: {e}")
def extract_entities_spacy(self, text: str) -> List[Dict[str, Any]]:
"""
Extract medical entities using spaCy
Args:
text: Clinical text
Returns:
List[Dict[str, Any]]: List of extracted entities
"""
if not self.spacy_model:
logger.error("spaCy model not available")
return []
doc = self.spacy_model(text)
entities = []
for ent in doc.ents:
# Filter by entity type if needed
entity = {
"text": ent.text,
"label": ent.label_,
"start": ent.start_char,
"end": ent.end_char,
"source": "spacy"
}
entities.append(entity)
return entities
def extract_entities_transformer(self, text: str) -> List[Dict[str, Any]]:
"""
Extract medical entities using transformer model
Args:
text: Clinical text
Returns:
List[Dict[str, Any]]: List of extracted entities
"""
if not self.ner_pipeline:
logger.error("Transformer model not available")
return []
try:
# Process text with transformer NER pipeline
ner_results = self.ner_pipeline(text)
entities = []
for result in ner_results:
# Filter by confidence threshold
if result["score"] >= self.confidence_threshold:
entity = {
"text": result["word"],
"label": result["entity"],
"start": result["start"],
"end": result["end"],
"confidence": result["score"],
"source": "transformer"
}
entities.append(entity)
return entities
except Exception as e:
logger.error(f"Error in transformer entity extraction: {e}")
return []
def extract_entities(self, text: str, method: str = "hybrid") -> List[Dict[str, Any]]:
"""
Extract medical entities from clinical text
Args:
text: Clinical text
method: Extraction method ('spacy', 'transformer', or 'hybrid')
Returns:
List[Dict[str, Any]]: List of extracted entities
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return []
if method == "spacy":
return self.extract_entities_spacy(text)
elif method == "transformer":
return self.extract_entities_transformer(text)
elif method == "hybrid":
# Get entities from both methods
spacy_entities = self.extract_entities_spacy(text)
transformer_entities = self.extract_entities_transformer(text)
# Combine and deduplicate entities
all_entities = spacy_entities + transformer_entities
# Simple deduplication based on text and label
deduplicated = {}
for entity in all_entities:
key = f"{entity['text']}_{entity['label']}"
if key not in deduplicated or (
'confidence' in entity and
('confidence' not in deduplicated[key] or
entity['confidence'] > deduplicated[key]['confidence'])
):
deduplicated[key] = entity
return list(deduplicated.values())
else:
logger.error(f"Unknown extraction method: {method}")
return []
def extract_medical_concepts(self, text: str) -> Dict[str, List[Dict[str, Any]]]:
"""
Extract medical concepts categorized by type
Args:
text: Clinical text
Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary of entities by type
"""
# Extract all entities
entities = self.extract_entities(text)
# Group by entity type
grouped_entities = {}
for entity in entities:
entity_type = entity["label"]
if entity_type not in grouped_entities:
grouped_entities[entity_type] = []
grouped_entities[entity_type].append(entity)
return grouped_entities
def extract_diseases_and_symptoms(self, text: str) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Extract diseases and symptoms from clinical text
Args:
text: Clinical text
Returns:
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: Diseases and symptoms
"""
entities = self.extract_entities(text)
diseases = []
symptoms = []
for entity in entities:
if entity["label"] in ["DISEASE", "PROBLEM"]:
diseases.append(entity)
elif entity["label"] in ["SYMPTOM", "FINDING"]:
symptoms.append(entity)
return diseases, symptoms
def extract_medications(self, text: str) -> List[Dict[str, Any]]:
"""
Extract medications from clinical text
Args:
text: Clinical text
Returns:
List[Dict[str, Any]]: List of medications
"""
entities = self.extract_entities(text)
medications = []
for entity in entities:
if entity["label"] in ["MEDICATION", "DRUG"]:
medications.append(entity)
return medications
def extract_treatments(self, text: str) -> List[Dict[str, Any]]:
"""
Extract treatments from clinical text
Args:
text: Clinical text
Returns:
List[Dict[str, Any]]: List of treatments
"""
entities = self.extract_entities(text)
treatments = []
for entity in entities:
if entity["label"] in ["TREATMENT", "PROCEDURE"]:
treatments.append(entity)
return treatments
def visualize_entities(self, text: str, entities: List[Dict[str, Any]]) -> str:
"""
Create a visualization of entities in text
Args:
text: Original text
entities: List of entities
Returns:
str: HTML representation of text with highlighted entities
"""
if not text or not entities:
return text
# Sort entities by start position (descending) to avoid index issues when inserting HTML tags
sorted_entities = sorted(entities, key=lambda x: x["start"], reverse=True)
# Create a copy of the text to modify
html_text = text
# Define colors for different entity types
colors = {
"DISEASE": "#ff9999",
"SYMPTOM": "#99ff99",
"MEDICATION": "#9999ff",
"TREATMENT": "#ffff99",
"ANATOMY": "#ff99ff",
"PROCEDURE": "#99ffff",
"TEST": "#ffcc99",
"DOSAGE": "#ccff99",
"FREQUENCY": "#cc99ff",
"DURATION": "#99ccff",
"PROBLEM": "#ff9999",
"FINDING": "#99ff99",
"DRUG": "#9999ff"
}
# Insert HTML tags for each entity
for entity in sorted_entities:
start = entity["start"]
end = entity["end"]
label = entity["label"]
# Get color for entity type, default to gray
color = colors.get(label, "#cccccc")
# Create span with background color and tooltip
span_html = f'<span style="background-color: {color};" title="{label}">{html_text[start:end]}</span>'
# Replace the entity text with the HTML span
html_text = html_text[:start] + span_html + html_text[end:]
# Wrap in a div for styling
html_text = f'<div style="font-family: Arial, sans-serif; line-height: 1.5;">{html_text}</div>'
return html_text
def train_custom_model(self, training_data: List[Dict[str, Any]],
output_path: str, epochs: int = 10) -> bool:
"""
Train a custom NER model on medical data
Args:
training_data: List of training examples
output_path: Path to save the trained model
epochs: Number of training epochs
Returns:
bool: Success status
"""
# This is a simplified implementation
# In a real system, this would be more complex
try:
# Create a blank spaCy model
nlp = spacy.blank("en")
# Add NER component
ner = nlp.add_pipe("ner")
# Add entity labels
for example in training_data:
for entity in example["entities"]:
ner.add_label(entity["label"])
# Prepare training examples
train_examples = []
for example in training_data:
text = example["text"]
entities = example["entities"]
# Create Doc object
doc = nlp.make_doc(text)
# Create spans for entities
spans = []
for entity in entities:
span = Span(doc, entity["start"], entity["end"], label=entity["label"])
spans.append(span)
# Set entities
doc.ents = spans
# Create Example object
train_examples.append(Example.from_dict(doc, {"entities": entities}))
# Train the model
optimizer = nlp.begin_training()
for _ in range(epochs):
losses = {}
for example in train_examples:
nlp.update([example], drop=0.5, losses=losses)
# Save the model
nlp.to_disk(output_path)
logger.info(f"Custom NER model saved to {output_path}")
return True
except Exception as e:
logger.error(f"Error training custom NER model: {e}")
return False
modules\text_analysis\relation_extractor.py
"""
Medical Relation Extraction Module
This module provides functionality for extracting relationships between medical entities
in clinical text, such as disease-symptom, medication-disease, etc.
"""
import os
import json
import logging
import re
from typing import List, Dict, Any, Optional, Union, Tuple, Set
import numpy as np
import spacy
from spacy.tokens import Doc
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline
)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("medical_relation_extraction.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class MedicalRelationExtractor:
"""
Class for extracting relationships between medical entities in clinical text
"""
def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
"""
Initialize the medical relation extractor
Args:
config_path: Path to the configuration file
model_path: Path to a pre-trained model
"""
self.config = self._load_config(config_path)
self.relation_types = self.config["relation_extraction"]["relation_types"]
self.confidence_threshold = self.config["relation_extraction"]["confidence_threshold"]
# Initialize models
self.spacy_model = None
self.transformer_model = None
self.tokenizer = None
self.relation_pipeline = None
# Load models based on configuration
self._load_models(model_path)
# Define relation patterns for rule-based extraction
self.relation_patterns = self._define_relation_patterns()
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"relation_extraction": {
"model": "medical-bert",
"relation_types": [
"TREATS", "CAUSES", "DIAGNOSES", "PREVENTS", "CONTRAINDICATES",
"SYMPTOM_OF", "MANIFESTATION_OF", "SUGGESTS"
],
"confidence_threshold": 0.6
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def _load_models(self, model_path: Optional[str]) -> None:
"""
Load relation extraction models
Args:
model_path: Path to a pre-trained model
"""
# Try to load spaCy model
try:
self.spacy_model = spacy.load("en_core_web_sm")
logger.info("Loaded spaCy model: en_core_web_sm")
except OSError:
logger.warning("spaCy model not found. Using spaCy's blank model.")
self.spacy_model = spacy.blank('en')
# Load transformer model if available
if model_path and os.path.exists(model_path):
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.relation_pipeline = pipeline(
"text-classification",
model=self.transformer_model,
tokenizer=self.tokenizer
)
logger.info(f"Loaded transformer model from {model_path}")
except Exception as e:
logger.error(f"Error loading transformer model from {model_path}: {e}")
else:
# Use a pre-trained model from Hugging Face
try:
model_name = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_name)
self.relation_pipeline = pipeline(
"text-classification",
model=self.transformer_model,
tokenizer=self.tokenizer
)
logger.info(f"Loaded pre-trained transformer model: {model_name}")
except Exception as e:
logger.error(f"Error loading pre-trained transformer model: {e}")
def _define_relation_patterns(self) -> Dict[str, List[Dict[str, Any]]]:
"""
Define patterns for rule-based relation extraction
Returns:
Dict[str, List[Dict[str, Any]]]: Dictionary of relation patterns
"""
patterns = {
"TREATS": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:treats|treat|treating|treated|therapy for|treatment for|used for|indicated for|prescribed for)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "MEDICATION",
"entity2_type": "DISEASE"
},
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is effective for|is used to treat|helps with|alleviates|relieves|cures)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "MEDICATION",
"entity2_type": "DISEASE"
}
],
"CAUSES": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:causes|cause|causing|caused|leads to|results in|induces|triggers)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "DISEASE",
"entity2_type": "SYMPTOM"
},
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is caused by|results from|is triggered by|is induced by|develops from)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "SYMPTOM",
"entity2_type": "DISEASE"
}
],
"DIAGNOSES": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:diagnoses|diagnose|diagnosing|diagnosed|confirms|confirms diagnosis of|indicates|is diagnostic of)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "TEST",
"entity2_type": "DISEASE"
}
],
"PREVENTS": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:prevents|prevent|preventing|prevented|protects against|reduces risk of|prophylaxis for)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "MEDICATION",
"entity2_type": "DISEASE"
}
],
"CONTRAINDICATES": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:contraindicates|contraindicate|contraindicating|contraindicated in|should not be used with|avoid in|not recommended for)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "MEDICATION",
"entity2_type": "DISEASE"
}
],
"SYMPTOM_OF": [
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:is a symptom of|symptom of|sign of|manifestation of|indicates|suggests|points to)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "SYMPTOM",
"entity2_type": "DISEASE"
},
{
"pattern": r"(?i)(\b\w+(?:\s+\w+){0,5})\s+(?:presents with|manifests as|shows|exhibits|displays|characterized by)\s+(\b\w+(?:\s+\w+){0,5})",
"entity1_type": "DISEASE",
"entity2_type": "SYMPTOM"
}
]
}
return patterns
def extract_relations_rule_based(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Extract relations between entities using rule-based patterns
Args:
text: Clinical text
entities: List of entities extracted from the text
Returns:
List[Dict[str, Any]]: List of extracted relations
"""
if not text or not entities:
return []
relations = []
# Create a dictionary of entities by type
entities_by_type = {}
for entity in entities:
entity_type = entity["label"]
if entity_type not in entities_by_type:
entities_by_type[entity_type] = []
entities_by_type[entity_type].append(entity)
# Apply relation patterns
for relation_type, patterns in self.relation_patterns.items():
for pattern_info in patterns:
pattern = pattern_info["pattern"]
entity1_type = pattern_info["entity1_type"]
entity2_type = pattern_info["entity2_type"]
# Find all matches of the pattern
for match in re.finditer(pattern, text):
entity1_text = match.group(1).strip()
entity2_text = match.group(2).strip()
# Find entities that match the extracted text
entity1_matches = []
entity2_matches = []
if entity1_type in entities_by_type:
for entity in entities_by_type[entity1_type]:
if entity1_text.lower() in entity["text"].lower() or entity["text"].lower() in entity1_text.lower():
entity1_matches.append(entity)
if entity2_type in entities_by_type:
for entity in entities_by_type[entity2_type]:
if entity2_text.lower() in entity["text"].lower() or entity["text"].lower() in entity2_text.lower():
entity2_matches.append(entity)
# Create relations for all combinations of matching entities
for entity1 in entity1_matches:
for entity2 in entity2_matches:
relation = {
"relation_type": relation_type,
"entity1": entity1,
"entity2": entity2,
"confidence": 0.8, # Default confidence for rule-based
"source": "rule-based",
"context": text[max(0, match.start() - 50):min(len(text), match.end() + 50)]
}
relations.append(relation)
return relations
def extract_relations_dependency(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Extract relations between entities using dependency parsing
Args:
text: Clinical text
entities: List of entities extracted from the text
Returns:
List[Dict[str, Any]]: List of extracted relations
"""
if not text or not entities or not self.spacy_model:
return []
relations = []
# Process text with spaCy
doc = self.spacy_model(text)
# Create a mapping from character spans to entities
span_to_entity = {}
for entity in entities:
start = entity["start"]
end = entity["end"]
span_to_entity[(start, end)] = entity
# Create a mapping from tokens to entities
token_to_entity = {}
for token in doc:
token_start = token.idx
token_end = token.idx + len(token.text)
for (start, end), entity in span_to_entity.items():
# Check if token is within entity span or overlaps significantly
if (token_start >= start and token_end <= end) or \
(token_start < end and token_end > start and min(token_end, end) - max(token_start, start) > 0.5 * (token_end - token_start)):
token_to_entity[token] = entity
break
# Define relation mappings based on dependency patterns
relation_mappings = {
"TREATS": [
{"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["treat", "cure", "heal", "alleviate", "relieve", "manage"]},
{"dep": "dobj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["treat", "cure", "heal", "alleviate", "relieve", "manage"]}
],
"CAUSES": [
{"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["cause", "lead", "result", "induce", "trigger"]},
{"dep": "dobj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["cause", "lead", "result", "induce", "trigger"]}
],
"SYMPTOM_OF": [
{"dep": "nsubj", "head_dep": "ROOT", "head_pos": "VERB", "head_lemma": ["indicate", "suggest", "imply", "point", "manifest"]},
{"dep": "prep", "head_dep": "ROOT", "head_pos": "NOUN", "head_text": ["symptom", "sign", "manifestation", "indication"]}
]
}
# Extract relations based on dependency patterns
for sent in doc.sents:
# Find the root of the sentence
root = None
for token in sent:
if token.dep_ == "ROOT":
root = token
break
if not root:
continue
# Check each relation mapping
for relation_type, patterns in relation_mappings.items():
for pattern in patterns:
# Find subject and object based on pattern
subject = None
object = None
# Check if root matches the pattern
if root.pos_ == pattern.get("head_pos") and \
(not pattern.get("head_lemma") or root.lemma_ in pattern.get("head_lemma")) and \
(not pattern.get("head_text") or root.text.lower() in pattern.get("head_text")):
# Find subject and object
for token in sent:
if token.dep_ == "nsubj" and token.head == root:
# Expand to include compound nouns
subject_span = self._expand_noun_phrase(token)
subject = subject_span
if token.dep_ == "dobj" and token.head == root:
# Expand to include compound nouns
object_span = self._expand_noun_phrase(token)
object = object_span
# If both subject and object are found and are entities
if subject and object and subject in token_to_entity and object in token_to_entity:
entity1 = token_to_entity[subject]
entity2 = token_to_entity[object]
relation = {
"relation_type": relation_type,
"entity1": entity1,
"entity2": entity2,
"confidence": 0.7, # Default confidence for dependency-based
"source": "dependency",
"context": sent.text
}
relations.append(relation)
return relations
def _expand_noun_phrase(self, token) -> Any:
"""
Expand a token to include its full noun phrase
Args:
token: spaCy token
Returns:
Any: Head token of the noun phrase
"""
# Start with the token itself
head_token = token
# Check for compound nouns
for child in token.children:
if child.dep_ == "compound" and child.i < token.i:
head_token = child
break
return head_token
def extract_relations_transformer(self, text: str, entities: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Extract relations between entities using transformer models
Args:
text: Clinical text
entities: List of entities extracted from the text
Returns:
List[Dict[str, Any]]: List of extracted relations
"""
if not text or not entities or not self.relation_pipeline:
return []
relations = []
# Group entities by type
entity_types = {
"DISEASE": ["DISEASE", "PROBLEM"],
"SYMPTOM": ["SYMPTOM", "FINDING"],
"MEDICATION": ["MEDICATION", "DRUG"],
"TREATMENT": ["TREATMENT", "PROCEDURE"],
"TEST": ["TEST", "PROCEDURE"]
}
entities_by_group = {}
for entity in entities:
for group, types in entity_types.items():
if entity["label"] in types:
if group not in entities_by_group:
entities_by_group[group] = []
entities_by_group[group].append(entity)
# Define relation candidates based on entity types
relation_candidates = [
{"type": "TREATS", "source": "MEDICATION", "target": "DISEASE"},
{"type": "CAUSES", "source": "DISEASE", "target": "SYMPTOM"},
{"type": "DIAGNOSES", "source": "TEST", "target": "DISEASE"},
{"type": "PREVENTS", "source": "MEDICATION", "target": "DISEASE"},
{"type": "SYMPTOM_OF", "source": "SYMPTOM", "target": "DISEASE"}
]
# Generate relation candidates
for candidate in relation_candidates:
source_type = candidate["source"]
target_type = candidate["target"]
relation_type = candidate["type"]
if source_type in entities_by_group and target_type in entities_by_group:
source_entities = entities_by_group[source_type]
target_entities = entities_by_group[target_type]
for source_entity in source_entities:
for target_entity in target_entities:
# Skip if entities are too far apart
if abs(source_entity["start"] - target_entity["start"]) > 200:
continue
# Create input for transformer
source_text = source_entity["text"]
target_text = target_entity["text"]
# Get context around entities
context_start = max(0, min(source_entity["start"], target_entity["start"]) - 50)
context_end = min(len(text), max(source_entity["end"], target_entity["end"]) + 50)
context = text[context_start:context_end]
# Replace entity mentions with special tokens
context = context.replace(source_text, f"[E1] {source_text} [/E1]")
context = context.replace(target_text, f"[E2] {target_text} [/E2]")
# Classify relation
try:
result = self.relation_pipeline(context)
label = result[0]["label"]
score = result[0]["score"]
# Check if score is above threshold
if score >= self.confidence_threshold:
relation = {
"relation_type": relation_type,
"entity1": source_entity,
"entity2": target_entity,
"confidence": score,
"source": "transformer",
"context": context
}
relations.append(relation)
except Exception as e:
logger.error(f"Error in transformer relation extraction: {e}")
return relations
def extract_relations(self, text: str, entities: List[Dict[str, Any]],
method: str = "hybrid") -> List[Dict[str, Any]]:
"""
Extract relations between entities in clinical text
Args:
text: Clinical text
entities: List of entities extracted from the text
method: Extraction method ('rule-based', 'dependency', 'transformer', or 'hybrid')
Returns:
List[Dict[str, Any]]: List of extracted relations
"""
if not text or not isinstance(text, str) or not entities:
logger.warning(f"Invalid input for relation extraction")
return []
if method == "rule-based":
return self.extract_relations_rule_based(text, entities)
elif method == "dependency":
return self.extract_relations_dependency(text, entities)
elif method == "transformer":
return self.extract_relations_transformer(text, entities)
elif method == "hybrid":
# Get relations from all methods
rule_based_relations = self.extract_relations_rule_based(text, entities)
dependency_relations = self.extract_relations_dependency(text, entities)
transformer_relations = self.extract_relations_transformer(text, entities)
# Combine and deduplicate relations
all_relations = rule_based_relations + dependency_relations + transformer_relations
# Simple deduplication based on relation type and entities
deduplicated = {}
for relation in all_relations:
entity1_text = relation["entity1"]["text"]
entity2_text = relation["entity2"]["text"]
relation_type = relation["relation_type"]
key = f"{relation_type}_{entity1_text}_{entity2_text}"
if key not in deduplicated or relation["confidence"] > deduplicated[key]["confidence"]:
deduplicated[key] = relation
return list(deduplicated.values())
else:
logger.error(f"Unknown relation extraction method: {method}")
return []
def build_knowledge_graph(self, relations: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Build a knowledge graph from extracted relations
Args:
relations: List of extracted relations
Returns:
Dict[str, Any]: Knowledge graph representation
"""
if not relations:
return {"nodes": [], "edges": []}
nodes = {}
edges = []
# Create nodes from entities in relations
for relation in relations:
entity1 = relation["entity1"]
entity2 = relation["entity2"]
# Add entity1 as node if not already added
if entity1["text"] not in nodes:
nodes[entity1["text"]] = {
"id": len(nodes),
"text": entity1["text"],
"type": entity1["label"]
}
# Add entity2 as node if not already added
if entity2["text"] not in nodes:
nodes[entity2["text"]] = {
"id": len(nodes),
"text": entity2["text"],
"type": entity2["label"]
}
# Add edge between entities
edge = {
"source": nodes[entity1["text"]]["id"],
"target": nodes[entity2["text"]]["id"],
"relation": relation["relation_type"],
"confidence": relation["confidence"]
}
edges.append(edge)
# Convert nodes dictionary to list
nodes_list = list(nodes.values())
return {
"nodes": nodes_list,
"edges": edges
}
def visualize_relations(self, text: str, relations: List[Dict[str, Any]],
output_path: Optional[str] = None) -> str:
"""
Create a visualization of extracted relations
Args:
text: Original text
relations: List of extracted relations
output_path: Optional path to save visualization
Returns:
str: HTML representation of relations visualization
"""
if not relations:
return "<div>No relations found</div>"
# Create HTML visualization
html = f"""
<div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; border: 1px solid #ddd; border-radius: 5px;">
<h2 style="color: #333;">Medical Relations Extraction</h2>
<div style="margin: 20px 0; padding: 15px; background-color: #f9f9f9; border-radius: 5px;">
<h3 style="margin-top: 0; color: #555;">Input Text:</h3>
<p style="line-height: 1.5;">{text[:500]}{'...' if len(text) > 500 else ''}</p>
</div>
<div style="margin: 20px 0;">
<h3 style="color: #555;">Extracted Relations:</h3>
<table style="width: 100%; border-collapse: collapse; margin-top: 10px;">
<thead>
<tr style="background-color: #f2f2f2;">
<th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Relation Type</th>
<th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Entity 1</th>
<th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Entity 2</th>
<th style="padding: 8px; text-align: left; border: 1px solid #ddd;">Confidence</th>
</tr>
</thead>
<tbody>
"""
# Add rows for each relation
for relation in relations:
entity1_text = relation["entity1"]["text"]
entity1_type = relation["entity1"]["label"]
entity2_text = relation["entity2"]["text"]
entity2_type = relation["entity2"]["label"]
relation_type = relation["relation_type"]
confidence = relation["confidence"]
html += f"""
<tr>
<td style="padding: 8px; text-align: left; border: 1px solid #ddd;">{relation_type}</td>
<td style="padding: 8px; text-align: left; border: 1px solid #ddd;">
<span style="font-weight: bold;">{entity1_text}</span>
<span style="color: #777; font-size: 12px;"> ({entity1_type})</span>
</td>
<td style="padding: 8px; text-align: left; border: 1px solid #ddd;">
<span style="font-weight: bold;">{entity2_text}</span>
<span style="color: #777; font-size: 12px;"> ({entity2_type})</span>
</td>
<td style="padding: 8px; text-align: left; border: 1px solid #ddd;">{confidence:.2f}</td>
</tr>
"""
html += """
</tbody>
</table>
</div>
</div>
"""
# Save to file if output path is provided
if output_path:
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(html)
logger.info(f"Visualization saved to {output_path}")
except Exception as e:
logger.error(f"Error saving visualization: {e}")
return html
modules\text_analysis\text_classifier.py
"""
Clinical Text Classifier Module
This module provides functionality for classifying clinical text data,
including disease classification, severity assessment, and risk prediction.
"""
import os
import json
import logging
import pickle
from typing import List, Dict, Any, Optional, Union, Tuple
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.naive_bayes import MultinomialNB
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding
)
from datasets import Dataset
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("clinical_text_classification.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class ClinicalTextClassifier:
"""
Class for classifying clinical text data
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the clinical text classifier
Args:
config_path: Path to the configuration file
"""
self.config = self._load_config(config_path)
self.model_type = self.config["classification"]["model_type"]
self.algorithm = self.config["classification"]["algorithm"]
# Initialize models
self.traditional_model = None
self.vectorizer = None
self.transformer_model = None
self.tokenizer = None
self.label_map = {}
self.inverse_label_map = {}
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"classification": {
"model_type": "traditional",
"algorithm": "svm",
"pretrained_model": "bert-base-uncased",
"batch_size": 16,
"learning_rate": 2e-5,
"epochs": 3,
"max_length": 512
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def _create_traditional_model(self) -> Any:
"""
Create a traditional machine learning model based on configuration
Returns:
Any: Initialized model
"""
if self.algorithm == "svm":
return SVC(probability=True, kernel='linear')
elif self.algorithm == "random_forest":
return RandomForestClassifier(n_estimators=100, random_state=42)
elif self.algorithm == "naive_bayes":
return MultinomialNB()
else:
logger.warning(f"Unknown algorithm: {self.algorithm}, using SVM as default")
return SVC(probability=True, kernel='linear')
def train_traditional_model(self, texts: List[str], labels: List[str],
test_size: float = 0.2) -> Dict[str, Any]:
"""
Train a traditional machine learning model for text classification
Args:
texts: List of clinical texts
labels: List of corresponding labels
test_size: Proportion of data to use for testing
Returns:
Dict[str, Any]: Training results
"""
if len(texts) != len(labels):
logger.error("Number of texts and labels must be equal")
return {"success": False, "error": "Number of texts and labels must be equal"}
if len(texts) == 0:
logger.error("Empty training data")
return {"success": False, "error": "Empty training data"}
try:
# Create label mapping
unique_labels = sorted(set(labels))
self.label_map = {label: i for i, label in enumerate(unique_labels)}
self.inverse_label_map = {i: label for label, i in self.label_map.items()}
# Convert labels to numeric
numeric_labels = [self.label_map[label] for label in labels]
# Split data
X_train, X_test, y_train, y_test = train_test_split(
texts, numeric_labels, test_size=test_size, random_state=42, stratify=numeric_labels
)
# Create and fit vectorizer
self.vectorizer = TfidfVectorizer(max_features=5000)
X_train_vec = self.vectorizer.fit_transform(X_train)
X_test_vec = self.vectorizer.transform(X_test)
# Create and train model
self.traditional_model = self._create_traditional_model()
self.traditional_model.fit(X_train_vec, y_train)
# Evaluate model
y_pred = self.traditional_model.predict(X_test_vec)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='weighted')
# Generate classification report
report = classification_report(y_test, y_pred, target_names=unique_labels, output_dict=True)
logger.info(f"Model trained with accuracy: {accuracy:.4f}, F1 score: {f1:.4f}")
return {
"success": True,
"accuracy": accuracy,
"f1_score": f1,
"report": report,
"num_classes": len(unique_labels),
"classes": unique_labels,
"num_samples": len(texts),
"num_train": len(X_train),
"num_test": len(X_test)
}
except Exception as e:
logger.error(f"Error training traditional model: {e}")
return {"success": False, "error": str(e)}
def train_transformer_model(self, texts: List[str], labels: List[str],
output_dir: str, test_size: float = 0.2) -> Dict[str, Any]:
"""
Train a transformer model for text classification
Args:
texts: List of clinical texts
labels: List of corresponding labels
output_dir: Directory to save the trained model
test_size: Proportion of data to use for testing
Returns:
Dict[str, Any]: Training results
"""
if len(texts) != len(labels):
logger.error("Number of texts and labels must be equal")
return {"success": False, "error": "Number of texts and labels must be equal"}
if len(texts) == 0:
logger.error("Empty training data")
return {"success": False, "error": "Empty training data"}
try:
# Create label mapping
unique_labels = sorted(set(labels))
self.label_map = {label: i for i, label in enumerate(unique_labels)}
self.inverse_label_map = {i: label for label, i in self.label_map.items()}
# Convert labels to numeric
numeric_labels = [self.label_map[label] for label in labels]
# Split data
X_train, X_test, y_train, y_test = train_test_split(
texts, numeric_labels, test_size=test_size, random_state=42, stratify=numeric_labels
)
# Load tokenizer and model
model_name = self.config["classification"]["pretrained_model"]
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=len(unique_labels)
)
# Tokenize data
def tokenize_function(examples):
return self.tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=self.config["classification"]["max_length"]
)
# Create datasets
train_dataset = Dataset.from_dict({
"text": X_train,
"label": y_train
})
test_dataset = Dataset.from_dict({
"text": X_test,
"label": y_test
})
# Tokenize datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
# Data collator
data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
# Training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=self.config["classification"]["epochs"],
per_device_train_batch_size=self.config["classification"]["batch_size"],
per_device_eval_batch_size=self.config["classification"]["batch_size"],
learning_rate=self.config["classification"]["learning_rate"],
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
)
# Initialize trainer
trainer = Trainer(
model=self.transformer_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
tokenizer=self.tokenizer,
data_collator=data_collator,
)
# Train model
trainer.train()
# Evaluate model
eval_results = trainer.evaluate()
# Save model and tokenizer
self.transformer_model.save_pretrained(output_dir)
self.tokenizer.save_pretrained(output_dir)
# Save label mapping
with open(os.path.join(output_dir, "label_map.json"), "w") as f:
json.dump(self.label_map, f)
logger.info(f"Transformer model trained and saved to {output_dir}")
return {
"success": True,
"eval_loss": eval_results["eval_loss"],
"num_classes": len(unique_labels),
"classes": unique_labels,
"num_samples": len(texts),
"num_train": len(X_train),
"num_test": len(X_test),
"model_path": output_dir
}
except Exception as e:
logger.error(f"Error training transformer model: {e}")
return {"success": False, "error": str(e)}
def load_model(self, model_path: str) -> bool:
"""
Load a trained model
Args:
model_path: Path to the trained model
Returns:
bool: Success status
"""
try:
if os.path.isdir(model_path):
# Try to load as transformer model
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(model_path)
# Load label mapping
label_map_path = os.path.join(model_path, "label_map.json")
if os.path.exists(label_map_path):
with open(label_map_path, "r") as f:
self.label_map = json.load(f)
self.inverse_label_map = {int(i): label for label, i in self.label_map.items()}
self.model_type = "transformer"
logger.info(f"Loaded transformer model from {model_path}")
return True
except Exception as e:
logger.error(f"Error loading transformer model: {e}")
return False
else:
# Try to load as traditional model
with open(model_path, "rb") as f:
model_data = pickle.load(f)
self.traditional_model = model_data["model"]
self.vectorizer = model_data["vectorizer"]
self.label_map = model_data["label_map"]
self.inverse_label_map = model_data["inverse_label_map"]
self.model_type = "traditional"
self.algorithm = model_data.get("algorithm", "unknown")
logger.info(f"Loaded traditional model from {model_path}")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
def save_traditional_model(self, output_path: str) -> bool:
"""
Save a trained traditional model
Args:
output_path: Path to save the model
Returns:
bool: Success status
"""
if not self.traditional_model or not self.vectorizer:
logger.error("No trained model to save")
return False
try:
model_data = {
"model": self.traditional_model,
"vectorizer": self.vectorizer,
"label_map": self.label_map,
"inverse_label_map": self.inverse_label_map,
"algorithm": self.algorithm
}
with open(output_path, "wb") as f:
pickle.dump(model_data, f)
logger.info(f"Model saved to {output_path}")
return True
except Exception as e:
logger.error(f"Error saving model: {e}")
return False
def predict(self, text: str) -> Dict[str, Any]:
"""
Predict class for clinical text
Args:
text: Clinical text
Returns:
Dict[str, Any]: Prediction results
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return {"success": False, "error": "Invalid input text"}
try:
if self.model_type == "transformer" and self.transformer_model and self.tokenizer:
# Tokenize input
inputs = self.tokenizer(
text,
padding="max_length",
truncation=True,
max_length=self.config["classification"]["max_length"],
return_tensors="pt"
)
# Get prediction
with torch.no_grad():
outputs = self.transformer_model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
# Get predicted class and probabilities
predicted_class_id = torch.argmax(probabilities, dim=1).item()
predicted_class = self.inverse_label_map.get(predicted_class_id, "Unknown")
# Convert probabilities to dictionary
probs_dict = {
self.inverse_label_map.get(i, f"Class_{i}"): prob.item()
for i, prob in enumerate(probabilities[0])
}
return {
"success": True,
"predicted_class": predicted_class,
"confidence": probabilities[0][predicted_class_id].item(),
"probabilities": probs_dict
}
elif self.model_type == "traditional" and self.traditional_model and self.vectorizer:
# Vectorize input
X_vec = self.vectorizer.transform([text])
# Get prediction
predicted_class_id = self.traditional_model.predict(X_vec)[0]
predicted_class = self.inverse_label_map.get(predicted_class_id, "Unknown")
# Get probabilities if available
probabilities = {}
if hasattr(self.traditional_model, "predict_proba"):
probs = self.traditional_model.predict_proba(X_vec)[0]
probabilities = {
self.inverse_label_map.get(i, f"Class_{i}"): prob
for i, prob in enumerate(probs)
}
confidence = probs[predicted_class_id]
else:
confidence = 1.0
return {
"success": True,
"predicted_class": predicted_class,
"confidence": confidence,
"probabilities": probabilities
}
else:
logger.error("No trained model available")
return {"success": False, "error": "No trained model available"}
except Exception as e:
logger.error(f"Error in prediction: {e}")
return {"success": False, "error": str(e)}
def batch_predict(self, texts: List[str]) -> List[Dict[str, Any]]:
"""
Predict classes for a batch of clinical texts
Args:
texts: List of clinical texts
Returns:
List[Dict[str, Any]]: List of prediction results
"""
if not texts:
logger.warning("Empty input list")
return []
results = []
for text in texts:
result = self.predict(text)
results.append(result)
return results
def evaluate(self, texts: List[str], labels: List[str]) -> Dict[str, Any]:
"""
Evaluate model on test data
Args:
texts: List of clinical texts
labels: List of corresponding labels
Returns:
Dict[str, Any]: Evaluation results
"""
if len(texts) != len(labels):
logger.error("Number of texts and labels must be equal")
return {"success": False, "error": "Number of texts and labels must be equal"}
if len(texts) == 0:
logger.error("Empty evaluation data")
return {"success": False, "error": "Empty evaluation data"}
try:
# Convert labels to numeric if needed
if all(label in self.label_map for label in labels):
numeric_labels = [self.label_map[label] for label in labels]
else:
# Create new label mapping for unknown labels
unique_labels = sorted(set(labels))
label_map = {label: i for i, label in enumerate(unique_labels)}
numeric_labels = [label_map[label] for label in labels]
predictions = []
for text in texts:
result = self.predict(text)
if result["success"]:
predicted_label = result["predicted_class"]
predictions.append(predicted_label)
else:
logger.warning(f"Prediction failed for text: {text[:50]}...")
predictions.append("Unknown")
# Convert predictions to numeric
numeric_predictions = []
for pred in predictions:
if pred in self.label_map:
numeric_predictions.append(self.label_map[pred])
else:
# Assign a default value for unknown predictions
numeric_predictions.append(-1)
# Calculate metrics
accuracy = accuracy_score([n for n in numeric_labels if n >= 0],
[p for i, p in enumerate(numeric_predictions) if numeric_labels[i] >= 0])
# Generate classification report
label_names = [self.inverse_label_map.get(i, f"Class_{i}") for i in set(numeric_labels)]
report = classification_report(numeric_labels, numeric_predictions,
target_names=label_names, output_dict=True)
logger.info(f"Evaluation completed with accuracy: {accuracy:.4f}")
return {
"success": True,
"accuracy": accuracy,
"report": report,
"num_samples": len(texts)
}
except Exception as e:
logger.error(f"Error in evaluation: {e}")
return {"success": False, "error": str(e)}
def visualize_predictions(self, text: str, prediction: Dict[str, Any],
output_path: Optional[str] = None) -> str:
"""
Create a visualization of prediction results
Args:
text: Original text
prediction: Prediction result
output_path: Optional path to save visualization
Returns:
str: HTML representation of prediction visualization
"""
if not prediction["success"]:
return f"<div>Error: {prediction.get('error', 'Unknown error')}</div>"
# Create HTML visualization
html = f"""
<div style="font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; border: 1px solid #ddd; border-radius: 5px;">
<h2 style="color: #333;">Text Classification Result</h2>
<div style="margin: 20px 0; padding: 15px; background-color: #f9f9f9; border-radius: 5px;">
<h3 style="margin-top: 0; color: #555;">Input Text:</h3>
<p style="line-height: 1.5;">{text[:500]}{'...' if len(text) > 500 else ''}</p>
</div>
<div style="margin: 20px 0;">
<h3 style="color: #555;">Prediction:</h3>
<div style="display: flex; align-items: center;">
<div style="font-size: 18px; font-weight: bold; color: #2c3e50; flex: 1;">
{prediction["predicted_class"]}
</div>
<div style="background-color: #3498db; color: white; padding: 5px 10px; border-radius: 20px; font-weight: bold;">
Confidence: {prediction["confidence"]:.2%}
</div>
</div>
</div>
"""
# Add probability distribution if available
if "probabilities" in prediction and prediction["probabilities"]:
html += """
<div style="margin: 20px 0;">
<h3 style="color: #555;">Class Probabilities:</h3>
<div style="display: flex; flex-direction: column; gap: 10px;">
"""
# Sort probabilities by value (descending)
sorted_probs = sorted(prediction["probabilities"].items(), key=lambda x: x[1], reverse=True)
for class_name, prob in sorted_probs:
# Calculate width based on probability
width = int(prob * 100)
# Choose color based on whether it's the predicted class
color = "#3498db" if class_name == prediction["predicted_class"] else "#95a5a6"
html += f"""
<div style="display: flex; align-items: center; gap: 10px;">
<div style="width: 150px; font-weight: {
'bold' if class_name == prediction["predicted_class"] else 'normal'
};">{class_name}</div>
<div style="flex: 1; background-color: #eee; border-radius: 5px; height: 20px;">
<div style="width: {width}%; background-color: {color}; height: 100%; border-radius: 5px;"></div>
</div>
<div style="width: 60px; text-align: right;">{prob:.2%}</div>
</div>
"""
html += """
</div>
</div>
"""
html += """
</div>
"""
# Save to file if output path is provided
if output_path:
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(html)
logger.info(f"Visualization saved to {output_path}")
except Exception as e:
logger.error(f"Error saving visualization: {e}")
return html
modules\text_analysis\text_preprocessor.py
"""
Clinical Text Preprocessor Module
This module provides functionality for preprocessing clinical text data,
including tokenization, normalization, and feature extraction.
"""
import re
import json
import logging
import string
import os
from typing import List, Dict, Any, Optional, Union, Set
import nltk
import numpy as np
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer, WordNetLemmatizer
import spacy
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
# Download required NLTK resources
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
nltk.data.find('corpora/wordnet')
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("clinical_text_processing.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class ClinicalTextPreprocessor:
"""
Class for preprocessing clinical text data
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the clinical text preprocessor
Args:
config_path: Path to the configuration file
"""
self.config = self._load_config(config_path)
self.stop_words = set(stopwords.words('english'))
self.stemmer = PorterStemmer()
self.lemmatizer = WordNetLemmatizer()
# Add medical stop words
self.medical_stop_words = {
'patient', 'doctor', 'hospital', 'clinic', 'medical', 'medicine',
'treatment', 'therapy', 'diagnosis', 'prognosis', 'report', 'date',
'name', 'id', 'mr', 'mrs', 'dr', 'test', 'result', 'normal', 'abnormal'
}
self.stop_words.update(self.medical_stop_words)
# Load spaCy model for advanced NLP tasks
try:
self.nlp = spacy.load('en_core_web_sm')
logger.info("Loaded spaCy model: en_core_web_sm")
except OSError:
logger.warning("spaCy model not found. Using spaCy's default model.")
self.nlp = spacy.blank('en')
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"preprocessing": {
"remove_stopwords": True,
"stemming": True,
"lemmatization": True,
"lowercase": True,
"remove_punctuation": True,
"remove_numbers": False,
"min_token_length": 2
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def preprocess_text(self, text: str) -> str:
"""
Preprocess clinical text
Args:
text: Input clinical text
Returns:
str: Preprocessed text
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return ""
# Convert to lowercase if configured
if self.config["preprocessing"]["lowercase"]:
text = text.lower()
# Remove punctuation if configured
if self.config["preprocessing"]["remove_punctuation"]:
text = re.sub(f'[{string.punctuation}]', ' ', text)
# Remove numbers if configured
if self.config["preprocessing"]["remove_numbers"]:
text = re.sub(r'\d+', '', text)
# Tokenize
tokens = word_tokenize(text)
# Remove stopwords if configured
if self.config["preprocessing"]["remove_stopwords"]:
tokens = [t for t in tokens if t not in self.stop_words]
# Apply stemming if configured
if self.config["preprocessing"]["stemming"]:
tokens = [self.stemmer.stem(t) for t in tokens]
# Apply lemmatization if configured
if self.config["preprocessing"]["lemmatization"]:
tokens = [self.lemmatizer.lemmatize(t) for t in tokens]
# Filter by token length
min_length = self.config["preprocessing"]["min_token_length"]
tokens = [t for t in tokens if len(t) >= min_length]
# Join tokens back into text
preprocessed_text = ' '.join(tokens)
return preprocessed_text
def extract_sentences(self, text: str) -> List[str]:
"""
Extract sentences from clinical text
Args:
text: Input clinical text
Returns:
List[str]: List of sentences
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return []
# Use NLTK's sentence tokenizer
sentences = sent_tokenize(text)
return sentences
def extract_medical_terms(self, text: str) -> List[str]:
"""
Extract medical terms from clinical text using rule-based approach
Args:
text: Input clinical text
Returns:
List[str]: List of medical terms
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return []
# Process text with spaCy
doc = self.nlp(text)
# Extract noun phrases as potential medical terms
medical_terms = []
for chunk in doc.noun_chunks:
# Filter out common non-medical terms
if chunk.text.lower() not in self.stop_words:
medical_terms.append(chunk.text)
return medical_terms
def extract_features(self, texts: List[str], method: str = 'tfidf',
max_features: int = 1000) -> np.ndarray:
"""
Extract features from clinical texts
Args:
texts: List of clinical texts
method: Feature extraction method ('tfidf', 'count', 'binary')
max_features: Maximum number of features
Returns:
np.ndarray: Feature matrix
"""
if not texts:
logger.warning("Empty text list provided for feature extraction")
return np.array([])
# Preprocess texts
preprocessed_texts = [self.preprocess_text(text) for text in texts]
# Extract features based on method
if method == 'tfidf':
vectorizer = TfidfVectorizer(max_features=max_features)
elif method == 'count':
vectorizer = CountVectorizer(max_features=max_features)
elif method == 'binary':
vectorizer = CountVectorizer(max_features=max_features, binary=True)
else:
logger.error(f"Unknown feature extraction method: {method}")
return np.array([])
# Fit and transform
try:
feature_matrix = vectorizer.fit_transform(preprocessed_texts)
feature_names = vectorizer.get_feature_names_out()
logger.info(f"Extracted {len(feature_names)} features using {method} method")
return feature_matrix
except Exception as e:
logger.error(f"Error in feature extraction: {e}")
return np.array([])
def normalize_medical_concepts(self, terms: List[str]) -> Dict[str, str]:
"""
Normalize medical concepts to standard terminology
Args:
terms: List of medical terms
Returns:
Dict[str, str]: Mapping from original terms to normalized terms
"""
# This is a simplified implementation
# In a real system, this would connect to UMLS or SNOMED CT
# Example mapping for common variations
normalization_map = {
"heart attack": "myocardial infarction",
"heart failure": "cardiac failure",
"high blood pressure": "hypertension",
"low blood pressure": "hypotension",
"stroke": "cerebrovascular accident",
"sugar": "glucose",
"cancer": "malignant neoplasm",
"chest pain": "angina pectoris"
}
result = {}
for term in terms:
term_lower = term.lower()
if term_lower in normalization_map:
result[term] = normalization_map[term_lower]
else:
result[term] = term
return result
def detect_section_headers(self, text: str) -> List[Dict[str, Any]]:
"""
Detect clinical note section headers
Args:
text: Clinical note text
Returns:
List[Dict[str, Any]]: List of detected sections with positions
"""
# Common section headers in clinical notes
section_patterns = [
r"(?i)(?:^|\n)[\s]*CHIEF\s+COMPLAINT[\s:]*",
r"(?i)(?:^|\n)[\s]*HISTORY\s+(?:OF\s+)?(?:PRESENT\s+)?ILLNESS[\s:]*",
r"(?i)(?:^|\n)[\s]*PAST\s+MEDICAL\s+HISTORY[\s:]*",
r"(?i)(?:^|\n)[\s]*MEDICATIONS[\s:]*",
r"(?i)(?:^|\n)[\s]*ALLERGIES[\s:]*",
r"(?i)(?:^|\n)[\s]*FAMILY\s+HISTORY[\s:]*",
r"(?i)(?:^|\n)[\s]*SOCIAL\s+HISTORY[\s:]*",
r"(?i)(?:^|\n)[\s]*PHYSICAL\s+EXAMINATION[\s:]*",
r"(?i)(?:^|\n)[\s]*VITAL\s+SIGNS[\s:]*",
r"(?i)(?:^|\n)[\s]*LABORATORY\s+(?:DATA|RESULTS)[\s:]*",
r"(?i)(?:^|\n)[\s]*ASSESSMENT[\s:]*",
r"(?i)(?:^|\n)[\s]*PLAN[\s:]*",
r"(?i)(?:^|\n)[\s]*IMPRESSION[\s:]*",
r"(?i)(?:^|\n)[\s]*DIAGNOSIS[\s:]*"
]
sections = []
for pattern in section_patterns:
for match in re.finditer(pattern, text):
section_name = match.group(0).strip()
start_pos = match.start()
sections.append({
"name": section_name,
"start": start_pos,
"text": section_name
})
# Sort sections by their position in the text
sections.sort(key=lambda x: x["start"])
# Extract section content
for i in range(len(sections)):
start = sections[i]["start"] + len(sections[i]["text"])
end = sections[i+1]["start"] if i < len(sections) - 1 else len(text)
sections[i]["content"] = text[start:end].strip()
return sections
def extract_demographics(self, text: str) -> Dict[str, Any]:
"""
Extract patient demographics from clinical text
Args:
text: Clinical text
Returns:
Dict[str, Any]: Extracted demographics
"""
demographics = {
"age": None,
"gender": None,
"race": None
}
# Age patterns
age_patterns = [
r"(?i)(?:age|aged)[\s:]+(\d+)(?:\s+years?)?",
r"(?i)(\d+)[\s-]*(?:year|y)(?:s|\.)?\s+old",
r"(?i)(\d+)[\s-]*y\.?o\.?"
]
for pattern in age_patterns:
match = re.search(pattern, text)
if match:
demographics["age"] = int(match.group(1))
break
# Gender patterns
male_patterns = [r"(?i)\b(?:male|m)\b", r"(?i)gender[\s:]+male"]
female_patterns = [r"(?i)\b(?:female|f)\b", r"(?i)gender[\s:]+female"]
for pattern in male_patterns:
if re.search(pattern, text):
demographics["gender"] = "male"
break
if demographics["gender"] is None:
for pattern in female_patterns:
if re.search(pattern, text):
demographics["gender"] = "female"
break
# Race patterns (simplified)
race_patterns = {
"caucasian": [r"(?i)\b(?:caucasian|white)\b"],
"african american": [r"(?i)\b(?:african american|black)\b"],
"asian": [r"(?i)\b(?:asian)\b"],
"hispanic": [r"(?i)\b(?:hispanic|latino)\b"],
"other": [r"(?i)\b(?:other race|mixed race)\b"]
}
for race, patterns in race_patterns.items():
for pattern in patterns:
if re.search(pattern, text):
demographics["race"] = race
break
if demographics["race"]:
break
return demographics
modules\text_analysis\text_summarizer.py
"""
Clinical Text Summarizer Module
This module provides functionality for summarizing clinical text data,
including extractive and abstractive summarization methods.
"""
import os
import json
import logging
import re
from typing import List, Dict, Any, Optional, Union, Tuple
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
pipeline
)
# Download required NLTK resources
try:
nltk.data.find('tokenizers/punkt')
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('punkt')
nltk.download('stopwords')
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("clinical_text_summarization.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class ClinicalTextSummarizer:
"""
Class for summarizing clinical text data
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the clinical text summarizer
Args:
config_path: Path to the configuration file
"""
self.config = self._load_config(config_path)
self.stop_words = set(stopwords.words('english'))
# Add medical stop words
self.medical_stop_words = {
'patient', 'doctor', 'hospital', 'clinic', 'medical', 'medicine',
'treatment', 'therapy', 'diagnosis', 'prognosis', 'report', 'date',
'name', 'id', 'mr', 'mrs', 'dr', 'test', 'result', 'normal', 'abnormal'
}
self.stop_words.update(self.medical_stop_words)
# Initialize transformer models
self.tokenizer = None
self.model = None
self.summarization_pipeline = None
# Load transformer model if specified
if self.config["summarization"]["model"]:
self._load_transformer_model(self.config["summarization"]["model"])
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"summarization": {
"max_length": 150,
"min_length": 50,
"model": "t5-small",
"early_stopping": True
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def _load_transformer_model(self, model_name: str) -> None:
"""
Load transformer model for summarization
Args:
model_name: Name or path of the model
"""
try:
# Map model name to actual model if it's a custom name
model_mapping = {
"medical-t5": "google/flan-t5-base",
"clinical-t5": "google/flan-t5-base"
}
actual_model = model_mapping.get(model_name, model_name)
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(actual_model)
self.model = AutoModelForSeq2SeqLM.from_pretrained(actual_model)
# Create summarization pipeline
self.summarization_pipeline = pipeline(
"summarization",
model=self.model,
tokenizer=self.tokenizer,
framework="pt"
)
logger.info(f"Loaded transformer model: {actual_model}")
except Exception as e:
logger.error(f"Error loading transformer model: {e}")
self.tokenizer = None
self.model = None
self.summarization_pipeline = None
def extractive_summarize(self, text: str, ratio: float = 0.3,
min_length: int = 3, max_length: int = 10) -> str:
"""
Generate an extractive summary of clinical text
Args:
text: Clinical text to summarize
ratio: Proportion of sentences to include in summary
min_length: Minimum number of sentences in summary
max_length: Maximum number of sentences in summary
Returns:
str: Extractive summary
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return ""
try:
# Split text into sentences
sentences = sent_tokenize(text)
if len(sentences) <= min_length:
return text
# Calculate number of sentences in summary
num_sentences = max(min_length, min(max_length, int(len(sentences) * ratio)))
# Create TF-IDF vectorizer
vectorizer = TfidfVectorizer(stop_words=self.stop_words)
# Get sentence vectors
sentence_vectors = vectorizer.fit_transform(sentences)
# Calculate sentence scores based on TF-IDF
sentence_scores = {}
for i, sentence in enumerate(sentences):
# Sum the TF-IDF scores of all terms in the sentence
sentence_scores[i] = np.sum(sentence_vectors[i].toarray())
# Get top sentences
top_sentence_indices = sorted(sentence_scores, key=sentence_scores.get, reverse=True)[:num_sentences]
# Sort indices to maintain original order
top_sentence_indices = sorted(top_sentence_indices)
# Create summary
summary = " ".join([sentences[i] for i in top_sentence_indices])
return summary
except Exception as e:
logger.error(f"Error in extractive summarization: {e}")
return text[:500] + "..." if len(text) > 500 else text
def abstractive_summarize(self, text: str) -> str:
"""
Generate an abstractive summary of clinical text using transformer models
Args:
text: Clinical text to summarize
Returns:
str: Abstractive summary
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return ""
if not self.summarization_pipeline:
logger.warning("Transformer model not available, falling back to extractive summarization")
return self.extractive_summarize(text)
try:
# Truncate text if it's too long
max_tokens = 1024 # Typical limit for transformer models
if len(text.split()) > max_tokens:
logger.info(f"Text too long ({len(text.split())} tokens), truncating to {max_tokens} tokens")
text = " ".join(text.split()[:max_tokens])
# Generate summary
summary = self.summarization_pipeline(
text,
max_length=self.config["summarization"]["max_length"],
min_length=self.config["summarization"]["min_length"],
do_sample=False,
early_stopping=self.config["summarization"]["early_stopping"]
)
# Extract summary text
summary_text = summary[0]["summary_text"]
return summary_text
except Exception as e:
logger.error(f"Error in abstractive summarization: {e}")
return self.extractive_summarize(text)
def summarize(self, text: str, method: str = "hybrid") -> Dict[str, Any]:
"""
Generate a summary of clinical text
Args:
text: Clinical text to summarize
method: Summarization method ('extractive', 'abstractive', or 'hybrid')
Returns:
Dict[str, Any]: Summarization results
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return {"success": False, "error": "Invalid input text"}
try:
if method == "extractive":
summary = self.extractive_summarize(text)
summary_type = "extractive"
elif method == "abstractive":
if self.summarization_pipeline:
summary = self.abstractive_summarize(text)
summary_type = "abstractive"
else:
logger.warning("Transformer model not available, falling back to extractive summarization")
summary = self.extractive_summarize(text)
summary_type = "extractive (fallback)"
elif method == "hybrid":
# For hybrid, first do extractive to reduce size, then abstractive
if self.summarization_pipeline:
extractive_summary = self.extractive_summarize(text, ratio=0.5)
summary = self.abstractive_summarize(extractive_summary)
summary_type = "hybrid"
else:
summary = self.extractive_summarize(text)
summary_type = "extractive (fallback)"
else:
logger.error(f"Unknown summarization method: {method}")
return {"success": False, "error": f"Unknown summarization method: {method}"}
# Calculate compression ratio
original_length = len(text.split())
summary_length = len(summary.split())
compression_ratio = summary_length / original_length if original_length > 0 else 0
return {
"success": True,
"summary": summary,
"method": summary_type,
"original_length": original_length,
"summary_length": summary_length,
"compression_ratio": compression_ratio
}
except Exception as e:
logger.error(f"Error in summarization: {e}")
return {"success": False, "error": str(e)}
def section_based_summarize(self, text: str, sections: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Generate summaries for each section of a clinical note
Args:
text: Full clinical text
sections: List of section dictionaries with name, start, and content
Returns:
Dict[str, Any]: Section-based summarization results
"""
if not text or not isinstance(text, str) or not sections:
logger.warning(f"Invalid input for section-based summarization")
return {"success": False, "error": "Invalid input"}
try:
section_summaries = {}
for section in sections:
section_name = section["name"]
section_content = section["content"]
# Skip empty sections or very short sections
if not section_content or len(section_content.split()) < 10:
section_summaries[section_name] = section_content
continue
# Summarize section content
summary_result = self.summarize(section_content, method="hybrid")
if summary_result["success"]:
section_summaries[section_name] = summary_result["summary"]
else:
section_summaries[section_name] = section_content
# Combine section summaries
combined_summary = ""
for section_name, summary in section_summaries.items():
combined_summary += f"{section_name}\n{summary}\n\n"
return {
"success": True,
"combined_summary": combined_summary,
"section_summaries": section_summaries
}
except Exception as e:
logger.error(f"Error in section-based summarization: {e}")
return {"success": False, "error": str(e)}
def key_findings_extract(self, text: str) -> Dict[str, Any]:
"""
Extract key findings from clinical text
Args:
text: Clinical text
Returns:
Dict[str, Any]: Extracted key findings
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return {"success": False, "error": "Invalid input text"}
try:
# Define patterns for key findings
patterns = {
"diagnosis": [
r"(?i)diagnosis[\s:]+([^\n\.]+)",
r"(?i)impression[\s:]+([^\n\.]+)",
r"(?i)assessment[\s:]+([^\n\.]+)"
],
"medications": [
r"(?i)medications?[\s:]+([^\n]+)",
r"(?i)prescribed[\s:]+([^\n]+)",
r"(?i)taking[\s:]+([^\n]+)"
],
"allergies": [
r"(?i)allergies[\s:]+([^\n]+)",
r"(?i)allergic to[\s:]+([^\n]+)"
],
"vital_signs": [
r"(?i)vital signs[\s:]+([^\n]+)",
r"(?i)temperature[\s:]+([^\n\.,]+)",
r"(?i)blood pressure[\s:]+([^\n\.,]+)",
r"(?i)pulse[\s:]+([^\n\.,]+)",
r"(?i)respiratory rate[\s:]+([^\n\.,]+)"
],
"lab_results": [
r"(?i)lab(?:oratory)? results?[\s:]+([^\n]+)",
r"(?i)labs?[\s:]+([^\n]+)"
],
"procedures": [
r"(?i)procedures?[\s:]+([^\n]+)",
r"(?i)surgery[\s:]+([^\n]+)",
r"(?i)operation[\s:]+([^\n]+)"
]
}
findings = {}
# Extract findings using patterns
for category, category_patterns in patterns.items():
category_findings = []
for pattern in category_patterns:
matches = re.finditer(pattern, text)
for match in matches:
finding = match.group(1).strip()
if finding and finding not in category_findings:
category_findings.append(finding)
findings[category] = category_findings
# Format findings as text
formatted_findings = ""
for category, category_findings in findings.items():
if category_findings:
formatted_findings += f"{category.replace('_', ' ').title()}:\n"
for finding in category_findings:
formatted_findings += f"- {finding}\n"
formatted_findings += "\n"
return {
"success": True,
"findings": findings,
"formatted_findings": formatted_findings
}
except Exception as e:
logger.error(f"Error extracting key findings: {e}")
return {"success": False, "error": str(e)}
def generate_discharge_summary(self, text: str) -> Dict[str, Any]:
"""
Generate a discharge summary from clinical notes
Args:
text: Clinical text
Returns:
Dict[str, Any]: Generated discharge summary
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return {"success": False, "error": "Invalid input text"}
try:
# Extract key findings
findings_result = self.key_findings_extract(text)
if not findings_result["success"]:
return findings_result
findings = findings_result["findings"]
# Generate summary of the full text
summary_result = self.summarize(text, method="hybrid")
if not summary_result["success"]:
return summary_result
summary = summary_result["summary"]
# Combine findings and summary into discharge summary
discharge_summary = "DISCHARGE SUMMARY\n\n"
# Add key findings
for category in ["diagnosis", "medications", "allergies", "procedures"]:
if category in findings and findings[category]:
discharge_summary += f"{category.replace('_', ' ').title()}:\n"
for finding in findings[category]:
discharge_summary += f"- {finding}\n"
discharge_summary += "\n"
# Add summary
discharge_summary += "Clinical Summary:\n"
discharge_summary += summary
return {
"success": True,
"discharge_summary": discharge_summary,
"findings": findings,
"clinical_summary": summary
}
except Exception as e:
logger.error(f"Error generating discharge summary: {e}")
return {"success": False, "error": str(e)}
def visualize_summary(self, original_text: str, summary: str,
output_path: Optional[str] = None) -> str:
"""
Create a visualization comparing original text and summary
Args:
original_text: Original clinical text
summary: Generated summary
output_path: Optional path to save visualization
Returns:
str: HTML representation of summary visualization
"""
# Create HTML visualization
html = f"""
<div style="font-family: Arial, sans-serif; max-width: 1000px; margin: 0 auto; padding: 20px;">
<h2 style="color: #333; text-align: center;">Clinical Text Summarization</h2>
<div style="display: flex; gap: 20px; margin-top: 30px;">
<div style="flex: 1; padding: 15px; border: 1px solid #ddd; border-radius: 5px;">
<h3 style="color: #555; margin-top: 0;">Original Text</h3>
<p style="line-height: 1.5; white-space: pre-wrap;">{original_text[:2000]}{'...' if len(original_text) > 2000 else ''}</p>
<div style="color: #777; font-size: 14px; margin-top: 10px;">
Word count: {len(original_text.split())}
</div>
</div>
<div style="flex: 1; padding: 15px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;">
<h3 style="color: #555; margin-top: 0;">Summary</h3>
<p style="line-height: 1.5; white-space: pre-wrap;">{summary}</p>
<div style="color: #777; font-size: 14px; margin-top: 10px;">
Word count: {len(summary.split())}
<br>
Compression ratio: {len(summary.split()) / len(original_text.split()):.2%}
</div>
</div>
</div>
</div>
"""
# Save to file if output path is provided
if output_path:
try:
with open(output_path, "w", encoding="utf-8") as f:
f.write(html)
logger.info(f"Visualization saved to {output_path}")
except Exception as e:
logger.error(f"Error saving visualization: {e}")
return html
modules\user_interface\app.py
"""
User Interface Module - Main Application
This module provides the main Flask application for the intelligent medical diagnostic system,
including routes and views for both doctor and patient interfaces.
"""
import os
import json
import logging
from datetime import datetime
from pathlib import Path
from flask import Flask, render_template, request, redirect, url_for, flash, jsonify, session
from werkzeug.utils import secure_filename
import sys
# Add parent directory to path to import modules
sys.path.append(str(Path(__file__).parent.parent.parent))
from modules.diagnosis_engine.diagnosis_engine import DiagnosisEngine
from modules.user_interface.auth import auth_blueprint, login_required, is_doctor
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize Flask application
app = Flask(__name__,
static_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "static"),
template_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates"))
# Load configuration
app.config.from_pyfile(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config.py'))
# Register blueprints
app.register_blueprint(auth_blueprint)
# Initialize diagnosis engine
diagnosis_engine = DiagnosisEngine()
# File upload settings
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'bmp', 'dcm', 'dicom', 'txt', 'pdf', 'csv', 'json'}
UPLOAD_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'uploads')
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
def allowed_file(filename):
"""Check if the file extension is allowed."""
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/')
def index():
"""Render the home page."""
return render_template('index.html')
@app.route('/dashboard')
@login_required
def dashboard():
"""Render the dashboard based on user role."""
if is_doctor():
return redirect(url_for('doctor_dashboard'))
else:
return redirect(url_for('patient_dashboard'))
@app.route('/doctor/dashboard')
@login_required
def doctor_dashboard():
"""Render the doctor dashboard."""
if not is_doctor():
flash('您没有权限访问医生控制面板', 'error')
return redirect(url_for('dashboard'))
# Mock data for demonstration
recent_cases = [
{
'id': 'CASE001',
'patient_name': '张三',
'age': 65,
'gender': '男',
'primary_diagnosis': '肺炎',
'confidence': 0.85,
'date': '2023-05-15'
},
{
'id': 'CASE002',
'patient_name': '李四',
'age': 42,
'gender': '女',
'primary_diagnosis': '冠心病',
'confidence': 0.78,
'date': '2023-05-14'
},
{
'id': 'CASE003',
'patient_name': '王五',
'age': 58,
'gender': '男',
'primary_diagnosis': '2型糖尿病',
'confidence': 0.92,
'date': '2023-05-13'
}
]
statistics = {
'total_cases': 152,
'pending_review': 8,
'completed_today': 12,
'system_accuracy': 0.87
}
return render_template('doctor/dashboard.html',
recent_cases=recent_cases,
statistics=statistics)
@app.route('/patient/dashboard')
@login_required
def patient_dashboard():
"""Render the patient dashboard."""
# Mock data for demonstration
medical_records = [
{
'id': 'REC001',
'date': '2023-05-15',
'doctor': '张医生',
'diagnosis': '上呼吸道感染',
'treatment': '抗生素治疗'
},
{
'id': 'REC002',
'date': '2023-04-10',
'doctor': '李医生',
'diagnosis': '胃炎',
'treatment': '质子泵抑制剂'
}
]
upcoming_appointments = [
{
'id': 'APT001',
'date': '2023-05-20',
'time': '10:30',
'doctor': '王医生',
'department': '内科'
}
]
return render_template('patient/dashboard.html',
medical_records=medical_records,
upcoming_appointments=upcoming_appointments)
@app.route('/doctor/new_case', methods=['GET', 'POST'])
@login_required
def new_case():
"""Handle new case creation."""
if not is_doctor():
flash('您没有权限创建新病例', 'error')
return redirect(url_for('dashboard'))
if request.method == 'POST':
# Process form data
patient_data = {
'demographics': {
'name': request.form.get('patient_name'),
'age': int(request.form.get('patient_age')),
'gender': request.form.get('patient_gender'),
'height': float(request.form.get('patient_height')) if request.form.get('patient_height') else None,
'weight': float(request.form.get('patient_weight')) if request.form.get('patient_weight') else None
},
'medical_history': request.form.get('medical_history', '').split(','),
'allergies': request.form.get('allergies', '').split(',')
}
# Process file uploads
image_analysis = None
text_analysis = None
lab_results = None
if 'medical_image' in request.files:
image_file = request.files['medical_image']
if image_file and allowed_file(image_file.filename):
filename = secure_filename(image_file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
image_file.save(file_path)
# In a real application, this would call the image analysis module
# For demo, we'll use mock data
image_analysis = {
"modality": request.form.get('image_modality', ''),
"body_part": request.form.get('body_part', ''),
"findings": request.form.get('image_findings', '').split('\n'),
"confidence": 0.85
}
if 'clinical_text' in request.files:
text_file = request.files['clinical_text']
if text_file and allowed_file(text_file.filename):
filename = secure_filename(text_file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
text_file.save(file_path)
# In a real application, this would call the text analysis module
# For demo, we'll use mock data based on form input
text_analysis = {
"document_type": "临床病历",
"content": request.form.get('clinical_notes', ''),
"entities": [], # Would be populated by NLP module
"confidence": 0.9
}
if 'lab_results' in request.files:
lab_file = request.files['lab_results']
if lab_file and allowed_file(lab_file.filename):
filename = secure_filename(lab_file.filename)
file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
lab_file.save(file_path)
# In a real application, this would parse the lab results file
# For demo, we'll use mock data
lab_results = {
"results": {}, # Would be populated by parsing the file
"abnormal_results": {},
"confidence": 0.95
}
# Generate diagnosis
try:
diagnosis_results = diagnosis_engine.diagnose(
patient_data, image_analysis, text_analysis, lab_results
)
# Generate report
report_format = request.form.get('report_format', 'html')
report = diagnosis_engine.generate_report(
diagnosis_results, patient_data, report_format
)
# Save report
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
report_filename = f"report_{timestamp}.{report_format}"
report_path = os.path.join(app.config['UPLOAD_FOLDER'], report_filename)
with open(report_path, 'w', encoding='utf-8') as f:
f.write(report)
# Store case ID in session for redirect
case_id = f"CASE{timestamp}"
session['last_case_id'] = case_id
flash('诊断报告已生成', 'success')
return redirect(url_for('view_case', case_id=case_id))
except Exception as e:
logger.error(f"Error generating diagnosis: {e}")
flash(f'生成诊断报告时出错: {str(e)}', 'error')
return render_template('doctor/new_case.html')
@app.route('/doctor/case/<case_id>')
@login_required
def view_case(case_id):
"""View a specific case."""
if not is_doctor():
flash('您没有权限查看病例', 'error')
return redirect(url_for('dashboard'))
# In a real application, this would fetch the case from a database
# For demo, we'll use mock data
if case_id == session.get('last_case_id'):
# This would be the most recently created case
case = {
'id': case_id,
'patient_name': '张三',
'age': 65,
'gender': '男',
'primary_diagnosis': '肺炎',
'confidence': 0.85,
'differential_diagnoses': [
{'disease': '慢性支气管炎', 'confidence': 0.65},
{'disease': '肺结核', 'confidence': 0.45}
],
'evidence': {
'肺炎': [
{'item': '右肺下叶有不规则阴影', 'type': '影像发现'},
{'item': '体温38.2°C', 'type': '体征'},
{'item': '白细胞计数升高', 'type': '实验室检查'}
]
},
'date': datetime.now().strftime("%Y-%m-%d")
}
else:
# Mock data for other cases
case = {
'id': case_id,
'patient_name': '李四',
'age': 42,
'gender': '女',
'primary_diagnosis': '冠心病',
'confidence': 0.78,
'differential_diagnoses': [
{'disease': '心肌炎', 'confidence': 0.55},
{'disease': '胸壁疼痛', 'confidence': 0.40}
],
'evidence': {
'冠心病': [
{'item': 'ST段压低', 'type': 'ECG发现'},
{'item': '胸痛', 'type': '症状'},
{'item': '家族史', 'type': '病史'}
]
},
'date': '2023-05-14'
}
return render_template('doctor/view_case.html', case=case)
@app.route('/doctor/patients')
@login_required
def patient_list():
"""View list of patients."""
if not is_doctor():
flash('您没有权限查看患者列表', 'error')
return redirect(url_for('dashboard'))
# Mock data for demonstration
patients = [
{
'id': 'PAT001',
'name': '张三',
'age': 65,
'gender': '男',
'last_visit': '2023-05-15',
'primary_diagnosis': '肺炎'
},
{
'id': 'PAT002',
'name': '李四',
'age': 42,
'gender': '女',
'last_visit': '2023-05-14',
'primary_diagnosis': '冠心病'
},
{
'id': 'PAT003',
'name': '王五',
'age': 58,
'gender': '男',
'last_visit': '2023-05-13',
'primary_diagnosis': '2型糖尿病'
}
]
return render_template('doctor/patients.html', patients=patients)
@app.route('/patient/medical_records')
@login_required
def medical_records():
"""View patient's medical records."""
# Mock data for demonstration
records = [
{
'id': 'REC001',
'date': '2023-05-15',
'doctor': '张医生',
'diagnosis': '上呼吸道感染',
'treatment': '抗生素治疗',
'notes': '症状有所改善,继续观察'
},
{
'id': 'REC002',
'date': '2023-04-10',
'doctor': '李医生',
'diagnosis': '胃炎',
'treatment': '质子泵抑制剂',
'notes': '建议清淡饮食,避免辛辣刺激'
},
{
'id': 'REC003',
'date': '2023-02-22',
'doctor': '王医生',
'diagnosis': '季节性过敏',
'treatment': '抗组胺药物',
'notes': '建议避免过敏原,保持室内通风'
}
]
return render_template('patient/medical_records.html', records=records)
@app.route('/patient/appointments')
@login_required
def appointments():
"""View and manage patient's appointments."""
# Mock data for demonstration
upcoming = [
{
'id': 'APT001',
'date': '2023-05-20',
'time': '10:30',
'doctor': '王医生',
'department': '内科',
'reason': '复诊'
}
]
past = [
{
'id': 'APT002',
'date': '2023-05-15',
'time': '14:00',
'doctor': '张医生',
'department': '呼吸科',
'reason': '咳嗽、发热',
'status': '已完成'
},
{
'id': 'APT003',
'date': '2023-04-10',
'time': '09:15',
'doctor': '李医生',
'department': '消化科',
'reason': '腹痛',
'status': '已完成'
}
]
return render_template('patient/appointments.html',
upcoming_appointments=upcoming,
past_appointments=past)
@app.route('/api/diagnose', methods=['POST'])
@login_required
def api_diagnose():
"""API endpoint for diagnosis."""
if not request.is_json:
return jsonify({'error': 'Request must be JSON'}), 400
data = request.get_json()
# Extract data from request
patient_data = data.get('patient_data', {})
image_analysis = data.get('image_analysis')
text_analysis = data.get('text_analysis')
lab_results = data.get('lab_results')
try:
# Generate diagnosis
diagnosis_results = diagnosis_engine.diagnose(
patient_data, image_analysis, text_analysis, lab_results
)
return jsonify(diagnosis_results)
except Exception as e:
logger.error(f"API error: {e}")
return jsonify({'error': str(e)}), 500
@app.errorhandler(404)
def page_not_found(e):
"""Handle 404 errors."""
return render_template('errors/404.html'), 404
@app.errorhandler(500)
def server_error(e):
"""Handle 500 errors."""
logger.error(f"Server error: {e}")
return render_template('errors/500.html'), 500
if __name__ == '__main__':
app.run(debug=True, host='0.0.0.0', port=5000)
modules\user_interface\auth.py
"""
User Interface Module - Authentication
This module provides authentication functionality for the intelligent medical diagnostic system,
including user registration, login, and role-based access control.
"""
import os
import json
import logging
import functools
from datetime import datetime, timedelta
from flask import Blueprint, render_template, request, redirect, url_for, flash, session, g
from werkzeug.security import generate_password_hash, check_password_hash
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize blueprint
auth_blueprint = Blueprint('auth', __name__)
# Mock user database (in a real application, this would be a database)
USERS_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'users.json')
os.makedirs(os.path.dirname(USERS_FILE), exist_ok=True)
def get_users():
"""Get users from JSON file."""
if os.path.exists(USERS_FILE):
try:
with open(USERS_FILE, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading users: {e}")
# Default users if file doesn't exist
return {
"doctor1": {
"username": "doctor1",
"password": generate_password_hash("password123"),
"name": "张医生",
"role": "doctor",
"department": "内科",
"title": "主任医师",
"email": "doctor1@example.com"
},
"doctor2": {
"username": "doctor2",
"password": generate_password_hash("password123"),
"name": "李医生",
"role": "doctor",
"department": "外科",
"title": "副主任医师",
"email": "doctor2@example.com"
},
"patient1": {
"username": "patient1",
"password": generate_password_hash("password123"),
"name": "张三",
"role": "patient",
"age": 65,
"gender": "男",
"email": "patient1@example.com"
},
"patient2": {
"username": "patient2",
"password": generate_password_hash("password123"),
"name": "李四",
"role": "patient",
"age": 42,
"gender": "女",
"email": "patient2@example.com"
}
}
def save_users(users):
"""Save users to JSON file."""
try:
os.makedirs(os.path.dirname(USERS_FILE), exist_ok=True)
with open(USERS_FILE, 'w', encoding='utf-8') as f:
json.dump(users, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
logger.error(f"Error saving users: {e}")
return False
# Initialize users file if it doesn't exist
if not os.path.exists(USERS_FILE):
save_users(get_users())
def login_required(view):
"""Decorator to require login for views."""
@functools.wraps(view)
def wrapped_view(**kwargs):
if 'user_id' not in session:
flash('请先登录', 'error')
return redirect(url_for('auth.login'))
return view(**kwargs)
return wrapped_view
def is_doctor():
"""Check if current user is a doctor."""
if 'user_id' not in session:
return False
users = get_users()
user = users.get(session['user_id'])
if not user:
return False
return user.get('role') == 'doctor'
@auth_blueprint.before_app_request
def load_logged_in_user():
"""Load user data before each request."""
user_id = session.get('user_id')
if user_id is None:
g.user = None
else:
users = get_users()
g.user = users.get(user_id)
@auth_blueprint.route('/login', methods=['GET', 'POST'])
def login():
"""Handle user login."""
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
users = get_users()
user = users.get(username)
error = None
if user is None:
error = '用户名不存在'
elif not check_password_hash(user['password'], password):
error = '密码不正确'
if error is None:
# Clear session and set user_id
session.clear()
session['user_id'] = username
# Log login
logger.info(f"User {username} logged in")
# Redirect based on role
if user.get('role') == 'doctor':
return redirect(url_for('doctor_dashboard'))
else:
return redirect(url_for('patient_dashboard'))
flash(error, 'error')
return render_template('auth/login.html')
@auth_blueprint.route('/register', methods=['GET', 'POST'])
def register():
"""Handle user registration."""
if request.method == 'POST':
username = request.form.get('username')
password = request.form.get('password')
confirm_password = request.form.get('confirm_password')
name = request.form.get('name')
email = request.form.get('email')
role = request.form.get('role')
users = get_users()
error = None
if not username:
error = '用户名不能为空'
elif not password:
error = '密码不能为空'
elif password != confirm_password:
error = '两次输入的密码不一致'
elif username in users:
error = f'用户名 {username} 已被注册'
if error is None:
# Create new user
new_user = {
'username': username,
'password': generate_password_hash(password),
'name': name,
'email': email,
'role': role,
'created_at': datetime.now().isoformat()
}
# Add role-specific fields
if role == 'doctor':
new_user['department'] = request.form.get('department', '')
new_user['title'] = request.form.get('title', '')
elif role == 'patient':
new_user['age'] = int(request.form.get('age', 0))
new_user['gender'] = request.form.get('gender', '')
# Save user
users[username] = new_user
if save_users(users):
flash('注册成功,请登录', 'success')
return redirect(url_for('auth.login'))
else:
error = '注册失败,请稍后再试'
flash(error, 'error')
return render_template('auth/register.html')
@auth_blueprint.route('/logout')
def logout():
"""Handle user logout."""
session.clear()
flash('您已成功退出登录', 'success')
return redirect(url_for('index'))
@auth_blueprint.route('/profile')
@login_required
def profile():
"""View and edit user profile."""
return render_template('auth/profile.html')
@auth_blueprint.route('/profile/edit', methods=['GET', 'POST'])
@login_required
def edit_profile():
"""Edit user profile."""
if request.method == 'POST':
users = get_users()
user = users.get(session['user_id'])
if not user:
flash('用户不存在', 'error')
return redirect(url_for('auth.profile'))
# Update user information
user['name'] = request.form.get('name', user['name'])
user['email'] = request.form.get('email', user['email'])
# Update role-specific fields
if user.get('role') == 'doctor':
user['department'] = request.form.get('department', user.get('department', ''))
user['title'] = request.form.get('title', user.get('title', ''))
elif user.get('role') == 'patient':
user['age'] = int(request.form.get('age', user.get('age', 0)))
user['gender'] = request.form.get('gender', user.get('gender', ''))
# Check if password is being updated
new_password = request.form.get('new_password')
confirm_password = request.form.get('confirm_password')
if new_password:
if new_password != confirm_password:
flash('两次输入的密码不一致', 'error')
return redirect(url_for('auth.edit_profile'))
user['password'] = generate_password_hash(new_password)
# Save updated user information
if save_users(users):
flash('个人信息已更新', 'success')
else:
flash('更新失败,请稍后再试', 'error')
return redirect(url_for('auth.profile'))
return render_template('auth/edit_profile.html')
@auth_blueprint.route('/reset_password', methods=['GET', 'POST'])
def reset_password():
"""Handle password reset request."""
if request.method == 'POST':
username = request.form.get('username')
email = request.form.get('email')
users = get_users()
user = users.get(username)
if user and user.get('email') == email:
# In a real application, this would send a password reset email
# For demo, we'll just reset the password directly
new_password = 'resetpassword123'
user['password'] = generate_password_hash(new_password)
if save_users(users):
flash(f'密码已重置为: {new_password},请登录后立即修改密码', 'success')
return redirect(url_for('auth.login'))
else:
flash('密码重置失败,请稍后再试', 'error')
else:
flash('用户名或邮箱不正确', 'error')
return render_template('auth/reset_password.html')