"""
Medical Image Classifier Module
This module provides classes and functions for classifying medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.
Key features:
- Traditional machine learning classifiers
- Deep learning models for medical image classification
- Model training, evaluation, and prediction
- Support for transfer learning with pre-trained models
"""
import os
import logging
import numpy as np
import json
import pickle
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MedicalImageClassifier:
"""Base class for medical image classification"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the medical image classifier
Args:
config_path: Path to configuration file
"""
self.config = {}
if config_path and os.path.isfile(config_path):
try:
with open(config_path, 'r') as f:
self.config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
except Exception as e:
logger.error(f"Error loading configuration: {str(e)}")
# Default configuration
if not self.config:
self.config = {
"model": {
"type": "traditional", # traditional or deep_learning
"algorithm": "svm", # svm, random_forest, etc. for traditional
"architecture": "resnet50", # For deep learning
"pretrained": True,
"batch_size": 32,
"learning_rate": 0.001,
"epochs": 10
},
"preprocessing": {
"target_size": [224, 224],
"normalize": True,
"augmentation": True
}
}
self.model = None
self.classes = []
self.is_trained = False
def preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess an image for classification
Args:
image: Input image
Returns:
np.ndarray: Preprocessed image
"""
if image is None or image.size == 0:
logger.error("Cannot preprocess empty image")
return np.array([])
# Resize to target size
target_size = self.config["preprocessing"]["target_size"]
if image.shape[0] != target_size[0] or image.shape[1] != target_size[1]:
image = cv2.resize(image, (target_size[1], target_size[0]))
# Convert to RGB if grayscale
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
# Normalize pixel values
if self.config["preprocessing"]["normalize"]:
if image.dtype != np.float32 and image.dtype != np.float64:
image = image.astype(np.float32)
if np.max(image) > 1.0:
image = image / 255.0
return image
def extract_features(self, image: np.ndarray) -> np.ndarray:
"""
Extract features from an image for traditional ML models
Args:
image: Preprocessed image
Returns:
np.ndarray: Extracted features
"""
# This is a placeholder for feature extraction
# For traditional ML, we need to implement specific feature extraction methods
features = []
# Basic statistical features
if len(image.shape) == 3:
# For RGB images, calculate stats for each channel
for channel in range(image.shape[2]):
channel_data = image[:, :, channel]
features.extend([
np.mean(channel_data),
np.std(channel_data),
np.min(channel_data),
np.max(channel_data)
])
else:
# For grayscale images
features.extend([
np.mean(image),
np.std(image),
np.min(image),
np.max(image)
])
# Histogram features
hist_features = self._extract_histogram_features(image)
features.extend(hist_features)
# HOG features (simplified)
try:
from skimage.feature import hog
# Convert to uint8 for HOG
img_uint8 = (image * 255).astype(np.uint8)
if len(img_uint8.shape) == 3:
img_uint8 = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
hog_features = hog(img_uint8, orientations=9, pixels_per_cell=(8, 8),
cells_per_block=(2, 2), block_norm='L2-Hys', feature_vector=True)
# Take a subset of HOG features to keep the feature vector manageable
step = max(1, len(hog_features) // 100)
features.extend(hog_features[::step])
except Exception as e:
logger.warning(f"Error extracting HOG features: {str(e)}")
return np.array(features)
def _extract_histogram_features(self, image: np.ndarray) -> List[float]:
"""
Extract histogram features from an image
Args:
image: Input image
Returns:
List[float]: Histogram features
"""
hist_features = []
# Convert to uint8 for histogram calculation
img_uint8 = (image * 255).astype(np.uint8)
if len(img_uint8.shape) == 3:
# For RGB images, calculate histogram for each channel
for channel in range(img_uint8.shape[2]):
hist = cv2.calcHist([img_uint8], [channel], None, [32], [0, 256])
hist = hist.flatten() / np.sum(hist) # Normalize
hist_features.extend(hist)
else:
# For grayscale images
hist = cv2.calcHist([img_uint8], [0], None, [32], [0, 256])
hist = hist.flatten() / np.sum(hist) # Normalize
hist_features.extend(hist)
return hist_features
def train(self, images: List[np.ndarray], labels: List[str],
validation_split: float = 0.2) -> Dict[str, Any]:
"""
Train the classifier on a set of images
Args:
images: List of input images
labels: List of corresponding labels
validation_split: Fraction of data to use for validation
Returns:
Dict[str, Any]: Training results
"""
if len(images) != len(labels):
logger.error("Number of images and labels must match")
return {"success": False, "error": "Number of images and labels must match"}
if len(images) == 0:
logger.error("No training data provided")
return {"success": False, "error": "No training data provided"}
# Get unique classes
self.classes = sorted(list(set(labels)))
logger.info(f"Training classifier with {len(self.classes)} classes: {self.classes}")
# Preprocess images
processed_images = []
for image in images:
processed = self.preprocess_image(image)
processed_images.append(processed)
# Split into training and validation sets
if validation_split > 0:
split_idx = int(len(images) * (1 - validation_split))
train_images = processed_images[:split_idx]
train_labels = labels[:split_idx]
val_images = processed_images[split_idx:]
val_labels = labels[split_idx:]
else:
train_images = processed_images
train_labels = labels
val_images = []
val_labels = []
# Train based on model type
if self.config["model"]["type"] == "traditional":
return self._train_traditional(train_images, train_labels, val_images, val_labels)
else: # deep_learning
return self._train_deep_learning(train_images, train_labels, val_images, val_labels)
def _train_traditional(self, train_images: List[np.ndarray], train_labels: List[str],
val_images: List[np.ndarray], val_labels: List[str]) -> Dict[str, Any]:
"""
Train a traditional machine learning model
Args:
train_images: Training images
train_labels: Training labels
val_images: Validation images
val_labels: Validation labels
Returns:
Dict[str, Any]: Training results
"""
try:
from sklearn.preprocessing import LabelEncoder
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, classification_report
# Extract features from images
logger.info("Extracting features from training images...")
train_features = np.array([self.extract_features(img) for img in train_images])
# Encode labels
label_encoder = LabelEncoder()
label_encoder.fit(self.classes)
y_train = label_encoder.transform(train_labels)
# Select and train the model
algorithm = self.config["model"]["algorithm"]
if algorithm == "svm":
model = SVC(probability=True)
elif algorithm == "random_forest":
model = RandomForestClassifier(n_estimators=100)
elif algorithm == "knn":
model = KNeighborsClassifier(n_neighbors=5)
else:
logger.error(f"Unsupported algorithm: {algorithm}")
return {"success": False, "error": f"Unsupported algorithm: {algorithm}"}
logger.info(f"Training {algorithm} classifier...")
model.fit(train_features, y_train)
self.model = model
# Save label encoder
self.label_encoder = label_encoder
# Evaluate on training set
y_pred = model.predict(train_features)
train_accuracy = accuracy_score(y_train, y_pred)
results = {
"success": True,
"model_type": "traditional",
"algorithm": algorithm,
"train_accuracy": train_accuracy,
"classes": self.classes
}
# Evaluate on validation set if available
if len(val_images) > 0:
logger.info("Evaluating on validation set...")
val_features = np.array([self.extract_features(img) for img in val_images])
y_val = label_encoder.transform(val_labels)
y_val_pred = model.predict(val_features)
val_accuracy = accuracy_score(y_val, y_val_pred)
results["val_accuracy"] = val_accuracy
results["classification_report"] = classification_report(y_val, y_val_pred,
target_names=self.classes,
output_dict=True)
self.is_trained = True
logger.info(f"Model training completed with train accuracy: {train_accuracy:.4f}")
return results
except Exception as e:
logger.error(f"Error training traditional model: {str(e)}")
return {"success": False, "error": str(e)}
def _train_deep_learning(self, train_images: List[np.ndarray], train_labels: List[str],
val_images: List[np.ndarray], val_labels: List[str]) -> Dict[str, Any]:
"""
Train a deep learning model
Args:
train_images: Training images
train_labels: Training labels
val_images: Validation images
val_labels: Validation labels
Returns:
Dict[str, Any]: Training results
"""
try:
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16, MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
# Encode labels
label_encoder = LabelEncoder()
label_encoder.fit(self.classes)
y_train = label_encoder.transform(train_labels)
y_train_cat = to_categorical(y_train, num_classes=len(self.classes))
if len(val_images) > 0:
y_val = label_encoder.transform(val_labels)
y_val_cat = to_categorical(y_val, num_classes=len(self.classes))
# Convert images to numpy arrays
X_train = np.array(train_images)
if len(val_images) > 0:
X_val = np.array(val_images)
# Create model
architecture = self.config["model"]["architecture"]
input_shape = (*self.config["preprocessing"]["target_size"], 3)
if architecture == "resnet50":
base_model = ResNet50(weights='imagenet' if self.config["model"]["pretrained"] else None,
include_top=False, input_shape=input_shape)
elif architecture == "vgg16":
base_model = VGG16(weights='imagenet' if self.config["model"]["pretrained"] else None,
include_top=False, input_shape=input_shape)
elif architecture == "mobilenetv2":
base_model = MobileNetV2(weights='imagenet' if self.config["model"]["pretrained"] else None,
include_top=False, input_shape=input_shape)
else:
logger.error(f"Unsupported architecture: {architecture}")
return {"success": False, "error": f"Unsupported architecture: {architecture}"}
# Add classification head
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(len(self.classes), activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)
# Freeze base model layers if using pretrained weights
if self.config["model"]["pretrained"]:
for layer in base_model.layers:
layer.trainable = False
# Compile model
model.compile(optimizer=Adam(learning_rate=self.config["model"]["learning_rate"]),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Train model
logger.info(f"Training {architecture} model...")
if len(val_images) > 0:
history = model.fit(
X_train, y_train_cat,
epochs=self.config["model"]["epochs"],
batch_size=self.config["model"]["batch_size"],
validation_data=(X_val, y_val_cat),
verbose=1
)
val_loss, val_accuracy = model.evaluate(X_val, y_val_cat)
else:
history = model.fit(
X_train, y_train_cat,
epochs=self.config["model"]["epochs"],
batch_size=self.config["model"]["batch_size"],
validation_split=0.1,
verbose=1
)
val_loss, val_accuracy = None, None
self.model = model
self.label_encoder = label_encoder
self.is_trained = True
# Get training metrics
train_loss, train_accuracy = model.evaluate(X_train, y_train_cat)
results = {
"success": True,
"model_type": "deep_learning",
"architecture": architecture,
"train_accuracy": float(train_accuracy),
"train_loss": float(train_loss),
"classes": self.classes
}
if val_loss is not None and val_accuracy is not None:
results["val_accuracy"] = float(val_accuracy)
results["val_loss"] = float(val_loss)
# Add training history
results["history"] = {
"accuracy": [float(acc) for acc in history.history['accuracy']],
"loss": [float(loss) for loss in history.history['loss']]
}
if 'val_accuracy' in history.history:
results["history"]["val_accuracy"] = [float(acc) for acc in history.history['val_accuracy']]
results["history"]["val_loss"] = [float(loss) for loss in history.history['val_loss']]
logger.info(f"Model training completed with train accuracy: {train_accuracy:.4f}")
return results
except Exception as e:
logger.error(f"Error training deep learning model: {str(e)}")
return {"success": False, "error": str(e)}
def predict(self, image: np.ndarray) -> Dict[str, Any]:
"""
Predict the class of an image
Args:
image: Input image
Returns:
Dict[str, Any]: Prediction results
"""
if not self.is_trained or self.model is None:
logger.error("Model not trained")
return {"success": False, "error": "Model not trained"}
try:
# Preprocess image
processed_image = self.preprocess_image(image)
if self.config["model"]["type"] == "traditional":
# Extract features
features = self.extract_features(processed_image)
features = features.reshape(1, -1)
# Predict
pred_proba = self.model.predict_proba(features)[0]
pred_class_idx = np.argmax(pred_proba)
pred_class = self.classes[pred_class_idx]
# Get probabilities for all classes
class_probabilities = {self.classes[i]: float(prob) for i, prob in enumerate(pred_proba)}
return {
"success": True,
"predicted_class": pred_class,
"confidence": float(pred_proba[pred_class_idx]),
"probabilities": class_probabilities
}
else: # deep_learning
# Prepare image for model
input_image = np.expand_dims(processed_image, axis=0)
# Predict
pred_proba = self.model.predict(input_image)[0]
pred_class_idx = np.argmax(pred_proba)
pred_class = self.classes[pred_class_idx]
# Get probabilities for all classes
class_probabilities = {self.classes[i]: float(prob) for i, prob in enumerate(pred_proba)}
return {
"success": True,
"predicted_class": pred_class,
"confidence": float(pred_proba[pred_class_idx]),
"probabilities": class_probabilities
}
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
return {"success": False, "error": str(e)}
def save_model(self, model_path: str) -> bool:
"""
Save the trained model to a file
Args:
model_path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
if not self.is_trained or self.model is None:
logger.error("Cannot save untrained model")
return False
try:
os.makedirs(os.path.dirname(model_path), exist_ok=True)
# Save model metadata
metadata = {
"config": self.config,
"classes": self.classes,
"model_type": self.config["model"]["type"]
}
metadata_path = os.path.splitext(model_path)[0] + "_metadata.json"
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
# Save the actual model
if self.config["model"]["type"] == "traditional":
# Save scikit-learn model
with open(model_path, 'wb') as f:
pickle.dump({
"model": self.model,
"label_encoder": self.label_encoder
}, f)
else:
# Save deep learning model
self.model.save(model_path)
logger.info(f"Model saved to {model_path}")
return True
except Exception as e:
logger.error(f"Error saving model: {str(e)}")
return False
def load_model(self, model_path: str) -> bool:
"""
Load a trained model from a file
Args:
model_path: Path to the saved model
Returns:
bool: True if successful, False otherwise
"""
try:
# Load model metadata
metadata_path = os.path.splitext(model_path)[0] + "_metadata.json"
if not os.path.isfile(metadata_path):
logger.error(f"Model metadata not found: {metadata_path}")
return False
with open(metadata_path, 'r') as f:
metadata = json.load(f)
self.config = metadata["config"]
self.classes = metadata["classes"]
# Load the actual model
if metadata["model_type"] == "traditional":
# Load scikit-learn model
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
self.model = model_data["model"]
self.label_encoder = model_data["label_encoder"]
else:
# Load deep learning model
import tensorflow as tf
self.model = tf.keras.models.load_model(model_path)
self.is_trained = True
logger.info(f"Model loaded from {model_path}")
return True
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def visualize_predictions(self, image: np.ndarray, save_path: Optional[str] = None) -> None:
"""
Visualize the prediction results
Args:
image: Input image
save_path: Optional path to save the visualization
"""
if not self.is_trained or self.model is None:
logger.error("Model not trained")
return
# Get prediction
prediction = self.predict(image)
if not prediction["success"]:
logger.error(f"Prediction failed: {prediction.get('error', 'Unknown error')}")
return
# Preprocess image for display
processed_image = self.preprocess_image(image)
plt.figure(figsize=(10, 6))
# Display image
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(processed_image)
plt.axis('off')
# Display prediction probabilities
plt.subplot(1, 2, 2)
plt.title(f"Prediction: {prediction['predicted_class']}\nConfidence: {prediction['confidence']:.2f}")
classes = list(prediction["probabilities"].keys())
probs = list(prediction["probabilities"].values())
# Sort by probability
sorted_idx = np.argsort(probs)[::-1]
classes = [classes[i] for i in sorted_idx]
probs = [probs[i] for i in sorted_idx]
plt.barh(range(len(classes)), probs, tick_label=classes)
plt.xlim(0, 1)
plt.xlabel('Probability')
plt.tight_layout()
if save_path:
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Visualization saved to {save_path}")
except Exception as e:
logger.error(f"Error saving visualization: {str(e)}")
plt.close()
class ChestXRayClassifier(MedicalImageClassifier):
"""Specialized classifier for chest X-ray images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the chest X-ray classifier
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
# Set default configuration for chest X-ray classification
self.config.setdefault("xray", {})
self.config["xray"].setdefault("lung_segmentation", True)
self.config["xray"].setdefault("common_conditions", [
"normal", "pneumonia", "tuberculosis", "covid-19", "lung cancer"
])
def preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess a chest X-ray image
Args:
image: Input chest X-ray image
Returns:
np.ndarray: Preprocessed image
"""
# Apply base preprocessing
image = super().preprocess_image(image)
# Apply X-ray specific preprocessing
if self.config["xray"].get("lung_segmentation", False):
# This would ideally use a lung segmentation model
# For simplicity, we'll just apply some basic enhancements
# Convert to grayscale if RGB
if len(image.shape) == 3 and image.shape[2] == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# Apply CLAHE for contrast enhancement
if np.max(gray) <= 1.0:
gray = (gray * 255).astype(np.uint8)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Convert back to RGB if needed
if len(image.shape) == 3 and image.shape[2] == 3:
enhanced = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
enhanced = enhanced.astype(np.float32) / 255.0
return enhanced
return image
class BrainMRIClassifier(MedicalImageClassifier):
"""Specialized classifier for brain MRI images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the brain MRI classifier
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
# Set default configuration for brain MRI classification
self.config.setdefault("mri", {})
self.config["mri"].setdefault("brain_extraction", True)
self.config["mri"].setdefault("common_conditions", [
"normal", "tumor", "alzheimer", "multiple sclerosis", "stroke"
])
def preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess a brain MRI image
Args:
image: Input brain MRI image
Returns:
np.ndarray: Preprocessed image
"""
# Apply base preprocessing
image = super().preprocess_image(image)
# Apply MRI specific preprocessing
if self.config["mri"].get("brain_extraction", False):
# This would ideally use a brain extraction model
# For simplicity, we'll just apply some basic enhancements
# Convert to grayscale if RGB
if len(image.shape) == 3 and image.shape[2] == 3:
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray = image
# Apply threshold to roughly segment brain
if np.max(gray) <= 1.0:
gray = (gray * 255).astype(np.uint8)
_, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
# Apply morphological operations
kernel = np.ones((5, 5), np.uint8)
binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# Use binary as mask
masked = cv2.bitwise_and(gray, gray, mask=binary)
# Convert back to RGB if needed
if len(image.shape) == 3 and image.shape[2] == 3:
masked = cv2.cvtColor(masked, cv2.COLOR_GRAY2RGB)
masked = masked.astype(np.float32) / 255.0
return masked
return image
if __name__ == "__main__":
# Example usage
classifier = MedicalImageClassifier()
print("Medical Image Classifier initialized")
modules\image_analysis\image_processor.py
"""
Medical Image Processor Module
This module provides classes and functions for processing and analyzing medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.
Key features:
- Image loading and standardization
- Image enhancement and filtering
- Feature extraction
- Segmentation
- Registration
"""
import os
import logging
import numpy as np
import cv2
import pydicom
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
from skimage import measure, morphology, segmentation, filters, feature
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MedicalImageProcessor:
"""Base class for medical image processing"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the medical image processor
Args:
config_path: Path to configuration file
"""
self.config = {}
if config_path and os.path.isfile(config_path):
try:
import json
with open(config_path, 'r') as f:
self.config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
except Exception as e:
logger.error(f"Error loading configuration: {str(e)}")
# Default configuration
if not self.config:
self.config = {
"preprocessing": {
"normalize": True,
"clahe": True,
"clahe_clip_limit": 2.0,
"clahe_grid_size": [8, 8],
"denoise": True,
"target_size": [512, 512]
},
"segmentation": {
"threshold_method": "otsu",
"morphology_operations": True,
"remove_small_objects": True,
"min_size": 100
}
}
def load_image(self, image_path: str) -> Optional[np.ndarray]:
"""
Load an image from file
Args:
image_path: Path to the image file
Returns:
np.ndarray: Loaded image or None if loading fails
"""
if not os.path.isfile(image_path):
logger.error(f"Image file not found: {image_path}")
return None
try:
# Check if it's a DICOM file
if image_path.lower().endswith('.dcm'):
return self._load_dicom(image_path)
# Regular image file
image = cv2.imread(image_path)
if image is None:
logger.error(f"Failed to load image: {image_path}")
return None
# Convert to grayscale if it's a color image and not RGB
if len(image.shape) == 3 and image.shape[2] == 3:
# Keep the original if it's RGB
pass
elif len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return image
except Exception as e:
logger.error(f"Error loading image {image_path}: {str(e)}")
return None
def _load_dicom(self, dicom_path: str) -> Optional[np.ndarray]:
"""
Load a DICOM image
Args:
dicom_path: Path to the DICOM file
Returns:
np.ndarray: DICOM image data or None if loading fails
"""
try:
dicom_data = pydicom.dcmread(dicom_path)
image = dicom_data.pixel_array
# Convert to float and normalize to [0, 1]
image = image.astype(np.float32)
# Apply window center and width if available
if hasattr(dicom_data, 'WindowCenter') and hasattr(dicom_data, 'WindowWidth'):
center = dicom_data.WindowCenter
width = dicom_data.WindowWidth
# Handle multiple window settings
if isinstance(center, pydicom.multival.MultiValue):
center = center[0]
if isinstance(width, pydicom.multival.MultiValue):
width = width[0]
image = self._apply_window_level(image, center, width)
else:
# Simple normalization if window settings are not available
min_val = np.min(image)
max_val = np.max(image)
if max_val > min_val:
image = (image - min_val) / (max_val - min_val)
return image
except Exception as e:
logger.error(f"Error loading DICOM {dicom_path}: {str(e)}")
return None
def _apply_window_level(self, image: np.ndarray, center: float, width: float) -> np.ndarray:
"""
Apply window center and width to the image
Args:
image: Input image
center: Window center
width: Window width
Returns:
np.ndarray: Windowed image
"""
min_val = center - width / 2
max_val = center + width / 2
image = np.clip(image, min_val, max_val)
image = (image - min_val) / (max_val - min_val)
return image
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess the medical image
Args:
image: Input image
Returns:
np.ndarray: Preprocessed image
"""
if image is None:
logger.error("Cannot preprocess None image")
return np.array([])
# Convert to float and ensure range [0, 1]
if image.dtype != np.float32 and image.dtype != np.float64:
image = image.astype(np.float32)
if np.max(image) > 1.0:
image = image / 255.0
# Resize if needed
if self.config["preprocessing"].get("target_size"):
target_size = self.config["preprocessing"]["target_size"]
if image.shape[0] != target_size[0] or image.shape[1] != target_size[1]:
image = cv2.resize(image, (target_size[1], target_size[0]))
# Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
if self.config["preprocessing"].get("clahe", False):
if len(image.shape) == 2: # Grayscale
clip_limit = self.config["preprocessing"].get("clahe_clip_limit", 2.0)
grid_size = tuple(self.config["preprocessing"].get("clahe_grid_size", [8, 8]))
# Convert to uint8 for CLAHE
img_uint8 = (image * 255).astype(np.uint8)
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
img_uint8 = clahe.apply(img_uint8)
# Convert back to float
image = img_uint8.astype(np.float32) / 255.0
# Apply denoising
if self.config["preprocessing"].get("denoise", False):
if len(image.shape) == 2: # Grayscale
# Convert to uint8 for denoising
img_uint8 = (image * 255).astype(np.uint8)
img_uint8 = cv2.fastNlMeansDenoising(img_uint8, None, 10, 7, 21)
# Convert back to float
image = img_uint8.astype(np.float32) / 255.0
return image
def segment(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
"""
Segment the medical image
Args:
image: Preprocessed input image
Returns:
Tuple[np.ndarray, List[Dict]]: Segmentation mask and list of region properties
"""
if image is None or image.size == 0:
logger.error("Cannot segment empty image")
return np.array([]), []
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply thresholding
threshold_method = self.config["segmentation"].get("threshold_method", "otsu")
if threshold_method == "otsu":
thresh_val = filters.threshold_otsu(image)
binary = image > thresh_val
elif threshold_method == "adaptive":
# Convert to uint8 for adaptive thresholding
img_uint8 = (image * 255).astype(np.uint8)
binary = cv2.adaptiveThreshold(
img_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2
)
binary = binary > 0
else:
# Default to simple thresholding
binary = image > 0.5
# Apply morphological operations
if self.config["segmentation"].get("morphology_operations", True):
binary = morphology.binary_opening(binary)
binary = morphology.binary_closing(binary)
# Remove small objects
if self.config["segmentation"].get("remove_small_objects", True):
min_size = self.config["segmentation"].get("min_size", 100)
binary = morphology.remove_small_objects(binary, min_size=min_size)
# Label connected regions
labeled_mask, num_labels = measure.label(binary, return_num=True)
# Extract region properties
regions = []
if num_labels > 0:
props = measure.regionprops(labeled_mask)
for prop in props:
region = {
"label": prop.label,
"area": prop.area,
"centroid": prop.centroid,
"bbox": prop.bbox,
"eccentricity": prop.eccentricity,
"equivalent_diameter": prop.equivalent_diameter,
"orientation": prop.orientation,
"perimeter": prop.perimeter
}
regions.append(region)
return labeled_mask, regions
def extract_features(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
"""
Extract features from the medical image
Args:
image: Preprocessed input image
mask: Optional segmentation mask
Returns:
Dict[str, Any]: Extracted features
"""
if image is None or image.size == 0:
logger.error("Cannot extract features from empty image")
return {}
features = {}
# Ensure image is grayscale
if len(image.shape) > 2:
gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
else:
gray_image = image
# Basic statistical features
if mask is not None:
# Extract features only from the masked region
masked_image = gray_image * (mask > 0)
non_zero = masked_image[mask > 0]
if non_zero.size > 0:
features["mean"] = float(np.mean(non_zero))
features["std"] = float(np.std(non_zero))
features["min"] = float(np.min(non_zero))
features["max"] = float(np.max(non_zero))
features["median"] = float(np.median(non_zero))
features["skewness"] = float(self._calculate_skewness(non_zero))
features["kurtosis"] = float(self._calculate_kurtosis(non_zero))
else:
# Extract features from the entire image
features["mean"] = float(np.mean(gray_image))
features["std"] = float(np.std(gray_image))
features["min"] = float(np.min(gray_image))
features["max"] = float(np.max(gray_image))
features["median"] = float(np.median(gray_image))
features["skewness"] = float(self._calculate_skewness(gray_image))
features["kurtosis"] = float(self._calculate_kurtosis(gray_image))
# Texture features (GLCM)
try:
from skimage.feature import graycomatrix, graycoprops
# Convert to uint8 for GLCM
img_uint8 = (gray_image * 255).astype(np.uint8)
# Calculate GLCM
distances = [1]
angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
glcm = graycomatrix(img_uint8, distances, angles, 256, symmetric=True, normed=True)
# Calculate GLCM properties
props = ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation']
for prop in props:
features[f"glcm_{prop}"] = float(np.mean(graycoprops(glcm, prop)[0]))
except Exception as e:
logger.warning(f"Error calculating GLCM features: {str(e)}")
# Edge features
try:
# Canny edge detection
edges = feature.canny(gray_image)
features["edge_ratio"] = float(np.sum(edges) / edges.size)
# Sobel gradient magnitude
sobel = filters.sobel(gray_image)
features["gradient_mean"] = float(np.mean(sobel))
features["gradient_std"] = float(np.std(sobel))
except Exception as e:
logger.warning(f"Error calculating edge features: {str(e)}")
return features
def _calculate_skewness(self, data: np.ndarray) -> float:
"""
Calculate the skewness of the data
Args:
data: Input data
Returns:
float: Skewness value
"""
mean = np.mean(data)
std = np.std(data)
if std == 0:
return 0.0
n = len(data)
return (np.sum((data - mean) ** 3) / n) / (std ** 3)
def _calculate_kurtosis(self, data: np.ndarray) -> float:
"""
Calculate the kurtosis of the data
Args:
data: Input data
Returns:
float: Kurtosis value
"""
mean = np.mean(data)
std = np.std(data)
if std == 0:
return 0.0
n = len(data)
return (np.sum((data - mean) ** 4) / n) / (std ** 4) - 3
def visualize(self, image: np.ndarray, mask: Optional[np.ndarray] = None,
regions: Optional[List[Dict[str, Any]]] = None,
save_path: Optional[str] = None) -> None:
"""
Visualize the image, segmentation mask, and regions
Args:
image: Input image
mask: Optional segmentation mask
regions: Optional list of region properties
save_path: Optional path to save the visualization
"""
if image is None or image.size == 0:
logger.error("Cannot visualize empty image")
return
plt.figure(figsize=(15, 10))
# Plot original image
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')
# Plot segmentation mask if available
if mask is not None:
plt.subplot(1, 3, 2)
plt.title("Segmentation Mask")
plt.imshow(mask, cmap='viridis')
plt.axis('off')
# Plot image with region boundaries if available
if regions is not None and mask is not None:
plt.subplot(1, 3, 3)
plt.title("Regions")
# Create RGB image for visualization
if len(image.shape) == 2:
rgb_img = np.stack([image] * 3, axis=2)
else:
rgb_img = image.copy()
# Normalize to [0, 1] if needed
if rgb_img.max() > 1.0:
rgb_img = rgb_img / 255.0
# Draw region boundaries
from skimage.segmentation import mark_boundaries
boundary_img = mark_boundaries(rgb_img, mask, color=(1, 0, 0), mode='thick')
# Draw centroids and labels
for region in regions:
y, x = region["centroid"]
plt.plot(x, y, 'ro', markersize=5)
plt.text(x, y, str(region["label"]), color='white',
bbox=dict(facecolor='red', alpha=0.5))
plt.imshow(boundary_img)
plt.axis('off')
plt.tight_layout()
if save_path:
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Visualization saved to {save_path}")
except Exception as e:
logger.error(f"Error saving visualization: {str(e)}")
plt.close()
def save_results(self, results: Dict[str, Any], output_path: str) -> bool:
"""
Save processing results to a file
Args:
results: Processing results
output_path: Path to save the results
Returns:
bool: True if successful, False otherwise
"""
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
import json
with open(output_path, 'w') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {output_path}")
return True
except Exception as e:
logger.error(f"Error saving results: {str(e)}")
return False
class XRayProcessor(MedicalImageProcessor):
"""Specialized processor for X-ray images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the X-ray processor
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
# X-ray specific configuration
self.config.setdefault("xray", {})
self.config["xray"].setdefault("bone_enhancement", True)
self.config["xray"].setdefault("lung_segmentation", True)
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess the X-ray image
Args:
image: Input X-ray image
Returns:
np.ndarray: Preprocessed image
"""
# Apply base preprocessing
image = super().preprocess(image)
# Apply X-ray specific preprocessing
if self.config["xray"].get("bone_enhancement", False):
image = self._enhance_bone_structures(image)
return image
def _enhance_bone_structures(self, image: np.ndarray) -> np.ndarray:
"""
Enhance bone structures in X-ray images
Args:
image: Input X-ray image
Returns:
np.ndarray: Enhanced image
"""
# Apply unsharp masking to enhance edges
blurred = cv2.GaussianBlur(image, (0, 0), 3)
enhanced = cv2.addWeighted(image, 1.5, blurred, -0.5, 0)
return enhanced
def segment_lungs(self, image: np.ndarray) -> np.ndarray:
"""
Segment lungs in chest X-ray images
Args:
image: Preprocessed X-ray image
Returns:
np.ndarray: Lung segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment lungs from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply Otsu's thresholding
thresh_val = filters.threshold_otsu(image)
binary = image < thresh_val # Invert for lung segmentation (lungs are darker)
# Apply morphological operations
binary = morphology.binary_opening(binary, morphology.disk(2))
binary = morphology.binary_closing(binary, morphology.disk(10))
# Remove border objects
cleared = segmentation.clear_border(binary)
# Keep only the two largest regions (left and right lungs)
label_image = measure.label(cleared)
regions = measure.regionprops(label_image)
# Sort regions by area
regions.sort(key=lambda x: x.area, reverse=True)
# Create mask with only the two largest regions
lung_mask = np.zeros_like(image, dtype=bool)
if len(regions) >= 2:
for i in range(min(2, len(regions))):
region_label = regions[i].label
lung_mask[label_image == region_label] = True
return lung_mask
class CTProcessor(MedicalImageProcessor):
"""Specialized processor for CT scan images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the CT processor
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
# CT specific configuration
self.config.setdefault("ct", {})
self.config["ct"].setdefault("hounsfield_normalization", True)
self.config["ct"].setdefault("window_presets", {
"brain": {"center": 40, "width": 80},
"lung": {"center": -600, "width": 1500},
"bone": {"center": 400, "width": 1800},
"soft_tissue": {"center": 50, "width": 400}
})
def preprocess(self, image: np.ndarray, window_preset: str = "soft_tissue") -> np.ndarray:
"""
Preprocess the CT image with specific window settings
Args:
image: Input CT image
window_preset: Window preset name
Returns:
np.ndarray: Preprocessed image
"""
# Apply base preprocessing
image = super().preprocess(image)
# Apply CT-specific preprocessing
if self.config["ct"].get("hounsfield_normalization", False):
# Apply window settings
if window_preset in self.config["ct"]["window_presets"]:
preset = self.config["ct"]["window_presets"][window_preset]
center = preset["center"]
width = preset["width"]
image = self._apply_window_level(image, center, width)
return image
def segment_organs(self, image: np.ndarray, target_organ: str = "liver") -> np.ndarray:
"""
Segment specific organs in CT images
Args:
image: Preprocessed CT image
target_organ: Target organ to segment
Returns:
np.ndarray: Organ segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment organs from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply different thresholds based on target organ
if target_organ == "liver":
# Liver typically has HU values between 50-70
binary = (image > 0.5) & (image < 0.7)
elif target_organ == "lung":
# Lungs have very low HU values
binary = image < 0.3
elif target_organ == "bone":
# Bones have high HU values
binary = image > 0.7
else:
# Default segmentation
binary = image > 0.5
# Apply morphological operations
binary = morphology.binary_opening(binary, morphology.disk(2))
binary = morphology.binary_closing(binary, morphology.disk(5))
# Remove small objects
binary = morphology.remove_small_objects(binary, min_size=100)
return binary
class MRIProcessor(MedicalImageProcessor):
"""Specialized processor for MRI images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the MRI processor
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
# MRI specific configuration
self.config.setdefault("mri", {})
self.config["mri"].setdefault("bias_field_correction", True)
self.config["mri"].setdefault("intensity_normalization", True)
def preprocess(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess the MRI image
Args:
image: Input MRI image
Returns:
np.ndarray: Preprocessed image
"""
# Apply base preprocessing
image = super().preprocess(image)
# Apply MRI-specific preprocessing
if self.config["mri"].get("intensity_normalization", False):
image = self._normalize_intensity(image)
return image
def _normalize_intensity(self, image: np.ndarray) -> np.ndarray:
"""
Normalize intensity in MRI images
Args:
image: Input MRI image
Returns:
np.ndarray: Normalized image
"""
# Apply N4 bias field correction (simplified version)
# For a complete implementation, SimpleITK's N4BiasFieldCorrection should be used
# Simple intensity normalization
p1, p99 = np.percentile(image, (1, 99))
image = np.clip(image, p1, p99)
image = (image - p1) / (p99 - p1)
return image
def segment_brain(self, image: np.ndarray) -> np.ndarray:
"""
Segment brain in MRI images
Args:
image: Preprocessed MRI image
Returns:
np.ndarray: Brain segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment brain from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply Otsu's thresholding
thresh_val = filters.threshold_otsu(image)
binary = image > thresh_val
# Apply morphological operations
binary = morphology.binary_opening(binary, morphology.disk(2))
binary = morphology.binary_closing(binary, morphology.disk(10))
# Fill holes
binary = morphology.remove_small_holes(binary, area_threshold=100)
# Keep only the largest connected component (brain)
labels = measure.label(binary)
largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
binary = labels == largest_label
return binary
if __name__ == "__main__":
# Example usage
processor = MedicalImageProcessor()
print("Medical Image Processor initialized")
modules\image_analysis\image_segmentation.py
"""
Medical Image Segmentation Module
This module provides specialized segmentation algorithms for medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.
Key features:
- Thresholding-based segmentation
- Region growing
- Watershed segmentation
- Deep learning-based segmentation
- Organ-specific segmentation (lungs, brain, etc.)
"""
import os
import logging
import numpy as np
import cv2
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
from skimage import measure, morphology, segmentation, filters
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class MedicalImageSegmenter:
"""Base class for medical image segmentation"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the medical image segmenter
Args:
config_path: Path to configuration file
"""
self.config = {}
if config_path and os.path.isfile(config_path):
try:
import json
with open(config_path, 'r') as f:
self.config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
except Exception as e:
logger.error(f"Error loading configuration: {str(e)}")
# Default configuration
if not self.config:
self.config = {
"segmentation": {
"threshold_method": "otsu",
"morphology_operations": True,
"remove_small_objects": True,
"min_size": 100
}
}
def threshold_segmentation(self, image: np.ndarray) -> np.ndarray:
"""
Perform threshold-based segmentation
Args:
image: Input image (grayscale)
Returns:
np.ndarray: Binary segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply thresholding
threshold_method = self.config.get("segmentation", {}).get("threshold_method", "otsu")
if threshold_method == "otsu":
thresh_val = filters.threshold_otsu(image)
binary = image > thresh_val
elif threshold_method == "adaptive":
# Convert to uint8 for adaptive thresholding
img_uint8 = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
binary = cv2.adaptiveThreshold(
img_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2
)
binary = binary > 0
else:
# Default to simple thresholding
binary = image > 0.5
# Apply morphological operations
if self.config.get("segmentation", {}).get("morphology_operations", True):
binary = morphology.binary_opening(binary)
binary = morphology.binary_closing(binary)
# Remove small objects
if self.config.get("segmentation", {}).get("remove_small_objects", True):
min_size = self.config.get("segmentation", {}).get("min_size", 100)
binary = morphology.remove_small_objects(binary, min_size=min_size)
return binary
def region_growing(self, image: np.ndarray, seed_point: Tuple[int, int],
threshold: float = 0.1) -> np.ndarray:
"""
Perform region growing segmentation
Args:
image: Input image (grayscale)
seed_point: Starting point for region growing (y, x)
threshold: Intensity threshold for region growing
Returns:
np.ndarray: Binary segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Create output mask
mask = np.zeros_like(image, dtype=bool)
# Get image dimensions
height, width = image.shape
# Check if seed point is valid
if not (0 <= seed_point[0] < height and 0 <= seed_point[1] < width):
logger.error(f"Invalid seed point: {seed_point}")
return mask
# Get seed value
seed_value = image[seed_point]
# Initialize queue with seed point
queue = [seed_point]
mask[seed_point] = True
# Define 4-connected neighbors
neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0)]
# Perform region growing
while queue:
y, x = queue.pop(0)
for dy, dx in neighbors:
ny, nx = y + dy, x + dx
# Check if neighbor is valid
if (0 <= ny < height and 0 <= nx < width and
not mask[ny, nx] and
abs(image[ny, nx] - seed_value) <= threshold):
mask[ny, nx] = True
queue.append((ny, nx))
return mask
def watershed_segmentation(self, image: np.ndarray) -> np.ndarray:
"""
Perform watershed segmentation
Args:
image: Input image (grayscale)
Returns:
np.ndarray: Labeled segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Convert to uint8 for watershed
if image.dtype != np.uint8:
if np.max(image) <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Apply gradient
gradient = filters.sobel(image)
# Apply threshold to create markers
thresh_val = filters.threshold_otsu(image)
markers = np.zeros_like(image, dtype=np.int32)
markers[image < thresh_val * 0.6] = 1 # Background
markers[image > thresh_val * 1.2] = 2 # Foreground
# Apply watershed
segmented = segmentation.watershed(gradient, markers)
# Convert to binary mask (foreground only)
binary_mask = segmented == 2
return binary_mask
def extract_region_properties(self, mask: np.ndarray) -> List[Dict[str, Any]]:
"""
Extract properties of segmented regions
Args:
mask: Binary or labeled segmentation mask
Returns:
List[Dict[str, Any]]: List of region properties
"""
if mask is None or mask.size == 0:
logger.error("Cannot extract properties from empty mask")
return []
# Label connected regions if mask is binary
if mask.dtype == bool or np.array_equal(np.unique(mask), [0, 1]):
labeled_mask, num_labels = measure.label(mask, return_num=True)
else:
labeled_mask = mask
num_labels = np.max(mask)
# Extract region properties
regions = []
if num_labels > 0:
props = measure.regionprops(labeled_mask)
for prop in props:
region = {
"label": prop.label,
"area": prop.area,
"centroid": prop.centroid,
"bbox": prop.bbox,
"eccentricity": prop.eccentricity,
"equivalent_diameter": prop.equivalent_diameter,
"orientation": prop.orientation,
"perimeter": prop.perimeter
}
regions.append(region)
return regions
def visualize_segmentation(self, image: np.ndarray, mask: np.ndarray,
regions: Optional[List[Dict[str, Any]]] = None,
save_path: Optional[str] = None) -> None:
"""
Visualize the segmentation results
Args:
image: Original image
mask: Segmentation mask
regions: Optional list of region properties
save_path: Optional path to save the visualization
"""
if image is None or mask is None:
logger.error("Cannot visualize empty image or mask")
return
plt.figure(figsize=(15, 5))
# Display original image
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(image, cmap='gray')
plt.axis('off')
# Display segmentation mask
plt.subplot(1, 3, 2)
plt.title("Segmentation Mask")
plt.imshow(mask, cmap='viridis')
plt.axis('off')
# Display overlay
plt.subplot(1, 3, 3)
plt.title("Overlay")
# Create RGB image for visualization
if len(image.shape) == 2:
rgb_img = np.stack([image] * 3, axis=2)
else:
rgb_img = image.copy()
# Normalize to [0, 1] if needed
if rgb_img.max() > 1.0:
rgb_img = rgb_img / 255.0
# Draw region boundaries
from skimage.segmentation import mark_boundaries
boundary_img = mark_boundaries(rgb_img, mask, color=(1, 0, 0), mode='thick')
# Draw centroids and labels if regions are provided
if regions:
for region in regions:
y, x = region["centroid"]
plt.plot(x, y, 'ro', markersize=5)
plt.text(x, y, str(region["label"]), color='white',
bbox=dict(facecolor='red', alpha=0.5))
plt.imshow(boundary_img)
plt.axis('off')
plt.tight_layout()
if save_path:
try:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches='tight')
logger.info(f"Visualization saved to {save_path}")
except Exception as e:
logger.error(f"Error saving visualization: {str(e)}")
plt.close()
class LungSegmenter(MedicalImageSegmenter):
"""Specialized segmenter for lung segmentation in chest X-rays"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the lung segmenter
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
def segment_lungs(self, image: np.ndarray) -> np.ndarray:
"""
Segment lungs in chest X-ray images
Args:
image: Input chest X-ray image
Returns:
np.ndarray: Lung segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment lungs from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply Otsu's thresholding
thresh_val = filters.threshold_otsu(image)
binary = image < thresh_val # Invert for lung segmentation (lungs are darker)
# Apply morphological operations
binary = morphology.binary_opening(binary, morphology.disk(2))
binary = morphology.binary_closing(binary, morphology.disk(10))
# Remove border objects
cleared = segmentation.clear_border(binary)
# Keep only the two largest regions (left and right lungs)
label_image = measure.label(cleared)
regions = measure.regionprops(label_image)
# Sort regions by area
regions.sort(key=lambda x: x.area, reverse=True)
# Create mask with only the two largest regions
lung_mask = np.zeros_like(image, dtype=bool)
if len(regions) >= 2:
for i in range(min(2, len(regions))):
region_label = regions[i].label
lung_mask[label_image == region_label] = True
return lung_mask
def enhance_lung_structures(self, image: np.ndarray, lung_mask: np.ndarray) -> np.ndarray:
"""
Enhance lung structures in chest X-ray images
Args:
image: Input chest X-ray image
lung_mask: Lung segmentation mask
Returns:
np.ndarray: Enhanced lung image
"""
if image is None or image.size == 0 or lung_mask is None or lung_mask.size == 0:
logger.error("Cannot enhance empty image or mask")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply mask to focus on lung region
masked_image = image.copy()
masked_image[~lung_mask] = 0
# Apply CLAHE to enhance contrast in lung regions
if image.dtype != np.uint8:
if np.max(image) <= 1.0:
masked_image = (masked_image * 255).astype(np.uint8)
else:
masked_image = masked_image.astype(np.uint8)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(masked_image)
# Blend enhanced lungs with original image
result = image.copy()
if result.dtype != np.uint8:
if np.max(result) <= 1.0:
result = (result * 255).astype(np.uint8)
else:
result = result.astype(np.uint8)
result[lung_mask] = enhanced[lung_mask]
return result
class BrainSegmenter(MedicalImageSegmenter):
"""Specialized segmenter for brain segmentation in MRI images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the brain segmenter
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
def segment_brain(self, image: np.ndarray) -> np.ndarray:
"""
Segment brain in MRI images
Args:
image: Input MRI image
Returns:
np.ndarray: Brain segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment brain from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply Otsu's thresholding
thresh_val = filters.threshold_otsu(image)
binary = image > thresh_val
# Apply morphological operations
binary = morphology.binary_opening(binary, morphology.disk(2))
binary = morphology.binary_closing(binary, morphology.disk(10))
# Fill holes
binary = morphology.remove_small_holes(binary, area_threshold=100)
# Keep only the largest connected component (brain)
labels = measure.label(binary)
if np.max(labels) > 0:
largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
binary = labels == largest_label
return binary
def segment_tumor(self, image: np.ndarray, brain_mask: np.ndarray) -> np.ndarray:
"""
Segment potential tumor regions in brain MRI
Args:
image: Input MRI image
brain_mask: Brain segmentation mask
Returns:
np.ndarray: Tumor segmentation mask
"""
if image is None or image.size == 0 or brain_mask is None or brain_mask.size == 0:
logger.error("Cannot segment tumor from empty image or mask")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply mask to focus on brain region
masked_image = image.copy()
masked_image[~brain_mask] = 0
# Apply adaptive thresholding to find potential tumor regions
# Tumors are typically hyperintense (brighter) in T2-weighted MRI
thresh_val = np.percentile(masked_image[brain_mask], 95)
tumor_mask = masked_image > thresh_val
# Apply morphological operations
tumor_mask = morphology.binary_opening(tumor_mask, morphology.disk(2))
tumor_mask = morphology.binary_closing(tumor_mask, morphology.disk(5))
# Remove small objects
tumor_mask = morphology.remove_small_objects(tumor_mask, min_size=50)
return tumor_mask
class OrganSegmenter(MedicalImageSegmenter):
"""Specialized segmenter for organ segmentation in CT images"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the organ segmenter
Args:
config_path: Path to configuration file
"""
super().__init__(config_path)
def segment_liver(self, image: np.ndarray) -> np.ndarray:
"""
Segment liver in CT images
Args:
image: Input CT image
Returns:
np.ndarray: Liver segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment liver from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply thresholding for liver (typical HU range for liver is 50-70)
# Assuming image is normalized to [0, 1]
liver_mask = (image > 0.5) & (image < 0.7)
# Apply morphological operations
liver_mask = morphology.binary_opening(liver_mask, morphology.disk(2))
liver_mask = morphology.binary_closing(liver_mask, morphology.disk(10))
# Remove small objects
liver_mask = morphology.remove_small_objects(liver_mask, min_size=500)
# Keep only the largest connected component (liver)
labels = measure.label(liver_mask)
if np.max(labels) > 0:
largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
liver_mask = labels == largest_label
return liver_mask
def segment_kidneys(self, image: np.ndarray) -> np.ndarray:
"""
Segment kidneys in CT images
Args:
image: Input CT image
Returns:
np.ndarray: Kidney segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment kidneys from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply thresholding for kidneys (typical HU range for kidneys is 30-50)
# Assuming image is normalized to [0, 1]
kidney_mask = (image > 0.3) & (image < 0.5)
# Apply morphological operations
kidney_mask = morphology.binary_opening(kidney_mask, morphology.disk(2))
kidney_mask = morphology.binary_closing(kidney_mask, morphology.disk(5))
# Remove small objects
kidney_mask = morphology.remove_small_objects(kidney_mask, min_size=200)
return kidney_mask
def segment_bones(self, image: np.ndarray) -> np.ndarray:
"""
Segment bones in CT images
Args:
image: Input CT image
Returns:
np.ndarray: Bone segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment bones from empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Apply thresholding for bones (typical HU range for bones is >700)
# Assuming image is normalized to [0, 1]
bone_mask = image > 0.7
# Apply morphological operations
bone_mask = morphology.binary_opening(bone_mask, morphology.disk(1))
bone_mask = morphology.binary_closing(bone_mask, morphology.disk(3))
# Remove small objects
bone_mask = morphology.remove_small_objects(bone_mask, min_size=100)
return bone_mask
class DeepLearningSegmenter(MedicalImageSegmenter):
"""Deep learning-based segmentation for medical images"""
def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
"""
Initialize the deep learning segmenter
Args:
config_path: Path to configuration file
model_path: Path to pre-trained model
"""
super().__init__(config_path)
self.model = None
self.model_path = model_path
# Load model if path is provided
if model_path and os.path.exists(model_path):
self.load_model(model_path)
def load_model(self, model_path: str) -> bool:
"""
Load a pre-trained segmentation model
Args:
model_path: Path to pre-trained model
Returns:
bool: True if model loaded successfully, False otherwise
"""
try:
# Check if TensorFlow is available
import tensorflow as tf
from tensorflow.keras.models import load_model
# Load model
self.model = load_model(model_path, compile=False)
logger.info(f"Loaded model from {model_path}")
return True
except ImportError:
logger.error("TensorFlow not available. Cannot load deep learning model.")
return False
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
return False
def preprocess_for_model(self, image: np.ndarray) -> np.ndarray:
"""
Preprocess image for model input
Args:
image: Input image
Returns:
np.ndarray: Preprocessed image
"""
if image is None or image.size == 0:
logger.error("Cannot preprocess empty image")
return np.array([])
# Ensure image is grayscale
if len(image.shape) > 2:
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
# Resize to model input size
target_size = self.config.get("image_processing", {}).get("target_size", [256, 256])
resized = cv2.resize(image, (target_size[1], target_size[0]))
# Normalize to [0, 1]
if resized.max() > 1.0:
resized = resized / 255.0
# Add batch and channel dimensions
preprocessed = np.expand_dims(np.expand_dims(resized, axis=0), axis=-1)
return preprocessed
def segment_with_model(self, image: np.ndarray) -> np.ndarray:
"""
Segment image using deep learning model
Args:
image: Input image
Returns:
np.ndarray: Segmentation mask
"""
if image is None or image.size == 0:
logger.error("Cannot segment empty image")
return np.array([])
if self.model is None:
logger.error("No model loaded. Call load_model() first.")
return np.array([])
# Preprocess image
preprocessed = self.preprocess_for_model(image)
# Make prediction
try:
prediction = self.model.predict(preprocessed)[0]
# Convert prediction to binary mask
if prediction.shape[-1] > 1:
# Multi-class segmentation
mask = np.argmax(prediction, axis=-1)
else:
# Binary segmentation
mask = (prediction > 0.5).astype(np.uint8)
# Resize mask to original image size
if mask.shape[:2] != image.shape[:2]:
mask = cv2.resize(mask.astype(np.uint8),
(image.shape[1], image.shape[0]),
interpolation=cv2.INTER_NEAREST)
return mask
except Exception as e:
logger.error(f"Error during model prediction: {str(e)}")
return np.array([])
if __name__ == "__main__":
# Example usage
import argparse
parser = argparse.ArgumentParser(description='Medical Image Segmentation')
parser.add_argument('--image', type=str, required=True, help='Path to input image')
parser.add_argument('--type', type=str, choices=['xray', 'ct', 'mri'], required=True,
help='Type of medical image')
parser.add_argument('--output', type=str, default='./output', help='Output directory')
parser.add_argument('--config', type=str, default=None, help='Path to configuration file')
args = parser.parse_args()
# Create output directory
os.makedirs(args.output, exist_ok=True)
# Load image
image = cv2.imread(args.image)
if image is None:
print(f"Error: Could not load image {args.image}")
exit(1)
# Select appropriate segmenter based on image type
if args.type == 'xray':
segmenter = LungSegmenter(args.config)
mask = segmenter.segment_lungs(image)
regions = segmenter.extract_region_properties(mask)
enhanced = segmenter.enhance_lung_structures(image, mask)
# Save enhanced image
enhanced_path = os.path.join(args.output, 'enhanced_lungs.png')
cv2.imwrite(enhanced_path, enhanced)
print(f"Enhanced image saved to {enhanced_path}")
elif args.type == 'mri':
segmenter = BrainSegmenter(args.config)
brain_mask = segmenter.segment_brain(image)
regions = segmenter.extract_region_properties(brain_mask)
# Try to segment tumor
tumor_mask = segmenter.segment_tumor(image, brain_mask)
# Save tumor mask if found
if np.any(tumor_mask):
tumor_path = os.path.join(args.output, 'tumor_mask.png')
cv2.imwrite(tumor_path, tumor_mask.astype(np.uint8) * 255)
print(f"Tumor mask saved to {tumor_path}")
elif args.type == 'ct':
segmenter = OrganSegmenter(args.config)
liver_mask = segmenter.segment_liver(image)
kidney_mask = segmenter.segment_kidneys(image)
bone_mask = segmenter.segment_bones(image)
# Create combined mask
combined_mask = np.zeros_like(image[:,:,0], dtype=np.uint8)
combined_mask[liver_mask] = 1 # Liver
combined_mask[kidney_mask] = 2 # Kidneys
combined_mask[bone_mask] = 3 # Bones
regions = segmenter.extract_region_properties(combined_mask)
# Save individual masks
cv2.imwrite(os.path.join(args.output, 'liver_mask.png'), liver_mask.astype(np.uint8) * 255)
cv2.imwrite(os.path.join(args.output, 'kidney_mask.png'), kidney_mask.astype(np.uint8) * 255)
cv2.imwrite(os.path.join(args.output, 'bone_mask.png'), bone_mask.astype(np.uint8) * 255)
print(f"Organ masks saved to {args.output}")
# Save segmentation mask
mask_path = os.path.join(args.output, 'segmentation_mask.png')
cv2.imwrite(mask_path, mask.astype(np.uint8) * 255)
print(f"Segmentation mask saved to {mask_path}")
# Visualize segmentation
vis_path = os.path.join(args.output, 'segmentation_visualization.png')
segmenter.visualize_segmentation(image, mask, regions, vis_path)
print(f"Visualization saved to {vis_path}")
print(f"Detected {len(regions)} regions in the image")
for i, region in enumerate(regions[:5]): # Show top 5 regions
print(f"Region {i+1}: Area = {region['area']}, Centroid = {region['centroid']}")
modules\image_analysis_init_.py
modules\preprocessing\data_preprocessor.py
"""
Data Preprocessor Module for Medical Diagnosis System
This module handles the preprocessing of various types of medical data including:
- Medical images (DICOM, JPEG, PNG)
- Clinical text data (reports, notes)
- Laboratory results
The module provides standardization, normalization, and transformation functions.
"""
import os
import json
import logging
import numpy as np
import pandas as pd
import cv2
import pydicom
from typing import Dict, List, Union, Optional, Any, Tuple
from pathlib import Path
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from sklearn.preprocessing import StandardScaler, MinMaxScaler
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class DataPreprocessor:
"""Base class for data preprocessing operations"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the data preprocessor with optional configuration
Args:
config_path: Path to configuration file
"""
self.config = {}
if config_path and os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
self.config = json.load(f)
logger.info("DataPreprocessor initialized")
def save_preprocessed_data(self, data: Any, output_path: str) -> bool:
"""
Save preprocessed data to file
Args:
data: Data to save
output_path: Path to save the data
Returns:
bool: True if successful, False otherwise
"""
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Determine the file type from extension
ext = os.path.splitext(output_path)[1].lower()
if ext == '.npy':
np.save(output_path, data)
elif ext == '.json':
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
elif ext == '.csv':
if isinstance(data, pd.DataFrame):
data.to_csv(output_path, index=False)
else:
pd.DataFrame(data).to_csv(output_path, index=False)
elif ext in ['.jpg', '.jpeg', '.png']:
cv2.imwrite(output_path, data)
else:
logger.warning(f"Unsupported file extension: {ext}")
return False
logger.info(f"Saved preprocessed data to {output_path}")
return True
except Exception as e:
logger.error(f"Error saving preprocessed data: {str(e)}")
return False
class MedicalImagePreprocessor(DataPreprocessor):
"""Class for preprocessing medical image data"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the medical image preprocessor"""
super().__init__(config_path)
# Default preprocessing parameters
self.target_size = self.config.get('target_size', (224, 224))
self.normalize = self.config.get('normalize', True)
self.clahe = self.config.get('clahe', True)
self.clahe_clip_limit = self.config.get('clahe_clip_limit', 2.0)
self.clahe_grid_size = self.config.get('clahe_grid_size', (8, 8))
logger.info(f"MedicalImagePreprocessor initialized with target size: {self.target_size}")
def preprocess_dicom(self, dicom_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess DICOM image data
Args:
dicom_data: DICOM data from the collector
Returns:
Dict: Preprocessed DICOM data
"""
try:
file_path = dicom_data.get('file_path')
if not file_path or not os.path.exists(file_path):
logger.error(f"DICOM file not found: {file_path}")
return {}
# Read the DICOM file
dicom = pydicom.dcmread(file_path)
# Extract pixel array
pixel_array = dicom.pixel_array
# Convert to appropriate format
if dicom.PhotometricInterpretation == "MONOCHROME1":
pixel_array = np.max(pixel_array) - pixel_array
# Normalize to 0-255 range if needed
if pixel_array.dtype != np.uint8:
pixel_array = cv2.normalize(pixel_array, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
# Apply CLAHE for contrast enhancement if enabled
if self.clahe:
clahe = cv2.createCLAHE(clipLimit=self.clahe_clip_limit, tileGridSize=self.clahe_grid_size)
if len(pixel_array.shape) == 2:
pixel_array = clahe.apply(pixel_array)
else:
# For RGB images, apply CLAHE to the luminance channel
if len(pixel_array.shape) == 3 and pixel_array.shape[2] == 3:
lab = cv2.cvtColor(pixel_array, cv2.COLOR_RGB2LAB)
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
pixel_array = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# Resize to target size
if pixel_array.shape[:2] != self.target_size:
pixel_array = cv2.resize(pixel_array, self.target_size, interpolation=cv2.INTER_CUBIC)
# Ensure 3 channels for model compatibility
if len(pixel_array.shape) == 2:
pixel_array = np.stack([pixel_array] * 3, axis=2)
# Normalize to 0-1 range if enabled
if self.normalize:
pixel_array = pixel_array.astype(np.float32) / 255.0
# Create result dictionary
result = {
'preprocessed_image': pixel_array,
'metadata': dicom_data.get('metadata', {}),
'original_file_path': file_path,
'preprocessing_info': {
'target_size': self.target_size,
'normalize': self.normalize,
'clahe': self.clahe
}
}
logger.info(f"Successfully preprocessed DICOM image from {file_path}")
return result
except Exception as e:
logger.error(f"Error preprocessing DICOM image: {str(e)}")
return {}
def preprocess_image(self, image_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess regular image data (JPEG, PNG)
Args:
image_data: Image data from the collector
Returns:
Dict: Preprocessed image data
"""
try:
file_path = image_data.get('file_path')
if not file_path or not os.path.exists(file_path):
logger.error(f"Image file not found: {file_path}")
return {}
# Read the image file
image = cv2.imread(file_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Apply CLAHE for contrast enhancement if enabled
if self.clahe:
clahe = cv2.createCLAHE(clipLimit=self.clahe_clip_limit, tileGridSize=self.clahe_grid_size)
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
lab[:, :, 0] = clahe.apply(lab[:, :, 0])
image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
# Resize to target size
if image.shape[:2] != self.target_size:
image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_CUBIC)
# Normalize to 0-1 range if enabled
if self.normalize:
image = image.astype(np.float32) / 255.0
# Create result dictionary
result = {
'preprocessed_image': image,
'metadata': image_data.get('metadata', {}),
'original_file_path': file_path,
'preprocessing_info': {
'target_size': self.target_size,
'normalize': self.normalize,
'clahe': self.clahe
}
}
logger.info(f"Successfully preprocessed image from {file_path}")
return result
except Exception as e:
logger.error(f"Error preprocessing image: {str(e)}")
return {}
def preprocess_batch(self, image_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Preprocess a batch of images
Args:
image_data_list: List of image data from the collector
Returns:
List[Dict]: List of preprocessed image data
"""
results = []
for image_data in image_data_list:
if image_data.get('type') == 'dicom':
result = self.preprocess_dicom(image_data)
else:
result = self.preprocess_image(image_data)
if result:
results.append(result)
logger.info(f"Preprocessed {len(results)} images out of {len(image_data_list)} total")
return results
class ClinicalTextPreprocessor(DataPreprocessor):
"""Class for preprocessing clinical text data"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the clinical text preprocessor"""
super().__init__(config_path)
# Download NLTK resources if needed
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
# Default preprocessing parameters
self.remove_stopwords = self.config.get('remove_stopwords', True)
self.remove_punctuation = self.config.get('remove_punctuation', True)
self.lowercase = self.config.get('lowercase', True)
self.language = self.config.get('language', 'english')
# Load stopwords
self.stopwords = set(stopwords.words(self.language)) if self.remove_stopwords else set()
# Medical specific stopwords (common words in medical texts that aren't informative)
medical_stopwords = self.config.get('medical_stopwords', [
'patient', 'doctor', 'hospital', 'treatment', 'medicine',
'dose', 'report', 'medical', 'clinical', 'health'
])
self.stopwords.update(medical_stopwords)
logger.info(f"ClinicalTextPreprocessor initialized with {len(self.stopwords)} stopwords")
def preprocess_text(self, text_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess clinical text data
Args:
text_data: Text data from the collector
Returns:
Dict: Preprocessed text data
"""
try:
content = text_data.get('content')
if not content:
logger.error("No content found in text data")
return {}
# Handle different content types
if isinstance(content, list) or isinstance(content, dict):
# For structured data, convert to string for text processing
content_str = json.dumps(content)
else:
content_str = str(content)
# Apply preprocessing steps
if self.lowercase:
content_str = content_str.lower()
if self.remove_punctuation:
content_str = re.sub(r'[^\w\s]', ' ', content_str)
# Tokenize
tokens = word_tokenize(content_str)
# Remove stopwords if enabled
if self.remove_stopwords:
tokens = [token for token in tokens if token not in self.stopwords]
# Reconstruct text
preprocessed_text = ' '.join(tokens)
# Extract medical entities (simplified)
medical_entities = self._extract_medical_entities(preprocessed_text)
# Create result dictionary
result = {
'original_content': content,
'preprocessed_text': preprocessed_text,
'tokens': tokens,
'medical_entities': medical_entities,
'metadata': text_data.get('metadata', {}),
'preprocessing_info': {
'remove_stopwords': self.remove_stopwords,
'remove_punctuation': self.remove_punctuation,
'lowercase': self.lowercase
}
}
logger.info(f"Successfully preprocessed text from {text_data.get('file_path', 'unknown')}")
return result
except Exception as e:
logger.error(f"Error preprocessing text: {str(e)}")
return {}
def _extract_medical_entities(self, text: str) -> Dict[str, List[str]]:
"""
Extract medical entities from text (simplified version)
Args:
text: Preprocessed text
Returns:
Dict: Dictionary of medical entities by category
"""
# This is a simplified implementation
# In a real system, you would use a medical NER model or dictionary lookup
# Simple dictionaries for demonstration
symptoms = ['pain', 'fever', 'cough', 'headache', 'nausea', 'fatigue', 'dizziness']
diseases = ['diabetes', 'hypertension', 'cancer', 'asthma', 'pneumonia', 'arthritis']
medications = ['aspirin', 'ibuprofen', 'paracetamol', 'insulin', 'antibiotics']
found_entities = {
'symptoms': [],
'diseases': [],
'medications': []
}
tokens = text.split()
for token in tokens:
if token in symptoms:
found_entities['symptoms'].append(token)
elif token in diseases:
found_entities['diseases'].append(token)
elif token in medications:
found_entities['medications'].append(token)
return found_entities
def preprocess_batch(self, text_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Preprocess a batch of clinical texts
Args:
text_data_list: List of text data from the collector
Returns:
List[Dict]: List of preprocessed text data
"""
results = []
for text_data in text_data_list:
result = self.preprocess_text(text_data)
if result:
results.append(result)
logger.info(f"Preprocessed {len(results)} texts out of {len(text_data_list)} total")
return results
class LabResultPreprocessor(DataPreprocessor):
"""Class for preprocessing laboratory result data"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize the lab result preprocessor"""
super().__init__(config_path)
# Default preprocessing parameters
self.normalization = self.config.get('normalization', 'standard') # 'standard', 'minmax', or None
self.handle_missing = self.config.get('handle_missing', 'mean') # 'mean', 'median', 'zero', or 'drop'
self.outlier_detection = self.config.get('outlier_detection', True)
self.outlier_threshold = self.config.get('outlier_threshold', 3.0) # Z-score threshold
# Initialize scalers
self.standard_scaler = StandardScaler() if self.normalization == 'standard' else None
self.minmax_scaler = MinMaxScaler() if self.normalization == 'minmax' else None
logger.info(f"LabResultPreprocessor initialized with normalization: {self.normalization}")
def preprocess_lab_data(self, lab_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Preprocess laboratory result data
Args:
lab_data: Lab data from the collector
Returns:
Dict: Preprocessed lab data
"""
try:
data = lab_data.get('data')
if not data:
logger.error("No data found in lab results")
return {}
# Convert to DataFrame for easier processing
if isinstance(data, list):
df = pd.DataFrame(data)
elif isinstance(data, dict):
df = pd.DataFrame([data])
else:
logger.error(f"Unsupported data type: {type(data)}")
return {}
# Store original data
original_df = df.copy()
# Handle missing values
if self.handle_missing == 'mean':
df = df.fillna(df.mean(numeric_only=True))
elif self.handle_missing == 'median':
df = df.fillna(df.median(numeric_only=True))
elif self.handle_missing == 'zero':
df = df.fillna(0)
elif self.handle_missing == 'drop':
df = df.dropna()
# Handle non-numeric columns
numeric_cols = df.select_dtypes(include=['number']).columns
non_numeric_cols = [col for col in df.columns if col not in numeric_cols]
# Detect and handle outliers in numeric columns
if self.outlier_detection:
for col in numeric_cols:
if df[col].dtype in [np.float64, np.int64]:
# Calculate z-scores
z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
# Mark outliers
outliers = z_scores > self.outlier_threshold
# Replace outliers with column mean
df.loc[outliers, col] = df[col].mean()
# Apply normalization to numeric columns
numeric_df = df[numeric_cols]
if self.normalization == 'standard':
normalized_data = self.standard_scaler.fit_transform(numeric_df)
normalized_df = pd.DataFrame(normalized_data, columns=numeric_cols, index=df.index)
elif self.normalization == 'minmax':
normalized_data = self.minmax_scaler.fit_transform(numeric_df)
normalized_df = pd.DataFrame(normalized_data, columns=numeric_cols, index=df.index)
else:
normalized_df = numeric_df
# Combine normalized numeric columns with non-numeric columns
if non_numeric_cols:
result_df = pd.concat([normalized_df, df[non_numeric_cols]], axis=1)
else:
result_df = normalized_df
# Calculate statistics
stats = {}
for col in numeric_cols:
stats[col] = {
'mean': float(df[col].mean()),
'median': float(df[col].median()),
'std': float(df[col].std()),
'min': float(df[col].min()),
'max': float(df[col].max())
}
# Create result dictionary
result = {
'original_data': original_df.to_dict(orient='records'),
'preprocessed_data': result_df.to_dict(orient='records'),
'statistics': stats,
'metadata': lab_data.get('metadata', {}),
'preprocessing_info': {
'normalization': self.normalization,
'handle_missing': self.handle_missing,
'outlier_detection': self.outlier_detection,
'outlier_threshold': self.outlier_threshold
}
}
logger.info(f"Successfully preprocessed lab data from {lab_data.get('file_path', 'unknown')}")
return result
except Exception as e:
logger.error(f"Error preprocessing lab data: {str(e)}")
return {}
def preprocess_batch(self, lab_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Preprocess a batch of lab results
Args:
lab_data_list: List of lab data from the collector
Returns:
List[Dict]: List of preprocessed lab data
"""
results = []
for lab_data in lab_data_list:
result = self.preprocess_lab_data(lab_data)
if result:
results.append(result)
logger.info(f"Preprocessed {len(results)} lab results out of {len(lab_data_list)} total")
return results
# Composite preprocessor that uses all specialized preprocessors
class MedicalDataPreprocessor:
"""Composite class that combines all data preprocessors"""
def __init__(self, config_path: Optional[str] = None):
"""Initialize all data preprocessors"""
self.image_preprocessor = MedicalImagePreprocessor(config_path)
self.text_preprocessor = ClinicalTextPreprocessor(config_path)
self.lab_preprocessor = LabResultPreprocessor(config_path)
logger.info("MedicalDataPreprocessor initialized with all specialized preprocessors")
def preprocess_all(self, collection: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
"""
Preprocess all types of medical data in a collection
Args:
collection: Collection of data from the collector
Returns:
Dict: Dictionary with keys for each data type and values as lists of preprocessed data
"""
preprocessed = {}
if 'images' in collection:
preprocessed['images'] = self.image_preprocessor.preprocess_batch(collection['images'])
if 'clinical_texts' in collection:
preprocessed['clinical_texts'] = self.text_preprocessor.preprocess_batch(collection['clinical_texts'])
if 'lab_results' in collection:
preprocessed['lab_results'] = self.lab_preprocessor.preprocess_batch(collection['lab_results'])
logger.info(f"Preprocessed all data in collection")
return preprocessed
def save_preprocessed_data(self, preprocessed_data: Dict[str, List[Dict[str, Any]]], output_dir: str) -> None:
"""
Save all preprocessed data to files
Args:
preprocessed_data: Preprocessed data collection
output_dir: Directory to save the data
"""
os.makedirs(output_dir, exist_ok=True)
# Save a summary
summary = {
'total_items': sum(len(items) for items in preprocessed_data.values()),
'counts': {data_type: len(items) for data_type, items in preprocessed_data.items()}
}
with open(os.path.join(output_dir, 'preprocessing_summary.json'), 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2)
# Save each data type in its own subdirectory
for data_type, items in preprocessed_data.items():
type_dir = os.path.join(output_dir, data_type)
os.makedirs(type_dir, exist_ok=True)
for i, item in enumerate(items):
# Save metadata and preprocessing info
meta_file = os.path.join(type_dir, f'item_{i}_meta.json')
meta_data = {
'metadata': item.get('metadata', {}),
'preprocessing_info': item.get('preprocessing_info', {})
}
with open(meta_file, 'w', encoding='utf-8') as f:
json.dump(meta_data, f, indent=2)
# Save the actual data based on type
if data_type == 'images':
# Save image as NPY file
img_file = os.path.join(type_dir, f'item_{i}_image.npy')
np.save(img_file, item.get('preprocessed_image'))
elif data_type == 'clinical_texts':
# Save text data as JSON
text_file = os.path.join(type_dir, f'item_{i}_text.json')
text_data = {
'preprocessed_text': item.get('preprocessed_text', ''),
'tokens': item.get('tokens', []),
'medical_entities': item.get('medical_entities', {})
}
with open(text_file, 'w', encoding='utf-8') as f:
json.dump(text_data, f, indent=2)
elif data_type == 'lab_results':
# Save lab data as CSV
lab_file = os.path.join(type_dir, f'item_{i}_lab.csv')
preprocessed_data = item.get('preprocessed_data', [])
pd.DataFrame(preprocessed_data).to_csv(lab_file, index=False)
logger.info(f"Saved all preprocessed data to {output_dir}")
if __name__ == "__main__":
# Example usage
from data_collector import MedicalDataCollector
collector = MedicalDataCollector()
collection = collector.collect_all_from_directory("./sample_data")
preprocessor = MedicalDataPreprocessor()
preprocessed_data = preprocessor.preprocess_all(collection)
preprocessor.save_preprocessed_data(preprocessed_data, "./data/preprocessed")
modules\preprocessing_init_.py
modules\text_analysis\clinical_text_analyzer.py
"""
Clinical Text Analyzer Module
This module integrates all text analysis components (preprocessing, NER, classification,
summarization, and relation extraction) to provide a comprehensive clinical text analysis.
"""
import os
import json
import logging
from typing import List, Dict, Any, Optional, Union, Tuple
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import seaborn as sns
from .text_preprocessor import ClinicalTextPreprocessor
from .medical_ner import MedicalNamedEntityRecognizer
from .text_classifier import ClinicalTextClassifier
from .text_summarizer import ClinicalTextSummarizer
from .relation_extractor import MedicalRelationExtractor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("clinical_text_analysis.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
class ClinicalTextAnalyzer:
"""
Class for comprehensive clinical text analysis
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize the clinical text analyzer
Args:
config_path: Path to the configuration file
"""
self.config = self._load_config(config_path)
# Initialize components
self.preprocessor = ClinicalTextPreprocessor(config_path)
self.ner = MedicalNamedEntityRecognizer(config_path)
self.classifier = ClinicalTextClassifier(config_path)
self.summarizer = ClinicalTextSummarizer(config_path)
self.relation_extractor = MedicalRelationExtractor(config_path)
logger.info("Initialized ClinicalTextAnalyzer")
def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
"""
Load configuration from file
Args:
config_path: Path to the configuration file
Returns:
Dict[str, Any]: Configuration dictionary
"""
default_config = {
"analysis": {
"default_pipeline": ["preprocess", "ner", "classify", "summarize", "extract_relations"],
"output_format": "json",
"visualization": True
}
}
if config_path and os.path.exists(config_path):
try:
with open(config_path, 'r') as f:
config = json.load(f)
logger.info(f"Loaded configuration from {config_path}")
return config
except Exception as e:
logger.error(f"Error loading configuration from {config_path}: {e}")
return default_config
else:
logger.info("Using default configuration")
return default_config
def analyze_text(self, text: str, pipeline: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Analyze clinical text using the specified pipeline
Args:
text: Clinical text to analyze
pipeline: List of analysis steps to perform
Returns:
Dict[str, Any]: Analysis results
"""
if not text or not isinstance(text, str):
logger.warning(f"Invalid input text: {text}")
return {"success": False, "error": "Invalid input text"}
# Use default pipeline if not specified
if pipeline is None:
pipeline = self.config["analysis"]["default_pipeline"]
results = {
"success": True,
"text": text,
"pipeline": pipeline
}
try:
# Preprocess text
if "preprocess" in pipeline:
logger.info("Preprocessing text")
preprocessed_text = self.preprocessor.preprocess_text(text)
sentences = self.preprocessor.extract_sentences(text)
medical_terms = self.preprocessor.extract_medical_terms(text)
results["preprocessing"] = {
"preprocessed_text": preprocessed_text,
"sentences": sentences,
"sentence_count": len(sentences),
"medical_terms": medical_terms,
"medical_term_count": len(medical_terms)
}
# Extract demographics if available
demographics = self.preprocessor.extract_demographics(text)
if demographics:
results["demographics"] = demographics
# Named Entity Recognition
if "ner" in pipeline:
logger.info("Performing Named Entity Recognition")
entities = self.ner.extract_entities(text, method="transformer")
# Group entities by type
entities_by_type = {}
for entity in entities:
entity_type = entity["label"]
if entity_type not in entities_by_type:
entities_by_type[entity_type] = []
entities_by_type[entity_type].append(entity["text"])
results["entities"] = {
"all_entities": entities,
"entity_count": len(entities),
"entities_by_type": entities_by_type
}
# Classification
if "classify" in pipeline:
logger.info("Performing text classification")
# Assuming we have a pre-trained model for classification
try:
classification_result = self.classifier.predict(text, method="transformer")
results["classification"] = classification_result
except Exception as e:
logger.error(f"Error in classification: {e}")
results["classification"] = {"error": str(e)}
# Summarization
if "summarize" in pipeline:
logger.info("Generating text summary")
summary_result = self.summarizer.summarize(text, method="hybrid")
if summary_result["success"]:
results["summary"] = summary_result
else:
logger.error(f"Error in summarization: {summary_result['error']}")
results["summary"] = {"error": summary_result["error"]}
# Extract key findings
findings_result = self.summarizer.key_findings_extract(text)
if findings_result["success"]:
results["key_findings"] = findings_result["findings"]
# Relation Extraction
if "extract_relations" in pipeline and "entities" in results:
logger.info("Extracting relations between entities")
relations = self.relation_extractor.extract_relations(
text,
results["entities"]["all_entities"],
method="hybrid"
)
# Build knowledge graph
knowledge_graph = self.relation_extractor.build_knowledge_graph(relations)
results["relations"] = {
"all_relations": relations,
"relation_count": len(relations),
"knowledge_graph": knowledge_graph
}
# Generate visualizations if enabled
if self.config["analysis"]["visualization"]:
logger.info("Generating visualizations")
visualizations = self._generate_visualizations(results)
results["visualizations"] = visualizations
return results
except Exception as e:
logger.error(f"Error in text analysis: {e}")
return {"success": False, "error": str(e)}
def analyze_document(self, document_path: str, pipeline: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Analyze a clinical document from file
Args:
document_path: Path to the document
pipeline: List of analysis steps to perform
Returns:
Dict[str, Any]: Analysis results
"""
if not document_path or not os.path.exists(document_path):
logger.error(f"Document not found: {document_path}")
return {"success": False, "error": f"Document not found: {document_path}"}
try:
# Read document
with open(document_path, 'r', encoding='utf-8') as f:
text = f.read()
# Analyze text
results = self.analyze_text(text, pipeline)
results["document_path"] = document_path
return results
except Exception as e:
logger.error(f"Error analyzing document: {e}")
return {"success": False, "error": str(e)}
def analyze_batch(self, texts: List[str], pipeline: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""
Analyze a batch of clinical texts
Args:
texts: List of clinical texts to analyze
pipeline: List of analysis steps to perform
Returns:
List[Dict[str, Any]]: List of analysis results
"""
if not texts:
logger.warning("Empty batch of texts")
return []
results = []
for i, text in enumerate(texts):
logger.info(f"Analyzing text {i+1}/{len(texts)}")
result = self.analyze_text(text, pipeline)
results.append(result)
return results
def _generate_visualizations(self, results: Dict[str, Any]) -> Dict[str, Any]:
"""
Generate visualizations for analysis results
Args:
results: Analysis results
Returns:
Dict[str, Any]: Visualization data
"""
visualizations = {}
# Entity distribution visualization
if "entities" in results and results["entities"]["entity_count"] > 0:
try:
entities_by_type = results["entities"]["entities_by_type"]
entity_counts = {entity_type: len(entities) for entity_type, entities in entities_by_type.items()}
# Create entity distribution chart data
visualizations["entity_distribution"] = {
"type": "bar_chart",
"data": entity_counts
}
except Exception as e:
logger.error(f"Error generating entity visualization: {e}")
# Relation visualization
if "relations" in results and results["relations"]["relation_count"] > 0:
try:
# Use knowledge graph for relation visualization
visualizations["relations"] = {
"type": "knowledge_graph",
"data": results["relations"]["knowledge_graph"]
}
except Exception as e:
logger.error(f"Error generating relation visualization: {e}")
# Text summary visualization
if "summary" in results and "summary" in results["summary"]:
try:
original_text = results["text"]
summary = results["summary"]["summary"]
visualizations["summary_comparison"] = {
"type": "text_comparison",
"data": {
"original": {
"text": original_text[:1000] + ("..." if len(original_text) > 1000 else ""),
"word_count": len(original_text.split())
},
"summary": {
"text": summary,
"word_count": len(summary.split())
},
"compression_ratio": len(summary.split()) / len(original_text.split()) if len(original_text.split()) > 0 else 0
}
}
except Exception as e:
logger.error(f"Error generating summary visualization: {e}")
return visualizations
def generate_report(self, analysis_results: Dict[str, Any],
output_path: Optional[str] = None,
format: str = "html") -> str:
"""
Generate a report from analysis results
Args:
analysis_results: Analysis results
output_path: Path to save the report
format: Report format ('html', 'json', or 'text')
Returns:
str: Report content or path to saved report
"""
if not analysis_results or not analysis_results.get("success", False):
logger.warning("Invalid analysis results for report generation")
return "Error: Invalid analysis results"
try:
if format == "html":
report = self._generate_html_report(analysis_results)
elif format == "json":
report = json.dumps(analysis_results, indent=2)
elif format == "text":
report = self._generate_text_report(analysis_results)
else:
logger.error(f"Unsupported report format: {format}")
return f"Error: Unsupported report format: {format}"
# Save report if output path is provided
if output_path:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(report)
logger.info(f"Report saved to {output_path}")
return output_path
return report
except Exception as e:
logger.error(f"Error generating report: {e}")
return f"Error generating report: {e}"
def _generate_html_report(self, results: Dict[str, Any]) -> str:
"""
Generate an HTML report from analysis results
Args:
results: Analysis results
Returns:
str: HTML report
"""
html = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Clinical Text Analysis Report</title>
<style>
body {{
font-family: Arial, sans-serif;
line-height: 1.6;
color: #333;
max-width: 1200px;
margin: 0 auto;
padding: 20px;
}}
h1, h2, h3 {{
color: #2c3e50;
}}
.section {{
margin-bottom: 30px;
padding: 20px;
border-radius: 5px;
background-color: #f9f9f9;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}}
.entity {{
display: inline-block;
margin: 5px;
padding: 5px 10px;
border-radius: 15px;
font-size: 14px;
}}
.entity-DISEASE {{
background-color: #ffcccc;
}}
.entity-SYMPTOM {{
background-color: #ffffcc;
}}
.entity-MEDICATION {{
background-color: #ccffcc;
}}
.entity-PROCEDURE {{
background-color: #ccccff;
}}
.entity-TEST {{
background-color: #ffccff;
}}
table {{
width: 100%;
border-collapse: collapse;
margin-top: 10px;
}}
th, td {{
padding: 8px;
text-align: left;
border-bottom: 1px solid #ddd;
}}
th {{
background-color: #f2f2f2;
}}
.text-container {{
max-height: 300px;
overflow-y: auto;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
background-color: #fff;
}}
</style>
</head>
<body>
<h1>Clinical Text Analysis Report</h1>
<div class="section">
<h2>Overview</h2>
<p><strong>Analysis Pipeline:</strong> {', '.join(results['pipeline'])}</p>
<p><strong>Text Length:</strong> {len(results['text'])} characters, {len(results['text'].split())} words</p>
"""
# Add preprocessing section if available
if "preprocessing" in results:
html += f"""
<div class="section">
<h2>Text Preprocessing</h2>
<p><strong>Sentence Count:</strong> {results['preprocessing']['sentence_count']}</p>
<p><strong>Medical Terms:</strong> {results['preprocessing']['medical_term_count']}</p>
<h3>Medical Terms</h3>
<div class="text-container">
{', '.join(results['preprocessing']['medical_terms'][:100])}
{'...' if len(results['preprocessing']['medical_terms']) > 100 else ''}
</div>
</div>
"""
# Add demographics section if available
if "demographics" in results:
html += f"""
<div class="section">
<h2>Patient Demographics</h2>
<table>
<tr>
<th>Attribute</th>
<th>Value</th>
</tr>
"""
for key, value in results["demographics"].items():
html += f"""
<tr>
<td>{key.replace('_', ' ').title()}</td>
<td>{value}</td>
</tr>
"""
html += """
</table>
</div>
"""
# Add entities section if available
if "entities" in results:
html += f"""
<div class="section">
<h2>Named Entity Recognition</h2>
<p><strong>Total Entities:</strong> {results['entities']['entity_count']}</p>
<h3>Entities by Type</h3>
"""
for entity_type, entities in results["entities"]["entities_by_type"].items():
html += f"""
<h4>{entity_type} ({len(entities)})</h4>
<div>
"""
for entity in entities[:30]:
html += f"""
<span class="entity entity-{entity_type}">{entity}</span>
"""
if len(entities) > 30:
html += "..."
html += """
</div>
"""
html += """
</div>
"""
# Add classification section if available
if "classification" in results and "error" not in results["classification"]:
html += f"""
<div class="section">
<h2>Text Classification</h2>
<table>
<tr>
<th>Label</th>
<th>Confidence</th>
</tr>
"""
for label, score in results["classification"]["predictions"].items():
html += f"""
<tr>
<td>{label}</td>
<td>{score:.4f}</td>
</tr>
"""
html += """
</table>
</div>
"""
# Add summary section if available
if "summary" in results and "error" not in results["summary"]:
html += f"""
<div class="section">
<h2>Text Summarization</h2>
<p><strong>Method:</strong> {results['summary']['method']}</p>
<p><strong>Compression Ratio:</strong> {results['summary']['compression_ratio']:.2%}</p>
<h3>Summary</h3>
<div class="text-container">
{results['summary']['summary']}
</div>
</div>
"""
# Add key findings section if available
if "key_findings" in results:
html += f"""
<div class="section">
<h2>Key Findings</h2>
<table>
<tr>
<th>Category</th>
<th>Findings</th>
</tr>
"""
for category, findings in results["key_findings"].items():
if findings:
html += f"""
<tr>
<td>{category.replace('_', ' ').title()}</td>
<td>{', '.join(findings)}</td>
</tr>
"""
html += """
</table>
</div>
"""
# Add relations section if available
if "relations" in results and results["relations"]["relation_count"] > 0:
html += f"""
<div class="section">
<h2>Relation Extraction</h2>
<p><strong>Total Relations:</strong> {results['relations']['relation_count']}</p>
<h3>Extracted Relations</h3>
<table>
<tr>
<th>Relation Type</th>
<th>Entity 1</th>
<th>Entity 2</th>
<th>Confidence</th>
</tr>
"""
for relation in results["relations"]["all_relations"][:30]:
html += f"""
<tr>
<td>{relation['relation_type']}</td>
<td>{relation['entity1']['text']} ({relation['entity1']['label']})</td>
<td>{relation['entity2']['text']} ({relation['entity2']['label']})</td>
<td>{relation['confidence']:.2f}</td>
</tr>
"""
if len(results["relations"]["all_relations"]) > 30:
html += """
<tr>
<td colspan="4">...</td>
</tr>
"""
html += """
</table>
</div>
"""
# Close HTML tags
html += """
</div>
</body>
</html>
"""
return html
def _generate_text_report(self, results: Dict[str, Any]) -> str:
"""
Generate a plain text report from analysis results
Args:
results: Analysis results
Returns:
str: Text report
"""
report = "CLINICAL TEXT ANALYSIS REPORT\n"
report += "=" * 30 + "\n\n"
# Overview
report += "OVERVIEW\n"
report += "-" * 8 + "\n"
report += f"Analysis Pipeline: {', '.join(results['pipeline'])}\n"
report += f"Text Length: {len(results['text'])} characters, {len(results['text'].split())} words\n\n"
# Preprocessing
if "preprocessing" in results:
report += "TEXT PREPROCESSING\n"
report += "-" * 17 + "\n"
report += f"Sentence Count: {results['preprocessing']['sentence_count']}\n"
report += f"Medical Terms: {results['preprocessing']['medical_term_count']}\n"
# List some medical terms
report += "Sample Medical Terms: "
report += ", ".join(results['preprocessing']['medical_terms'][:20])
if len(results['preprocessing']['medical_terms']) > 20:
report += "..."
report += "\n\n"
# Demographics
if "demographics" in results:
report += "PATIENT DEMOGRAPHICS\n"
report += "-" * 20 + "\n"
for key, value in results["demographics"].items():
report += f"{key.replace('_', ' ').title()}: {value}\n"
report += "\n"
# Entities
if "entities" in results:
report += "NAMED ENTITY RECOGNITION\n"
report += "-" * 24 + "\n"
report += f"Total Entities: {results['entities']['entity_count']}\n\n"
report += "Entities by Type:\n"
for entity_type, entities in results["entities"]["entities_by_type"].items():
report += f"{entity_type} ({len(entities)}): "
report += ", ".join(entities[:15])
if len(entities) > 15:
report += "..."
report += "\n"
report += "\n"
# Classification
if "classification" in results and "error" not in results["classification"]:
report += "TEXT CLASSIFICATION\n"
report += "-" * 18 + "\n"
for label, score in results["classification"]["predictions"].items():
report += f"{label}: {score:.4f}\n"
report += "\n"
# Summary
if "summary" in results and "error" not in results["summary"]:
report += "TEXT SUMMARIZATION\n"
report += "-" * 17 + "\n"
report += f"Method: {results['summary']['method']}\n"
report += f"Compression Ratio: {results['summary']['compression_ratio']:.2%}\n\n"
report += "Summary:\n"
report += results['summary']['summary'] + "\n\n"
# Key Findings
if "key_findings" in results:
report += "KEY FINDINGS\n"
report += "-" * 12 + "\n"
for category, findings in results["key_findings"].items():
if findings:
report += f"{category.replace('_', ' ').title()}:\n"
for finding in findings:
report += f"- {finding}\n"
report += "\n"
# Relations
if "relations" in results and results["relations"]["relation_count"] > 0:
report += "RELATION EXTRACTION\n"
report += "-" * 19 + "\n"
report += f"Total Relations: {results['relations']['relation_count']}\n\n"
report += "Sample Relations:\n"
for relation in results["relations"]["all_relations"][:15]:
report += f"{relation['entity1']['text']} --[{relation['relation_type']}]--> {relation['entity2']['text']} (Confidence: {relation['confidence']:.2f})\n"
if len(results["relations"]["all_relations"]) > 15:
report += "...\n"
report += "\n"
return report
def save_results(self, results: Dict[str, Any], output_path: str) -> bool:
"""
Save analysis results to file
Args:
results: Analysis results
output_path: Path to save results
Returns:
bool: Success flag
"""
try:
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2)
logger.info(f"Results saved to {output_path}")
return True
except Exception as e:
logger.error(f"Error saving results: {e}")
return False