Python项目:智能医疗辅助诊断系统开发(3)

源代码续

"""
Medical Image Classifier Module

This module provides classes and functions for classifying medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.

Key features:
- Traditional machine learning classifiers
- Deep learning models for medical image classification
- Model training, evaluation, and prediction
- Support for transfer learning with pre-trained models
"""

import os
import logging
import numpy as np
import json
import pickle
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
import cv2

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class MedicalImageClassifier:
    """Base class for medical image classification"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the medical image classifier
        
        Args:
            config_path: Path to configuration file
        """
        self.config = {}
        if config_path and os.path.isfile(config_path):
            try:
                with open(config_path, 'r') as f:
                    self.config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
            except Exception as e:
                logger.error(f"Error loading configuration: {str(e)}")
        
        # Default configuration
        if not self.config:
            self.config = {
                "model": {
                    "type": "traditional",  # traditional or deep_learning
                    "algorithm": "svm",     # svm, random_forest, etc. for traditional
                    "architecture": "resnet50",  # For deep learning
                    "pretrained": True,
                    "batch_size": 32,
                    "learning_rate": 0.001,
                    "epochs": 10
                },
                "preprocessing": {
                    "target_size": [224, 224],
                    "normalize": True,
                    "augmentation": True
                }
            }
        
        self.model = None
        self.classes = []
        self.is_trained = False
    
    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess an image for classification
        
        Args:
            image: Input image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        if image is None or image.size == 0:
            logger.error("Cannot preprocess empty image")
            return np.array([])
        
        # Resize to target size
        target_size = self.config["preprocessing"]["target_size"]
        if image.shape[0] != target_size[0] or image.shape[1] != target_size[1]:
            image = cv2.resize(image, (target_size[1], target_size[0]))
        
        # Convert to RGB if grayscale
        if len(image.shape) == 2:
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        
        # Normalize pixel values
        if self.config["preprocessing"]["normalize"]:
            if image.dtype != np.float32 and image.dtype != np.float64:
                image = image.astype(np.float32)
                
            if np.max(image) > 1.0:
                image = image / 255.0
        
        return image
    
    def extract_features(self, image: np.ndarray) -> np.ndarray:
        """
        Extract features from an image for traditional ML models
        
        Args:
            image: Preprocessed image
            
        Returns:
            np.ndarray: Extracted features
        """
        # This is a placeholder for feature extraction
        # For traditional ML, we need to implement specific feature extraction methods
        
        features = []
        
        # Basic statistical features
        if len(image.shape) == 3:
            # For RGB images, calculate stats for each channel
            for channel in range(image.shape[2]):
                channel_data = image[:, :, channel]
                features.extend([
                    np.mean(channel_data),
                    np.std(channel_data),
                    np.min(channel_data),
                    np.max(channel_data)
                ])
        else:
            # For grayscale images
            features.extend([
                np.mean(image),
                np.std(image),
                np.min(image),
                np.max(image)
            ])
        
        # Histogram features
        hist_features = self._extract_histogram_features(image)
        features.extend(hist_features)
        
        # HOG features (simplified)
        try:
            from skimage.feature import hog
            
            # Convert to uint8 for HOG
            img_uint8 = (image * 255).astype(np.uint8)
            if len(img_uint8.shape) == 3:
                img_uint8 = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
            
            hog_features = hog(img_uint8, orientations=9, pixels_per_cell=(8, 8),
                              cells_per_block=(2, 2), block_norm='L2-Hys', feature_vector=True)
            
            # Take a subset of HOG features to keep the feature vector manageable
            step = max(1, len(hog_features) // 100)
            features.extend(hog_features[::step])
            
        except Exception as e:
            logger.warning(f"Error extracting HOG features: {str(e)}")
        
        return np.array(features)
    
    def _extract_histogram_features(self, image: np.ndarray) -> List[float]:
        """
        Extract histogram features from an image
        
        Args:
            image: Input image
            
        Returns:
            List[float]: Histogram features
        """
        hist_features = []
        
        # Convert to uint8 for histogram calculation
        img_uint8 = (image * 255).astype(np.uint8)
        
        if len(img_uint8.shape) == 3:
            # For RGB images, calculate histogram for each channel
            for channel in range(img_uint8.shape[2]):
                hist = cv2.calcHist([img_uint8], [channel], None, [32], [0, 256])
                hist = hist.flatten() / np.sum(hist)  # Normalize
                hist_features.extend(hist)
        else:
            # For grayscale images
            hist = cv2.calcHist([img_uint8], [0], None, [32], [0, 256])
            hist = hist.flatten() / np.sum(hist)  # Normalize
            hist_features.extend(hist)
        
        return hist_features
    
    def train(self, images: List[np.ndarray], labels: List[str], 
              validation_split: float = 0.2) -> Dict[str, Any]:
        """
        Train the classifier on a set of images
        
        Args:
            images: List of input images
            labels: List of corresponding labels
            validation_split: Fraction of data to use for validation
            
        Returns:
            Dict[str, Any]: Training results
        """
        if len(images) != len(labels):
            logger.error("Number of images and labels must match")
            return {"success": False, "error": "Number of images and labels must match"}
        
        if len(images) == 0:
            logger.error("No training data provided")
            return {"success": False, "error": "No training data provided"}
        
        # Get unique classes
        self.classes = sorted(list(set(labels)))
        logger.info(f"Training classifier with {len(self.classes)} classes: {self.classes}")
        
        # Preprocess images
        processed_images = []
        for image in images:
            processed = self.preprocess_image(image)
            processed_images.append(processed)
        
        # Split into training and validation sets
        if validation_split > 0:
            split_idx = int(len(images) * (1 - validation_split))
            train_images = processed_images[:split_idx]
            train_labels = labels[:split_idx]
            val_images = processed_images[split_idx:]
            val_labels = labels[split_idx:]
        else:
            train_images = processed_images
            train_labels = labels
            val_images = []
            val_labels = []
        
        # Train based on model type
        if self.config["model"]["type"] == "traditional":
            return self._train_traditional(train_images, train_labels, val_images, val_labels)
        else:  # deep_learning
            return self._train_deep_learning(train_images, train_labels, val_images, val_labels)
    
    def _train_traditional(self, train_images: List[np.ndarray], train_labels: List[str],
                          val_images: List[np.ndarray], val_labels: List[str]) -> Dict[str, Any]:
        """
        Train a traditional machine learning model
        
        Args:
            train_images: Training images
            train_labels: Training labels
            val_images: Validation images
            val_labels: Validation labels
            
        Returns:
            Dict[str, Any]: Training results
        """
        try:
            from sklearn.preprocessing import LabelEncoder
            from sklearn.svm import SVC
            from sklearn.ensemble import RandomForestClassifier
            from sklearn.neighbors import KNeighborsClassifier
            from sklearn.metrics import accuracy_score, classification_report
            
            # Extract features from images
            logger.info("Extracting features from training images...")
            train_features = np.array([self.extract_features(img) for img in train_images])
            
            # Encode labels
            label_encoder = LabelEncoder()
            label_encoder.fit(self.classes)
            y_train = label_encoder.transform(train_labels)
            
            # Select and train the model
            algorithm = self.config["model"]["algorithm"]
            
            if algorithm == "svm":
                model = SVC(probability=True)
            elif algorithm == "random_forest":
                model = RandomForestClassifier(n_estimators=100)
            elif algorithm == "knn":
                model = KNeighborsClassifier(n_neighbors=5)
            else:
                logger.error(f"Unsupported algorithm: {algorithm}")
                return {"success": False, "error": f"Unsupported algorithm: {algorithm}"}
            
            logger.info(f"Training {algorithm} classifier...")
            model.fit(train_features, y_train)
            self.model = model
            
            # Save label encoder
            self.label_encoder = label_encoder
            
            # Evaluate on training set
            y_pred = model.predict(train_features)
            train_accuracy = accuracy_score(y_train, y_pred)
            
            results = {
                "success": True,
                "model_type": "traditional",
                "algorithm": algorithm,
                "train_accuracy": train_accuracy,
                "classes": self.classes
            }
            
            # Evaluate on validation set if available
            if len(val_images) > 0:
                logger.info("Evaluating on validation set...")
                val_features = np.array([self.extract_features(img) for img in val_images])
                y_val = label_encoder.transform(val_labels)
                y_val_pred = model.predict(val_features)
                val_accuracy = accuracy_score(y_val, y_val_pred)
                
                results["val_accuracy"] = val_accuracy
                results["classification_report"] = classification_report(y_val, y_val_pred, 
                                                                       target_names=self.classes,
                                                                       output_dict=True)
            
            self.is_trained = True
            logger.info(f"Model training completed with train accuracy: {train_accuracy:.4f}")
            
            return results
            
        except Exception as e:
            logger.error(f"Error training traditional model: {str(e)}")
            return {"success": False, "error": str(e)}
    
    def _train_deep_learning(self, train_images: List[np.ndarray], train_labels: List[str],
                           val_images: List[np.ndarray], val_labels: List[str]) -> Dict[str, Any]:
        """
        Train a deep learning model
        
        Args:
            train_images: Training images
            train_labels: Training labels
            val_images: Validation images
            val_labels: Validation labels
            
        Returns:
            Dict[str, Any]: Training results
        """
        try:
            import tensorflow as tf
            from tensorflow.keras.applications import ResNet50, VGG16, MobileNetV2
            from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
            from tensorflow.keras.models import Model
            from tensorflow.keras.optimizers import Adam
            from tensorflow.keras.utils import to_categorical
            from sklearn.preprocessing import LabelEncoder
            
            # Encode labels
            label_encoder = LabelEncoder()
            label_encoder.fit(self.classes)
            y_train = label_encoder.transform(train_labels)
            y_train_cat = to_categorical(y_train, num_classes=len(self.classes))
            
            if len(val_images) > 0:
                y_val = label_encoder.transform(val_labels)
                y_val_cat = to_categorical(y_val, num_classes=len(self.classes))
            
            # Convert images to numpy arrays
            X_train = np.array(train_images)
            if len(val_images) > 0:
                X_val = np.array(val_images)
            
            # Create model
            architecture = self.config["model"]["architecture"]
            input_shape = (*self.config["preprocessing"]["target_size"], 3)
            
            if architecture == "resnet50":
                base_model = ResNet50(weights='imagenet' if self.config["model"]["pretrained"] else None,
                                     include_top=False, input_shape=input_shape)
            elif architecture == "vgg16":
                base_model = VGG16(weights='imagenet' if self.config["model"]["pretrained"] else None,
                                  include_top=False, input_shape=input_shape)
            elif architecture == "mobilenetv2":
                base_model = MobileNetV2(weights='imagenet' if self.config["model"]["pretrained"] else None,
                                        include_top=False, input_shape=input_shape)
            else:
                logger.error(f"Unsupported architecture: {architecture}")
                return {"success": False, "error": f"Unsupported architecture: {architecture}"}
            
            # Add classification head
            x = base_model.output
            x = GlobalAveragePooling2D()(x)
            x = Dense(256, activation='relu')(x)
            predictions = Dense(len(self.classes), activation='softmax')(x)
            
            model = Model(inputs=base_model.input, outputs=predictions)
            
            # Freeze base model layers if using pretrained weights
            if self.config["model"]["pretrained"]:
                for layer in base_model.layers:
                    layer.trainable = False
            
            # Compile model
            model.compile(optimizer=Adam(learning_rate=self.config["model"]["learning_rate"]),
                         loss='categorical_crossentropy',
                         metrics=['accuracy'])
            
            # Train model
            logger.info(f"Training {architecture} model...")
            
            if len(val_images) > 0:
                history = model.fit(
                    X_train, y_train_cat,
                    epochs=self.config["model"]["epochs"],
                    batch_size=self.config["model"]["batch_size"],
                    validation_data=(X_val, y_val_cat),
                    verbose=1
                )
                val_loss, val_accuracy = model.evaluate(X_val, y_val_cat)
            else:
                history = model.fit(
                    X_train, y_train_cat,
                    epochs=self.config["model"]["epochs"],
                    batch_size=self.config["model"]["batch_size"],
                    validation_split=0.1,
                    verbose=1
                )
                val_loss, val_accuracy = None, None
            
            self.model = model
            self.label_encoder = label_encoder
            self.is_trained = True
            
            # Get training metrics
            train_loss, train_accuracy = model.evaluate(X_train, y_train_cat)
            
            results = {
                "success": True,
                "model_type": "deep_learning",
                "architecture": architecture,
                "train_accuracy": float(train_accuracy),
                "train_loss": float(train_loss),
                "classes": self.classes
            }
            
            if val_loss is not None and val_accuracy is not None:
                results["val_accuracy"] = float(val_accuracy)
                results["val_loss"] = float(val_loss)
            
            # Add training history
            results["history"] = {
                "accuracy": [float(acc) for acc in history.history['accuracy']],
                "loss": [float(loss) for loss in history.history['loss']]
            }
            
            if 'val_accuracy' in history.history:
                results["history"]["val_accuracy"] = [float(acc) for acc in history.history['val_accuracy']]
                results["history"]["val_loss"] = [float(loss) for loss in history.history['val_loss']]
            
            logger.info(f"Model training completed with train accuracy: {train_accuracy:.4f}")
            
            return results
            
        except Exception as e:
            logger.error(f"Error training deep learning model: {str(e)}")
            return {"success": False, "error": str(e)}
    
    def predict(self, image: np.ndarray) -> Dict[str, Any]:
        """
        Predict the class of an image
        
        Args:
            image: Input image
            
        Returns:
            Dict[str, Any]: Prediction results
        """
        if not self.is_trained or self.model is None:
            logger.error("Model not trained")
            return {"success": False, "error": "Model not trained"}
        
        try:
            # Preprocess image
            processed_image = self.preprocess_image(image)
            
            if self.config["model"]["type"] == "traditional":
                # Extract features
                features = self.extract_features(processed_image)
                features = features.reshape(1, -1)
                
                # Predict
                pred_proba = self.model.predict_proba(features)[0]
                pred_class_idx = np.argmax(pred_proba)
                pred_class = self.classes[pred_class_idx]
                
                # Get probabilities for all classes
                class_probabilities = {self.classes[i]: float(prob) for i, prob in enumerate(pred_proba)}
                
                return {
                    "success": True,
                    "predicted_class": pred_class,
                    "confidence": float(pred_proba[pred_class_idx]),
                    "probabilities": class_probabilities
                }
                
            else:  # deep_learning
                # Prepare image for model
                input_image = np.expand_dims(processed_image, axis=0)
                
                # Predict
                pred_proba = self.model.predict(input_image)[0]
                pred_class_idx = np.argmax(pred_proba)
                pred_class = self.classes[pred_class_idx]
                
                # Get probabilities for all classes
                class_probabilities = {self.classes[i]: float(prob) for i, prob in enumerate(pred_proba)}
                
                return {
                    "success": True,
                    "predicted_class": pred_class,
                    "confidence": float(pred_proba[pred_class_idx]),
                    "probabilities": class_probabilities
                }
                
        except Exception as e:
            logger.error(f"Error during prediction: {str(e)}")
            return {"success": False, "error": str(e)}
    
    def save_model(self, model_path: str) -> bool:
        """
        Save the trained model to a file
        
        Args:
            model_path: Path to save the model
            
        Returns:
            bool: True if successful, False otherwise
        """
        if not self.is_trained or self.model is None:
            logger.error("Cannot save untrained model")
            return False
        
        try:
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            
            # Save model metadata
            metadata = {
                "config": self.config,
                "classes": self.classes,
                "model_type": self.config["model"]["type"]
            }
            
            metadata_path = os.path.splitext(model_path)[0] + "_metadata.json"
            with open(metadata_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            
            # Save the actual model
            if self.config["model"]["type"] == "traditional":
                # Save scikit-learn model
                with open(model_path, 'wb') as f:
                    pickle.dump({
                        "model": self.model,
                        "label_encoder": self.label_encoder
                    }, f)
            else:
                # Save deep learning model
                self.model.save(model_path)
            
            logger.info(f"Model saved to {model_path}")
            return True
            
        except Exception as e:
            logger.error(f"Error saving model: {str(e)}")
            return False
    
    def load_model(self, model_path: str) -> bool:
        """
        Load a trained model from a file
        
        Args:
            model_path: Path to the saved model
            
        Returns:
            bool: True if successful, False otherwise
        """
        try:
            # Load model metadata
            metadata_path = os.path.splitext(model_path)[0] + "_metadata.json"
            if not os.path.isfile(metadata_path):
                logger.error(f"Model metadata not found: {metadata_path}")
                return False
            
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            
            self.config = metadata["config"]
            self.classes = metadata["classes"]
            
            # Load the actual model
            if metadata["model_type"] == "traditional":
                # Load scikit-learn model
                with open(model_path, 'rb') as f:
                    model_data = pickle.load(f)
                
                self.model = model_data["model"]
                self.label_encoder = model_data["label_encoder"]
            else:
                # Load deep learning model
                import tensorflow as tf
                self.model = tf.keras.models.load_model(model_path)
            
            self.is_trained = True
            logger.info(f"Model loaded from {model_path}")
            return True
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            return False
    
    def visualize_predictions(self, image: np.ndarray, save_path: Optional[str] = None) -> None:
        """
        Visualize the prediction results
        
        Args:
            image: Input image
            save_path: Optional path to save the visualization
        """
        if not self.is_trained or self.model is None:
            logger.error("Model not trained")
            return
        
        # Get prediction
        prediction = self.predict(image)
        
        if not prediction["success"]:
            logger.error(f"Prediction failed: {prediction.get('error', 'Unknown error')}")
            return
        
        # Preprocess image for display
        processed_image = self.preprocess_image(image)
        
        plt.figure(figsize=(10, 6))
        
        # Display image
        plt.subplot(1, 2, 1)
        plt.title("Input Image")
        plt.imshow(processed_image)
        plt.axis('off')
        
        # Display prediction probabilities
        plt.subplot(1, 2, 2)
        plt.title(f"Prediction: {prediction['predicted_class']}\nConfidence: {prediction['confidence']:.2f}")
        
        classes = list(prediction["probabilities"].keys())
        probs = list(prediction["probabilities"].values())
        
        # Sort by probability
        sorted_idx = np.argsort(probs)[::-1]
        classes = [classes[i] for i in sorted_idx]
        probs = [probs[i] for i in sorted_idx]
        
        plt.barh(range(len(classes)), probs, tick_label=classes)
        plt.xlim(0, 1)
        plt.xlabel('Probability')
        
        plt.tight_layout()
        
        if save_path:
            try:
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                logger.info(f"Visualization saved to {save_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {str(e)}")
        
        plt.close()


class ChestXRayClassifier(MedicalImageClassifier):
    """Specialized classifier for chest X-ray images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the chest X-ray classifier
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
        
        # Set default configuration for chest X-ray classification
        self.config.setdefault("xray", {})
        self.config["xray"].setdefault("lung_segmentation", True)
        self.config["xray"].setdefault("common_conditions", [
            "normal", "pneumonia", "tuberculosis", "covid-19", "lung cancer"
        ])
    
    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess a chest X-ray image
        
        Args:
            image: Input chest X-ray image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        # Apply base preprocessing
        image = super().preprocess_image(image)
        
        # Apply X-ray specific preprocessing
        if self.config["xray"].get("lung_segmentation", False):
            # This would ideally use a lung segmentation model
            # For simplicity, we'll just apply some basic enhancements
            
            # Convert to grayscale if RGB
            if len(image.shape) == 3 and image.shape[2] == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                gray = image
            
            # Apply CLAHE for contrast enhancement
            if np.max(gray) <= 1.0:
                gray = (gray * 255).astype(np.uint8)
            
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            enhanced = clahe.apply(gray)
            
            # Convert back to RGB if needed
            if len(image.shape) == 3 and image.shape[2] == 3:
                enhanced = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2RGB)
                enhanced = enhanced.astype(np.float32) / 255.0
            
            return enhanced
        
        return image


class BrainMRIClassifier(MedicalImageClassifier):
    """Specialized classifier for brain MRI images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the brain MRI classifier
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
        
        # Set default configuration for brain MRI classification
        self.config.setdefault("mri", {})
        self.config["mri"].setdefault("brain_extraction", True)
        self.config["mri"].setdefault("common_conditions", [
            "normal", "tumor", "alzheimer", "multiple sclerosis", "stroke"
        ])
    
    def preprocess_image(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess a brain MRI image
        
        Args:
            image: Input brain MRI image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        # Apply base preprocessing
        image = super().preprocess_image(image)
        
        # Apply MRI specific preprocessing
        if self.config["mri"].get("brain_extraction", False):
            # This would ideally use a brain extraction model
            # For simplicity, we'll just apply some basic enhancements
            
            # Convert to grayscale if RGB
            if len(image.shape) == 3 and image.shape[2] == 3:
                gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            else:
                gray = image
            
            # Apply threshold to roughly segment brain
            if np.max(gray) <= 1.0:
                gray = (gray * 255).astype(np.uint8)
            
            _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            
            # Apply morphological operations
            kernel = np.ones((5, 5), np.uint8)
            binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
            binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
            
            # Use binary as mask
            masked = cv2.bitwise_and(gray, gray, mask=binary)
            
            # Convert back to RGB if needed
            if len(image.shape) == 3 and image.shape[2] == 3:
                masked = cv2.cvtColor(masked, cv2.COLOR_GRAY2RGB)
                masked = masked.astype(np.float32) / 255.0
            
            return masked
        
        return image


if __name__ == "__main__":
    # Example usage
    classifier = MedicalImageClassifier()
    print("Medical Image Classifier initialized")

modules\image_analysis\image_processor.py

"""
Medical Image Processor Module

This module provides classes and functions for processing and analyzing medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.

Key features:
- Image loading and standardization
- Image enhancement and filtering
- Feature extraction
- Segmentation
- Registration
"""

import os
import logging
import numpy as np
import cv2
import pydicom
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
from skimage import measure, morphology, segmentation, filters, feature

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class MedicalImageProcessor:
    """Base class for medical image processing"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the medical image processor
        
        Args:
            config_path: Path to configuration file
        """
        self.config = {}
        if config_path and os.path.isfile(config_path):
            try:
                import json
                with open(config_path, 'r') as f:
                    self.config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
            except Exception as e:
                logger.error(f"Error loading configuration: {str(e)}")
        
        # Default configuration
        if not self.config:
            self.config = {
                "preprocessing": {
                    "normalize": True,
                    "clahe": True,
                    "clahe_clip_limit": 2.0,
                    "clahe_grid_size": [8, 8],
                    "denoise": True,
                    "target_size": [512, 512]
                },
                "segmentation": {
                    "threshold_method": "otsu",
                    "morphology_operations": True,
                    "remove_small_objects": True,
                    "min_size": 100
                }
            }
    
    def load_image(self, image_path: str) -> Optional[np.ndarray]:
        """
        Load an image from file
        
        Args:
            image_path: Path to the image file
            
        Returns:
            np.ndarray: Loaded image or None if loading fails
        """
        if not os.path.isfile(image_path):
            logger.error(f"Image file not found: {image_path}")
            return None
        
        try:
            # Check if it's a DICOM file
            if image_path.lower().endswith('.dcm'):
                return self._load_dicom(image_path)
            
            # Regular image file
            image = cv2.imread(image_path)
            if image is None:
                logger.error(f"Failed to load image: {image_path}")
                return None
            
            # Convert to grayscale if it's a color image and not RGB
            if len(image.shape) == 3 and image.shape[2] == 3:
                # Keep the original if it's RGB
                pass
            elif len(image.shape) == 3:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            
            return image
        
        except Exception as e:
            logger.error(f"Error loading image {image_path}: {str(e)}")
            return None
    
    def _load_dicom(self, dicom_path: str) -> Optional[np.ndarray]:
        """
        Load a DICOM image
        
        Args:
            dicom_path: Path to the DICOM file
            
        Returns:
            np.ndarray: DICOM image data or None if loading fails
        """
        try:
            dicom_data = pydicom.dcmread(dicom_path)
            image = dicom_data.pixel_array
            
            # Convert to float and normalize to [0, 1]
            image = image.astype(np.float32)
            
            # Apply window center and width if available
            if hasattr(dicom_data, 'WindowCenter') and hasattr(dicom_data, 'WindowWidth'):
                center = dicom_data.WindowCenter
                width = dicom_data.WindowWidth
                
                # Handle multiple window settings
                if isinstance(center, pydicom.multival.MultiValue):
                    center = center[0]
                if isinstance(width, pydicom.multival.MultiValue):
                    width = width[0]
                
                image = self._apply_window_level(image, center, width)
            else:
                # Simple normalization if window settings are not available
                min_val = np.min(image)
                max_val = np.max(image)
                if max_val > min_val:
                    image = (image - min_val) / (max_val - min_val)
            
            return image
        
        except Exception as e:
            logger.error(f"Error loading DICOM {dicom_path}: {str(e)}")
            return None
    
    def _apply_window_level(self, image: np.ndarray, center: float, width: float) -> np.ndarray:
        """
        Apply window center and width to the image
        
        Args:
            image: Input image
            center: Window center
            width: Window width
            
        Returns:
            np.ndarray: Windowed image
        """
        min_val = center - width / 2
        max_val = center + width / 2
        
        image = np.clip(image, min_val, max_val)
        image = (image - min_val) / (max_val - min_val)
        
        return image
    
    def preprocess(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess the medical image
        
        Args:
            image: Input image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        if image is None:
            logger.error("Cannot preprocess None image")
            return np.array([])
        
        # Convert to float and ensure range [0, 1]
        if image.dtype != np.float32 and image.dtype != np.float64:
            image = image.astype(np.float32)
            if np.max(image) > 1.0:
                image = image / 255.0
        
        # Resize if needed
        if self.config["preprocessing"].get("target_size"):
            target_size = self.config["preprocessing"]["target_size"]
            if image.shape[0] != target_size[0] or image.shape[1] != target_size[1]:
                image = cv2.resize(image, (target_size[1], target_size[0]))
        
        # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
        if self.config["preprocessing"].get("clahe", False):
            if len(image.shape) == 2:  # Grayscale
                clip_limit = self.config["preprocessing"].get("clahe_clip_limit", 2.0)
                grid_size = tuple(self.config["preprocessing"].get("clahe_grid_size", [8, 8]))
                
                # Convert to uint8 for CLAHE
                img_uint8 = (image * 255).astype(np.uint8)
                clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid_size)
                img_uint8 = clahe.apply(img_uint8)
                
                # Convert back to float
                image = img_uint8.astype(np.float32) / 255.0
        
        # Apply denoising
        if self.config["preprocessing"].get("denoise", False):
            if len(image.shape) == 2:  # Grayscale
                # Convert to uint8 for denoising
                img_uint8 = (image * 255).astype(np.uint8)
                img_uint8 = cv2.fastNlMeansDenoising(img_uint8, None, 10, 7, 21)
                
                # Convert back to float
                image = img_uint8.astype(np.float32) / 255.0
        
        return image
    
    def segment(self, image: np.ndarray) -> Tuple[np.ndarray, List[Dict[str, Any]]]:
        """
        Segment the medical image
        
        Args:
            image: Preprocessed input image
            
        Returns:
            Tuple[np.ndarray, List[Dict]]: Segmentation mask and list of region properties
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment empty image")
            return np.array([]), []
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply thresholding
        threshold_method = self.config["segmentation"].get("threshold_method", "otsu")
        
        if threshold_method == "otsu":
            thresh_val = filters.threshold_otsu(image)
            binary = image > thresh_val
        elif threshold_method == "adaptive":
            # Convert to uint8 for adaptive thresholding
            img_uint8 = (image * 255).astype(np.uint8)
            binary = cv2.adaptiveThreshold(
                img_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                cv2.THRESH_BINARY, 11, 2
            )
            binary = binary > 0
        else:
            # Default to simple thresholding
            binary = image > 0.5
        
        # Apply morphological operations
        if self.config["segmentation"].get("morphology_operations", True):
            binary = morphology.binary_opening(binary)
            binary = morphology.binary_closing(binary)
        
        # Remove small objects
        if self.config["segmentation"].get("remove_small_objects", True):
            min_size = self.config["segmentation"].get("min_size", 100)
            binary = morphology.remove_small_objects(binary, min_size=min_size)
        
        # Label connected regions
        labeled_mask, num_labels = measure.label(binary, return_num=True)
        
        # Extract region properties
        regions = []
        if num_labels > 0:
            props = measure.regionprops(labeled_mask)
            
            for prop in props:
                region = {
                    "label": prop.label,
                    "area": prop.area,
                    "centroid": prop.centroid,
                    "bbox": prop.bbox,
                    "eccentricity": prop.eccentricity,
                    "equivalent_diameter": prop.equivalent_diameter,
                    "orientation": prop.orientation,
                    "perimeter": prop.perimeter
                }
                regions.append(region)
        
        return labeled_mask, regions
    
    def extract_features(self, image: np.ndarray, mask: Optional[np.ndarray] = None) -> Dict[str, Any]:
        """
        Extract features from the medical image
        
        Args:
            image: Preprocessed input image
            mask: Optional segmentation mask
            
        Returns:
            Dict[str, Any]: Extracted features
        """
        if image is None or image.size == 0:
            logger.error("Cannot extract features from empty image")
            return {}
        
        features = {}
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            gray_image = image
        
        # Basic statistical features
        if mask is not None:
            # Extract features only from the masked region
            masked_image = gray_image * (mask > 0)
            non_zero = masked_image[mask > 0]
            
            if non_zero.size > 0:
                features["mean"] = float(np.mean(non_zero))
                features["std"] = float(np.std(non_zero))
                features["min"] = float(np.min(non_zero))
                features["max"] = float(np.max(non_zero))
                features["median"] = float(np.median(non_zero))
                features["skewness"] = float(self._calculate_skewness(non_zero))
                features["kurtosis"] = float(self._calculate_kurtosis(non_zero))
        else:
            # Extract features from the entire image
            features["mean"] = float(np.mean(gray_image))
            features["std"] = float(np.std(gray_image))
            features["min"] = float(np.min(gray_image))
            features["max"] = float(np.max(gray_image))
            features["median"] = float(np.median(gray_image))
            features["skewness"] = float(self._calculate_skewness(gray_image))
            features["kurtosis"] = float(self._calculate_kurtosis(gray_image))
        
        # Texture features (GLCM)
        try:
            from skimage.feature import graycomatrix, graycoprops
            
            # Convert to uint8 for GLCM
            img_uint8 = (gray_image * 255).astype(np.uint8)
            
            # Calculate GLCM
            distances = [1]
            angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]
            glcm = graycomatrix(img_uint8, distances, angles, 256, symmetric=True, normed=True)
            
            # Calculate GLCM properties
            props = ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation']
            for prop in props:
                features[f"glcm_{prop}"] = float(np.mean(graycoprops(glcm, prop)[0]))
        
        except Exception as e:
            logger.warning(f"Error calculating GLCM features: {str(e)}")
        
        # Edge features
        try:
            # Canny edge detection
            edges = feature.canny(gray_image)
            features["edge_ratio"] = float(np.sum(edges) / edges.size)
            
            # Sobel gradient magnitude
            sobel = filters.sobel(gray_image)
            features["gradient_mean"] = float(np.mean(sobel))
            features["gradient_std"] = float(np.std(sobel))
        
        except Exception as e:
            logger.warning(f"Error calculating edge features: {str(e)}")
        
        return features
    
    def _calculate_skewness(self, data: np.ndarray) -> float:
        """
        Calculate the skewness of the data
        
        Args:
            data: Input data
            
        Returns:
            float: Skewness value
        """
        mean = np.mean(data)
        std = np.std(data)
        
        if std == 0:
            return 0.0
        
        n = len(data)
        return (np.sum((data - mean) ** 3) / n) / (std ** 3)
    
    def _calculate_kurtosis(self, data: np.ndarray) -> float:
        """
        Calculate the kurtosis of the data
        
        Args:
            data: Input data
            
        Returns:
            float: Kurtosis value
        """
        mean = np.mean(data)
        std = np.std(data)
        
        if std == 0:
            return 0.0
        
        n = len(data)
        return (np.sum((data - mean) ** 4) / n) / (std ** 4) - 3
    
    def visualize(self, image: np.ndarray, mask: Optional[np.ndarray] = None, 
                 regions: Optional[List[Dict[str, Any]]] = None, 
                 save_path: Optional[str] = None) -> None:
        """
        Visualize the image, segmentation mask, and regions
        
        Args:
            image: Input image
            mask: Optional segmentation mask
            regions: Optional list of region properties
            save_path: Optional path to save the visualization
        """
        if image is None or image.size == 0:
            logger.error("Cannot visualize empty image")
            return
        
        plt.figure(figsize=(15, 10))
        
        # Plot original image
        plt.subplot(1, 3, 1)
        plt.title("Original Image")
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        
        # Plot segmentation mask if available
        if mask is not None:
            plt.subplot(1, 3, 2)
            plt.title("Segmentation Mask")
            plt.imshow(mask, cmap='viridis')
            plt.axis('off')
        
        # Plot image with region boundaries if available
        if regions is not None and mask is not None:
            plt.subplot(1, 3, 3)
            plt.title("Regions")
            
            # Create RGB image for visualization
            if len(image.shape) == 2:
                rgb_img = np.stack([image] * 3, axis=2)
            else:
                rgb_img = image.copy()
            
            # Normalize to [0, 1] if needed
            if rgb_img.max() > 1.0:
                rgb_img = rgb_img / 255.0
            
            # Draw region boundaries
            from skimage.segmentation import mark_boundaries
            boundary_img = mark_boundaries(rgb_img, mask, color=(1, 0, 0), mode='thick')
            
            # Draw centroids and labels
            for region in regions:
                y, x = region["centroid"]
                plt.plot(x, y, 'ro', markersize=5)
                plt.text(x, y, str(region["label"]), color='white', 
                         bbox=dict(facecolor='red', alpha=0.5))
            
            plt.imshow(boundary_img)
            plt.axis('off')
        
        plt.tight_layout()
        
        if save_path:
            try:
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                logger.info(f"Visualization saved to {save_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {str(e)}")
        
        plt.close()
    
    def save_results(self, results: Dict[str, Any], output_path: str) -> bool:
        """
        Save processing results to a file
        
        Args:
            results: Processing results
            output_path: Path to save the results
            
        Returns:
            bool: True if successful, False otherwise
        """
        try:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            import json
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=2)
            
            logger.info(f"Results saved to {output_path}")
            return True
        
        except Exception as e:
            logger.error(f"Error saving results: {str(e)}")
            return False


class XRayProcessor(MedicalImageProcessor):
    """Specialized processor for X-ray images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the X-ray processor
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
        
        # X-ray specific configuration
        self.config.setdefault("xray", {})
        self.config["xray"].setdefault("bone_enhancement", True)
        self.config["xray"].setdefault("lung_segmentation", True)
    
    def preprocess(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess the X-ray image
        
        Args:
            image: Input X-ray image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        # Apply base preprocessing
        image = super().preprocess(image)
        
        # Apply X-ray specific preprocessing
        if self.config["xray"].get("bone_enhancement", False):
            image = self._enhance_bone_structures(image)
        
        return image
    
    def _enhance_bone_structures(self, image: np.ndarray) -> np.ndarray:
        """
        Enhance bone structures in X-ray images
        
        Args:
            image: Input X-ray image
            
        Returns:
            np.ndarray: Enhanced image
        """
        # Apply unsharp masking to enhance edges
        blurred = cv2.GaussianBlur(image, (0, 0), 3)
        enhanced = cv2.addWeighted(image, 1.5, blurred, -0.5, 0)
        
        return enhanced
    
    def segment_lungs(self, image: np.ndarray) -> np.ndarray:
        """
        Segment lungs in chest X-ray images
        
        Args:
            image: Preprocessed X-ray image
            
        Returns:
            np.ndarray: Lung segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment lungs from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply Otsu's thresholding
        thresh_val = filters.threshold_otsu(image)
        binary = image < thresh_val  # Invert for lung segmentation (lungs are darker)
        
        # Apply morphological operations
        binary = morphology.binary_opening(binary, morphology.disk(2))
        binary = morphology.binary_closing(binary, morphology.disk(10))
        
        # Remove border objects
        cleared = segmentation.clear_border(binary)
        
        # Keep only the two largest regions (left and right lungs)
        label_image = measure.label(cleared)
        regions = measure.regionprops(label_image)
        
        # Sort regions by area
        regions.sort(key=lambda x: x.area, reverse=True)
        
        # Create mask with only the two largest regions
        lung_mask = np.zeros_like(image, dtype=bool)
        
        if len(regions) >= 2:
            for i in range(min(2, len(regions))):
                region_label = regions[i].label
                lung_mask[label_image == region_label] = True
        
        return lung_mask


class CTProcessor(MedicalImageProcessor):
    """Specialized processor for CT scan images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the CT processor
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
        
        # CT specific configuration
        self.config.setdefault("ct", {})
        self.config["ct"].setdefault("hounsfield_normalization", True)
        self.config["ct"].setdefault("window_presets", {
            "brain": {"center": 40, "width": 80},
            "lung": {"center": -600, "width": 1500},
            "bone": {"center": 400, "width": 1800},
            "soft_tissue": {"center": 50, "width": 400}
        })
    
    def preprocess(self, image: np.ndarray, window_preset: str = "soft_tissue") -> np.ndarray:
        """
        Preprocess the CT image with specific window settings
        
        Args:
            image: Input CT image
            window_preset: Window preset name
            
        Returns:
            np.ndarray: Preprocessed image
        """
        # Apply base preprocessing
        image = super().preprocess(image)
        
        # Apply CT-specific preprocessing
        if self.config["ct"].get("hounsfield_normalization", False):
            # Apply window settings
            if window_preset in self.config["ct"]["window_presets"]:
                preset = self.config["ct"]["window_presets"][window_preset]
                center = preset["center"]
                width = preset["width"]
                
                image = self._apply_window_level(image, center, width)
        
        return image
    
    def segment_organs(self, image: np.ndarray, target_organ: str = "liver") -> np.ndarray:
        """
        Segment specific organs in CT images
        
        Args:
            image: Preprocessed CT image
            target_organ: Target organ to segment
            
        Returns:
            np.ndarray: Organ segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment organs from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply different thresholds based on target organ
        if target_organ == "liver":
            # Liver typically has HU values between 50-70
            binary = (image > 0.5) & (image < 0.7)
        elif target_organ == "lung":
            # Lungs have very low HU values
            binary = image < 0.3
        elif target_organ == "bone":
            # Bones have high HU values
            binary = image > 0.7
        else:
            # Default segmentation
            binary = image > 0.5
        
        # Apply morphological operations
        binary = morphology.binary_opening(binary, morphology.disk(2))
        binary = morphology.binary_closing(binary, morphology.disk(5))
        
        # Remove small objects
        binary = morphology.remove_small_objects(binary, min_size=100)
        
        return binary


class MRIProcessor(MedicalImageProcessor):
    """Specialized processor for MRI images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the MRI processor
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
        
        # MRI specific configuration
        self.config.setdefault("mri", {})
        self.config["mri"].setdefault("bias_field_correction", True)
        self.config["mri"].setdefault("intensity_normalization", True)
    
    def preprocess(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess the MRI image
        
        Args:
            image: Input MRI image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        # Apply base preprocessing
        image = super().preprocess(image)
        
        # Apply MRI-specific preprocessing
        if self.config["mri"].get("intensity_normalization", False):
            image = self._normalize_intensity(image)
        
        return image
    
    def _normalize_intensity(self, image: np.ndarray) -> np.ndarray:
        """
        Normalize intensity in MRI images
        
        Args:
            image: Input MRI image
            
        Returns:
            np.ndarray: Normalized image
        """
        # Apply N4 bias field correction (simplified version)
        # For a complete implementation, SimpleITK's N4BiasFieldCorrection should be used
        
        # Simple intensity normalization
        p1, p99 = np.percentile(image, (1, 99))
        image = np.clip(image, p1, p99)
        image = (image - p1) / (p99 - p1)
        
        return image
    
    def segment_brain(self, image: np.ndarray) -> np.ndarray:
        """
        Segment brain in MRI images
        
        Args:
            image: Preprocessed MRI image
            
        Returns:
            np.ndarray: Brain segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment brain from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply Otsu's thresholding
        thresh_val = filters.threshold_otsu(image)
        binary = image > thresh_val
        
        # Apply morphological operations
        binary = morphology.binary_opening(binary, morphology.disk(2))
        binary = morphology.binary_closing(binary, morphology.disk(10))
        
        # Fill holes
        binary = morphology.remove_small_holes(binary, area_threshold=100)
        
        # Keep only the largest connected component (brain)
        labels = measure.label(binary)
        largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
        binary = labels == largest_label
        
        return binary


if __name__ == "__main__":
    # Example usage
    processor = MedicalImageProcessor()
    print("Medical Image Processor initialized")

modules\image_analysis\image_segmentation.py

"""
Medical Image Segmentation Module

This module provides specialized segmentation algorithms for medical images,
including X-rays, CT scans, MRIs, and other medical imaging modalities.

Key features:
- Thresholding-based segmentation
- Region growing
- Watershed segmentation
- Deep learning-based segmentation
- Organ-specific segmentation (lungs, brain, etc.)
"""

import os
import logging
import numpy as np
import cv2
from typing import Dict, List, Tuple, Any, Optional, Union
from pathlib import Path
import matplotlib.pyplot as plt
from skimage import measure, morphology, segmentation, filters

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class MedicalImageSegmenter:
    """Base class for medical image segmentation"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the medical image segmenter
        
        Args:
            config_path: Path to configuration file
        """
        self.config = {}
        if config_path and os.path.isfile(config_path):
            try:
                import json
                with open(config_path, 'r') as f:
                    self.config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
            except Exception as e:
                logger.error(f"Error loading configuration: {str(e)}")
        
        # Default configuration
        if not self.config:
            self.config = {
                "segmentation": {
                    "threshold_method": "otsu",
                    "morphology_operations": True,
                    "remove_small_objects": True,
                    "min_size": 100
                }
            }
    
    def threshold_segmentation(self, image: np.ndarray) -> np.ndarray:
        """
        Perform threshold-based segmentation
        
        Args:
            image: Input image (grayscale)
            
        Returns:
            np.ndarray: Binary segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply thresholding
        threshold_method = self.config.get("segmentation", {}).get("threshold_method", "otsu")
        
        if threshold_method == "otsu":
            thresh_val = filters.threshold_otsu(image)
            binary = image > thresh_val
        elif threshold_method == "adaptive":
            # Convert to uint8 for adaptive thresholding
            img_uint8 = (image * 255).astype(np.uint8) if image.max() <= 1.0 else image.astype(np.uint8)
            binary = cv2.adaptiveThreshold(
                img_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                cv2.THRESH_BINARY, 11, 2
            )
            binary = binary > 0
        else:
            # Default to simple thresholding
            binary = image > 0.5
        
        # Apply morphological operations
        if self.config.get("segmentation", {}).get("morphology_operations", True):
            binary = morphology.binary_opening(binary)
            binary = morphology.binary_closing(binary)
        
        # Remove small objects
        if self.config.get("segmentation", {}).get("remove_small_objects", True):
            min_size = self.config.get("segmentation", {}).get("min_size", 100)
            binary = morphology.remove_small_objects(binary, min_size=min_size)
        
        return binary
    
    def region_growing(self, image: np.ndarray, seed_point: Tuple[int, int], 
                      threshold: float = 0.1) -> np.ndarray:
        """
        Perform region growing segmentation
        
        Args:
            image: Input image (grayscale)
            seed_point: Starting point for region growing (y, x)
            threshold: Intensity threshold for region growing
            
        Returns:
            np.ndarray: Binary segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Create output mask
        mask = np.zeros_like(image, dtype=bool)
        
        # Get image dimensions
        height, width = image.shape
        
        # Check if seed point is valid
        if not (0 <= seed_point[0] < height and 0 <= seed_point[1] < width):
            logger.error(f"Invalid seed point: {seed_point}")
            return mask
        
        # Get seed value
        seed_value = image[seed_point]
        
        # Initialize queue with seed point
        queue = [seed_point]
        mask[seed_point] = True
        
        # Define 4-connected neighbors
        neighbors = [(0, 1), (1, 0), (0, -1), (-1, 0)]
        
        # Perform region growing
        while queue:
            y, x = queue.pop(0)
            
            for dy, dx in neighbors:
                ny, nx = y + dy, x + dx
                
                # Check if neighbor is valid
                if (0 <= ny < height and 0 <= nx < width and
                    not mask[ny, nx] and
                    abs(image[ny, nx] - seed_value) <= threshold):
                    
                    mask[ny, nx] = True
                    queue.append((ny, nx))
        
        return mask
    
    def watershed_segmentation(self, image: np.ndarray) -> np.ndarray:
        """
        Perform watershed segmentation
        
        Args:
            image: Input image (grayscale)
            
        Returns:
            np.ndarray: Labeled segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Convert to uint8 for watershed
        if image.dtype != np.uint8:
            if np.max(image) <= 1.0:
                image = (image * 255).astype(np.uint8)
            else:
                image = image.astype(np.uint8)
        
        # Apply gradient
        gradient = filters.sobel(image)
        
        # Apply threshold to create markers
        thresh_val = filters.threshold_otsu(image)
        markers = np.zeros_like(image, dtype=np.int32)
        markers[image < thresh_val * 0.6] = 1  # Background
        markers[image > thresh_val * 1.2] = 2  # Foreground
        
        # Apply watershed
        segmented = segmentation.watershed(gradient, markers)
        
        # Convert to binary mask (foreground only)
        binary_mask = segmented == 2
        
        return binary_mask
    
    def extract_region_properties(self, mask: np.ndarray) -> List[Dict[str, Any]]:
        """
        Extract properties of segmented regions
        
        Args:
            mask: Binary or labeled segmentation mask
            
        Returns:
            List[Dict[str, Any]]: List of region properties
        """
        if mask is None or mask.size == 0:
            logger.error("Cannot extract properties from empty mask")
            return []
        
        # Label connected regions if mask is binary
        if mask.dtype == bool or np.array_equal(np.unique(mask), [0, 1]):
            labeled_mask, num_labels = measure.label(mask, return_num=True)
        else:
            labeled_mask = mask
            num_labels = np.max(mask)
        
        # Extract region properties
        regions = []
        if num_labels > 0:
            props = measure.regionprops(labeled_mask)
            
            for prop in props:
                region = {
                    "label": prop.label,
                    "area": prop.area,
                    "centroid": prop.centroid,
                    "bbox": prop.bbox,
                    "eccentricity": prop.eccentricity,
                    "equivalent_diameter": prop.equivalent_diameter,
                    "orientation": prop.orientation,
                    "perimeter": prop.perimeter
                }
                regions.append(region)
        
        return regions
    
    def visualize_segmentation(self, image: np.ndarray, mask: np.ndarray, 
                              regions: Optional[List[Dict[str, Any]]] = None,
                              save_path: Optional[str] = None) -> None:
        """
        Visualize the segmentation results
        
        Args:
            image: Original image
            mask: Segmentation mask
            regions: Optional list of region properties
            save_path: Optional path to save the visualization
        """
        if image is None or mask is None:
            logger.error("Cannot visualize empty image or mask")
            return
        
        plt.figure(figsize=(15, 5))
        
        # Display original image
        plt.subplot(1, 3, 1)
        plt.title("Original Image")
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        
        # Display segmentation mask
        plt.subplot(1, 3, 2)
        plt.title("Segmentation Mask")
        plt.imshow(mask, cmap='viridis')
        plt.axis('off')
        
        # Display overlay
        plt.subplot(1, 3, 3)
        plt.title("Overlay")
        
        # Create RGB image for visualization
        if len(image.shape) == 2:
            rgb_img = np.stack([image] * 3, axis=2)
        else:
            rgb_img = image.copy()
        
        # Normalize to [0, 1] if needed
        if rgb_img.max() > 1.0:
            rgb_img = rgb_img / 255.0
        
        # Draw region boundaries
        from skimage.segmentation import mark_boundaries
        boundary_img = mark_boundaries(rgb_img, mask, color=(1, 0, 0), mode='thick')
        
        # Draw centroids and labels if regions are provided
        if regions:
            for region in regions:
                y, x = region["centroid"]
                plt.plot(x, y, 'ro', markersize=5)
                plt.text(x, y, str(region["label"]), color='white', 
                         bbox=dict(facecolor='red', alpha=0.5))
        
        plt.imshow(boundary_img)
        plt.axis('off')
        
        plt.tight_layout()
        
        if save_path:
            try:
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                plt.savefig(save_path, dpi=300, bbox_inches='tight')
                logger.info(f"Visualization saved to {save_path}")
            except Exception as e:
                logger.error(f"Error saving visualization: {str(e)}")
        
        plt.close()


class LungSegmenter(MedicalImageSegmenter):
    """Specialized segmenter for lung segmentation in chest X-rays"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the lung segmenter
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
    
    def segment_lungs(self, image: np.ndarray) -> np.ndarray:
        """
        Segment lungs in chest X-ray images
        
        Args:
            image: Input chest X-ray image
            
        Returns:
            np.ndarray: Lung segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment lungs from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply Otsu's thresholding
        thresh_val = filters.threshold_otsu(image)
        binary = image < thresh_val  # Invert for lung segmentation (lungs are darker)
        
        # Apply morphological operations
        binary = morphology.binary_opening(binary, morphology.disk(2))
        binary = morphology.binary_closing(binary, morphology.disk(10))
        
        # Remove border objects
        cleared = segmentation.clear_border(binary)
        
        # Keep only the two largest regions (left and right lungs)
        label_image = measure.label(cleared)
        regions = measure.regionprops(label_image)
        
        # Sort regions by area
        regions.sort(key=lambda x: x.area, reverse=True)
        
        # Create mask with only the two largest regions
        lung_mask = np.zeros_like(image, dtype=bool)
        
        if len(regions) >= 2:
            for i in range(min(2, len(regions))):
                region_label = regions[i].label
                lung_mask[label_image == region_label] = True
        
        return lung_mask
    
    def enhance_lung_structures(self, image: np.ndarray, lung_mask: np.ndarray) -> np.ndarray:
        """
        Enhance lung structures in chest X-ray images
        
        Args:
            image: Input chest X-ray image
            lung_mask: Lung segmentation mask
            
        Returns:
            np.ndarray: Enhanced lung image
        """
        if image is None or image.size == 0 or lung_mask is None or lung_mask.size == 0:
            logger.error("Cannot enhance empty image or mask")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply mask to focus on lung region
        masked_image = image.copy()
        masked_image[~lung_mask] = 0
        
        # Apply CLAHE to enhance contrast in lung regions
        if image.dtype != np.uint8:
            if np.max(image) <= 1.0:
                masked_image = (masked_image * 255).astype(np.uint8)
            else:
                masked_image = masked_image.astype(np.uint8)
        
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        enhanced = clahe.apply(masked_image)
        
        # Blend enhanced lungs with original image
        result = image.copy()
        if result.dtype != np.uint8:
            if np.max(result) <= 1.0:
                result = (result * 255).astype(np.uint8)
            else:
                result = result.astype(np.uint8)
        
        result[lung_mask] = enhanced[lung_mask]
        
        return result


class BrainSegmenter(MedicalImageSegmenter):
    """Specialized segmenter for brain segmentation in MRI images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the brain segmenter
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
    
    def segment_brain(self, image: np.ndarray) -> np.ndarray:
        """
        Segment brain in MRI images
        
        Args:
            image: Input MRI image
            
        Returns:
            np.ndarray: Brain segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment brain from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply Otsu's thresholding
        thresh_val = filters.threshold_otsu(image)
        binary = image > thresh_val
        
        # Apply morphological operations
        binary = morphology.binary_opening(binary, morphology.disk(2))
        binary = morphology.binary_closing(binary, morphology.disk(10))
        
        # Fill holes
        binary = morphology.remove_small_holes(binary, area_threshold=100)
        
        # Keep only the largest connected component (brain)
        labels = measure.label(binary)
        if np.max(labels) > 0:
            largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
            binary = labels == largest_label
        
        return binary
    
    def segment_tumor(self, image: np.ndarray, brain_mask: np.ndarray) -> np.ndarray:
        """
        Segment potential tumor regions in brain MRI
        
        Args:
            image: Input MRI image
            brain_mask: Brain segmentation mask
            
        Returns:
            np.ndarray: Tumor segmentation mask
        """
        if image is None or image.size == 0 or brain_mask is None or brain_mask.size == 0:
            logger.error("Cannot segment tumor from empty image or mask")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply mask to focus on brain region
        masked_image = image.copy()
        masked_image[~brain_mask] = 0
        
        # Apply adaptive thresholding to find potential tumor regions
        # Tumors are typically hyperintense (brighter) in T2-weighted MRI
        thresh_val = np.percentile(masked_image[brain_mask], 95)
        tumor_mask = masked_image > thresh_val
        
        # Apply morphological operations
        tumor_mask = morphology.binary_opening(tumor_mask, morphology.disk(2))
        tumor_mask = morphology.binary_closing(tumor_mask, morphology.disk(5))
        
        # Remove small objects
        tumor_mask = morphology.remove_small_objects(tumor_mask, min_size=50)
        
        return tumor_mask


class OrganSegmenter(MedicalImageSegmenter):
    """Specialized segmenter for organ segmentation in CT images"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the organ segmenter
        
        Args:
            config_path: Path to configuration file
        """
        super().__init__(config_path)
    
    def segment_liver(self, image: np.ndarray) -> np.ndarray:
        """
        Segment liver in CT images
        
        Args:
            image: Input CT image
            
        Returns:
            np.ndarray: Liver segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment liver from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply thresholding for liver (typical HU range for liver is 50-70)
        # Assuming image is normalized to [0, 1]
        liver_mask = (image > 0.5) & (image < 0.7)
        
        # Apply morphological operations
        liver_mask = morphology.binary_opening(liver_mask, morphology.disk(2))
        liver_mask = morphology.binary_closing(liver_mask, morphology.disk(10))
        
        # Remove small objects
        liver_mask = morphology.remove_small_objects(liver_mask, min_size=500)
        
        # Keep only the largest connected component (liver)
        labels = measure.label(liver_mask)
        if np.max(labels) > 0:
            largest_label = np.argmax(np.bincount(labels.flat)[1:]) + 1
            liver_mask = labels == largest_label
        
        return liver_mask
    
    def segment_kidneys(self, image: np.ndarray) -> np.ndarray:
        """
        Segment kidneys in CT images
        
        Args:
            image: Input CT image
            
        Returns:
            np.ndarray: Kidney segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment kidneys from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply thresholding for kidneys (typical HU range for kidneys is 30-50)
        # Assuming image is normalized to [0, 1]
        kidney_mask = (image > 0.3) & (image < 0.5)
        
        # Apply morphological operations
        kidney_mask = morphology.binary_opening(kidney_mask, morphology.disk(2))
        kidney_mask = morphology.binary_closing(kidney_mask, morphology.disk(5))
        
        # Remove small objects
        kidney_mask = morphology.remove_small_objects(kidney_mask, min_size=200)
        
        return kidney_mask
    
    def segment_bones(self, image: np.ndarray) -> np.ndarray:
        """
        Segment bones in CT images
        
        Args:
            image: Input CT image
            
        Returns:
            np.ndarray: Bone segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment bones from empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Apply thresholding for bones (typical HU range for bones is >700)
        # Assuming image is normalized to [0, 1]
        bone_mask = image > 0.7
        
        # Apply morphological operations
        bone_mask = morphology.binary_opening(bone_mask, morphology.disk(1))
        bone_mask = morphology.binary_closing(bone_mask, morphology.disk(3))
        
        # Remove small objects
        bone_mask = morphology.remove_small_objects(bone_mask, min_size=100)
        
        return bone_mask


class DeepLearningSegmenter(MedicalImageSegmenter):
    """Deep learning-based segmentation for medical images"""
    
    def __init__(self, config_path: Optional[str] = None, model_path: Optional[str] = None):
        """
        Initialize the deep learning segmenter
        
        Args:
            config_path: Path to configuration file
            model_path: Path to pre-trained model
        """
        super().__init__(config_path)
        self.model = None
        self.model_path = model_path
        
        # Load model if path is provided
        if model_path and os.path.exists(model_path):
            self.load_model(model_path)
    
    def load_model(self, model_path: str) -> bool:
        """
        Load a pre-trained segmentation model
        
        Args:
            model_path: Path to pre-trained model
            
        Returns:
            bool: True if model loaded successfully, False otherwise
        """
        try:
            # Check if TensorFlow is available
            import tensorflow as tf
            from tensorflow.keras.models import load_model
            
            # Load model
            self.model = load_model(model_path, compile=False)
            logger.info(f"Loaded model from {model_path}")
            return True
        except ImportError:
            logger.error("TensorFlow not available. Cannot load deep learning model.")
            return False
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            return False
    
    def preprocess_for_model(self, image: np.ndarray) -> np.ndarray:
        """
        Preprocess image for model input
        
        Args:
            image: Input image
            
        Returns:
            np.ndarray: Preprocessed image
        """
        if image is None or image.size == 0:
            logger.error("Cannot preprocess empty image")
            return np.array([])
        
        # Ensure image is grayscale
        if len(image.shape) > 2:
            image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        
        # Resize to model input size
        target_size = self.config.get("image_processing", {}).get("target_size", [256, 256])
        resized = cv2.resize(image, (target_size[1], target_size[0]))
        
        # Normalize to [0, 1]
        if resized.max() > 1.0:
            resized = resized / 255.0
        
        # Add batch and channel dimensions
        preprocessed = np.expand_dims(np.expand_dims(resized, axis=0), axis=-1)
        
        return preprocessed
    
    def segment_with_model(self, image: np.ndarray) -> np.ndarray:
        """
        Segment image using deep learning model
        
        Args:
            image: Input image
            
        Returns:
            np.ndarray: Segmentation mask
        """
        if image is None or image.size == 0:
            logger.error("Cannot segment empty image")
            return np.array([])
        
        if self.model is None:
            logger.error("No model loaded. Call load_model() first.")
            return np.array([])
        
        # Preprocess image
        preprocessed = self.preprocess_for_model(image)
        
        # Make prediction
        try:
            prediction = self.model.predict(preprocessed)[0]
            
            # Convert prediction to binary mask
            if prediction.shape[-1] > 1:
                # Multi-class segmentation
                mask = np.argmax(prediction, axis=-1)
            else:
                # Binary segmentation
                mask = (prediction > 0.5).astype(np.uint8)
            
            # Resize mask to original image size
            if mask.shape[:2] != image.shape[:2]:
                mask = cv2.resize(mask.astype(np.uint8), 
                                 (image.shape[1], image.shape[0]), 
                                 interpolation=cv2.INTER_NEAREST)
            
            return mask
        except Exception as e:
            logger.error(f"Error during model prediction: {str(e)}")
            return np.array([])


if __name__ == "__main__":
    # Example usage
    import argparse
    
    parser = argparse.ArgumentParser(description='Medical Image Segmentation')
    parser.add_argument('--image', type=str, required=True, help='Path to input image')
    parser.add_argument('--type', type=str, choices=['xray', 'ct', 'mri'], required=True, 
                       help='Type of medical image')
    parser.add_argument('--output', type=str, default='./output', help='Output directory')
    parser.add_argument('--config', type=str, default=None, help='Path to configuration file')
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output, exist_ok=True)
    
    # Load image
    image = cv2.imread(args.image)
    
    if image is None:
        print(f"Error: Could not load image {args.image}")
        exit(1)
    
    # Select appropriate segmenter based on image type
    if args.type == 'xray':
        segmenter = LungSegmenter(args.config)
        mask = segmenter.segment_lungs(image)
        regions = segmenter.extract_region_properties(mask)
        enhanced = segmenter.enhance_lung_structures(image, mask)
        
        # Save enhanced image
        enhanced_path = os.path.join(args.output, 'enhanced_lungs.png')
        cv2.imwrite(enhanced_path, enhanced)
        print(f"Enhanced image saved to {enhanced_path}")
    
    elif args.type == 'mri':
        segmenter = BrainSegmenter(args.config)
        brain_mask = segmenter.segment_brain(image)
        regions = segmenter.extract_region_properties(brain_mask)
        
        # Try to segment tumor
        tumor_mask = segmenter.segment_tumor(image, brain_mask)
        
        # Save tumor mask if found
        if np.any(tumor_mask):
            tumor_path = os.path.join(args.output, 'tumor_mask.png')
            cv2.imwrite(tumor_path, tumor_mask.astype(np.uint8) * 255)
            print(f"Tumor mask saved to {tumor_path}")
    
    elif args.type == 'ct':
        segmenter = OrganSegmenter(args.config)
        liver_mask = segmenter.segment_liver(image)
        kidney_mask = segmenter.segment_kidneys(image)
        bone_mask = segmenter.segment_bones(image)
        
        # Create combined mask
        combined_mask = np.zeros_like(image[:,:,0], dtype=np.uint8)
        combined_mask[liver_mask] = 1  # Liver
        combined_mask[kidney_mask] = 2  # Kidneys
        combined_mask[bone_mask] = 3    # Bones
        
        regions = segmenter.extract_region_properties(combined_mask)
        
        # Save individual masks
        cv2.imwrite(os.path.join(args.output, 'liver_mask.png'), liver_mask.astype(np.uint8) * 255)
        cv2.imwrite(os.path.join(args.output, 'kidney_mask.png'), kidney_mask.astype(np.uint8) * 255)
        cv2.imwrite(os.path.join(args.output, 'bone_mask.png'), bone_mask.astype(np.uint8) * 255)
        print(f"Organ masks saved to {args.output}")
    
    # Save segmentation mask
    mask_path = os.path.join(args.output, 'segmentation_mask.png')
    cv2.imwrite(mask_path, mask.astype(np.uint8) * 255)
    print(f"Segmentation mask saved to {mask_path}")
    
    # Visualize segmentation
    vis_path = os.path.join(args.output, 'segmentation_visualization.png')
    segmenter.visualize_segmentation(image, mask, regions, vis_path)
    print(f"Visualization saved to {vis_path}")
    
    print(f"Detected {len(regions)} regions in the image")
    for i, region in enumerate(regions[:5]):  # Show top 5 regions
        print(f"Region {i+1}: Area = {region['area']}, Centroid = {region['centroid']}")

modules\image_analysis_init_.py


modules\preprocessing\data_preprocessor.py

"""
Data Preprocessor Module for Medical Diagnosis System

This module handles the preprocessing of various types of medical data including:
- Medical images (DICOM, JPEG, PNG)
- Clinical text data (reports, notes)
- Laboratory results

The module provides standardization, normalization, and transformation functions.
"""

import os
import json
import logging
import numpy as np
import pandas as pd
import cv2
import pydicom
from typing import Dict, List, Union, Optional, Any, Tuple
from pathlib import Path
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

class DataPreprocessor:
    """Base class for data preprocessing operations"""
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the data preprocessor with optional configuration
        
        Args:
            config_path: Path to configuration file
        """
        self.config = {}
        if config_path and os.path.exists(config_path):
            with open(config_path, 'r', encoding='utf-8') as f:
                self.config = json.load(f)
        
        logger.info("DataPreprocessor initialized")
    
    def save_preprocessed_data(self, data: Any, output_path: str) -> bool:
        """
        Save preprocessed data to file
        
        Args:
            data: Data to save
            output_path: Path to save the data
            
        Returns:
            bool: True if successful, False otherwise
        """
        try:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            
            # Determine the file type from extension
            ext = os.path.splitext(output_path)[1].lower()
            
            if ext == '.npy':
                np.save(output_path, data)
            elif ext == '.json':
                with open(output_path, 'w', encoding='utf-8') as f:
                    json.dump(data, f, indent=2)
            elif ext == '.csv':
                if isinstance(data, pd.DataFrame):
                    data.to_csv(output_path, index=False)
                else:
                    pd.DataFrame(data).to_csv(output_path, index=False)
            elif ext in ['.jpg', '.jpeg', '.png']:
                cv2.imwrite(output_path, data)
            else:
                logger.warning(f"Unsupported file extension: {ext}")
                return False
            
            logger.info(f"Saved preprocessed data to {output_path}")
            return True
            
        except Exception as e:
            logger.error(f"Error saving preprocessed data: {str(e)}")
            return False


class MedicalImagePreprocessor(DataPreprocessor):
    """Class for preprocessing medical image data"""
    
    def __init__(self, config_path: Optional[str] = None):
        """Initialize the medical image preprocessor"""
        super().__init__(config_path)
        
        # Default preprocessing parameters
        self.target_size = self.config.get('target_size', (224, 224))
        self.normalize = self.config.get('normalize', True)
        self.clahe = self.config.get('clahe', True)
        self.clahe_clip_limit = self.config.get('clahe_clip_limit', 2.0)
        self.clahe_grid_size = self.config.get('clahe_grid_size', (8, 8))
        
        logger.info(f"MedicalImagePreprocessor initialized with target size: {self.target_size}")
    
    def preprocess_dicom(self, dicom_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocess DICOM image data
        
        Args:
            dicom_data: DICOM data from the collector
            
        Returns:
            Dict: Preprocessed DICOM data
        """
        try:
            file_path = dicom_data.get('file_path')
            if not file_path or not os.path.exists(file_path):
                logger.error(f"DICOM file not found: {file_path}")
                return {}
            
            # Read the DICOM file
            dicom = pydicom.dcmread(file_path)
            
            # Extract pixel array
            pixel_array = dicom.pixel_array
            
            # Convert to appropriate format
            if dicom.PhotometricInterpretation == "MONOCHROME1":
                pixel_array = np.max(pixel_array) - pixel_array
            
            # Normalize to 0-255 range if needed
            if pixel_array.dtype != np.uint8:
                pixel_array = cv2.normalize(pixel_array, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
            
            # Apply CLAHE for contrast enhancement if enabled
            if self.clahe:
                clahe = cv2.createCLAHE(clipLimit=self.clahe_clip_limit, tileGridSize=self.clahe_grid_size)
                if len(pixel_array.shape) == 2:
                    pixel_array = clahe.apply(pixel_array)
                else:
                    # For RGB images, apply CLAHE to the luminance channel
                    if len(pixel_array.shape) == 3 and pixel_array.shape[2] == 3:
                        lab = cv2.cvtColor(pixel_array, cv2.COLOR_RGB2LAB)
                        lab[:, :, 0] = clahe.apply(lab[:, :, 0])
                        pixel_array = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
            
            # Resize to target size
            if pixel_array.shape[:2] != self.target_size:
                pixel_array = cv2.resize(pixel_array, self.target_size, interpolation=cv2.INTER_CUBIC)
            
            # Ensure 3 channels for model compatibility
            if len(pixel_array.shape) == 2:
                pixel_array = np.stack([pixel_array] * 3, axis=2)
            
            # Normalize to 0-1 range if enabled
            if self.normalize:
                pixel_array = pixel_array.astype(np.float32) / 255.0
            
            # Create result dictionary
            result = {
                'preprocessed_image': pixel_array,
                'metadata': dicom_data.get('metadata', {}),
                'original_file_path': file_path,
                'preprocessing_info': {
                    'target_size': self.target_size,
                    'normalize': self.normalize,
                    'clahe': self.clahe
                }
            }
            
            logger.info(f"Successfully preprocessed DICOM image from {file_path}")
            return result
            
        except Exception as e:
            logger.error(f"Error preprocessing DICOM image: {str(e)}")
            return {}
    
    def preprocess_image(self, image_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocess regular image data (JPEG, PNG)
        
        Args:
            image_data: Image data from the collector
            
        Returns:
            Dict: Preprocessed image data
        """
        try:
            file_path = image_data.get('file_path')
            if not file_path or not os.path.exists(file_path):
                logger.error(f"Image file not found: {file_path}")
                return {}
            
            # Read the image file
            image = cv2.imread(file_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            
            # Apply CLAHE for contrast enhancement if enabled
            if self.clahe:
                clahe = cv2.createCLAHE(clipLimit=self.clahe_clip_limit, tileGridSize=self.clahe_grid_size)
                lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
                lab[:, :, 0] = clahe.apply(lab[:, :, 0])
                image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
            
            # Resize to target size
            if image.shape[:2] != self.target_size:
                image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_CUBIC)
            
            # Normalize to 0-1 range if enabled
            if self.normalize:
                image = image.astype(np.float32) / 255.0
            
            # Create result dictionary
            result = {
                'preprocessed_image': image,
                'metadata': image_data.get('metadata', {}),
                'original_file_path': file_path,
                'preprocessing_info': {
                    'target_size': self.target_size,
                    'normalize': self.normalize,
                    'clahe': self.clahe
                }
            }
            
            logger.info(f"Successfully preprocessed image from {file_path}")
            return result
            
        except Exception as e:
            logger.error(f"Error preprocessing image: {str(e)}")
            return {}
    
    def preprocess_batch(self, image_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Preprocess a batch of images
        
        Args:
            image_data_list: List of image data from the collector
            
        Returns:
            List[Dict]: List of preprocessed image data
        """
        results = []
        
        for image_data in image_data_list:
            if image_data.get('type') == 'dicom':
                result = self.preprocess_dicom(image_data)
            else:
                result = self.preprocess_image(image_data)
            
            if result:
                results.append(result)
        
        logger.info(f"Preprocessed {len(results)} images out of {len(image_data_list)} total")
        return results


class ClinicalTextPreprocessor(DataPreprocessor):
    """Class for preprocessing clinical text data"""
    
    def __init__(self, config_path: Optional[str] = None):
        """Initialize the clinical text preprocessor"""
        super().__init__(config_path)
        
        # Download NLTK resources if needed
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt')
        
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords')
        
        # Default preprocessing parameters
        self.remove_stopwords = self.config.get('remove_stopwords', True)
        self.remove_punctuation = self.config.get('remove_punctuation', True)
        self.lowercase = self.config.get('lowercase', True)
        self.language = self.config.get('language', 'english')
        
        # Load stopwords
        self.stopwords = set(stopwords.words(self.language)) if self.remove_stopwords else set()
        
        # Medical specific stopwords (common words in medical texts that aren't informative)
        medical_stopwords = self.config.get('medical_stopwords', [
            'patient', 'doctor', 'hospital', 'treatment', 'medicine',
            'dose', 'report', 'medical', 'clinical', 'health'
        ])
        self.stopwords.update(medical_stopwords)
        
        logger.info(f"ClinicalTextPreprocessor initialized with {len(self.stopwords)} stopwords")
    
    def preprocess_text(self, text_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocess clinical text data
        
        Args:
            text_data: Text data from the collector
            
        Returns:
            Dict: Preprocessed text data
        """
        try:
            content = text_data.get('content')
            if not content:
                logger.error("No content found in text data")
                return {}
            
            # Handle different content types
            if isinstance(content, list) or isinstance(content, dict):
                # For structured data, convert to string for text processing
                content_str = json.dumps(content)
            else:
                content_str = str(content)
            
            # Apply preprocessing steps
            if self.lowercase:
                content_str = content_str.lower()
            
            if self.remove_punctuation:
                content_str = re.sub(r'[^\w\s]', ' ', content_str)
            
            # Tokenize
            tokens = word_tokenize(content_str)
            
            # Remove stopwords if enabled
            if self.remove_stopwords:
                tokens = [token for token in tokens if token not in self.stopwords]
            
            # Reconstruct text
            preprocessed_text = ' '.join(tokens)
            
            # Extract medical entities (simplified)
            medical_entities = self._extract_medical_entities(preprocessed_text)
            
            # Create result dictionary
            result = {
                'original_content': content,
                'preprocessed_text': preprocessed_text,
                'tokens': tokens,
                'medical_entities': medical_entities,
                'metadata': text_data.get('metadata', {}),
                'preprocessing_info': {
                    'remove_stopwords': self.remove_stopwords,
                    'remove_punctuation': self.remove_punctuation,
                    'lowercase': self.lowercase
                }
            }
            
            logger.info(f"Successfully preprocessed text from {text_data.get('file_path', 'unknown')}")
            return result
            
        except Exception as e:
            logger.error(f"Error preprocessing text: {str(e)}")
            return {}
    
    def _extract_medical_entities(self, text: str) -> Dict[str, List[str]]:
        """
        Extract medical entities from text (simplified version)
        
        Args:
            text: Preprocessed text
            
        Returns:
            Dict: Dictionary of medical entities by category
        """
        # This is a simplified implementation
        # In a real system, you would use a medical NER model or dictionary lookup
        
        # Simple dictionaries for demonstration
        symptoms = ['pain', 'fever', 'cough', 'headache', 'nausea', 'fatigue', 'dizziness']
        diseases = ['diabetes', 'hypertension', 'cancer', 'asthma', 'pneumonia', 'arthritis']
        medications = ['aspirin', 'ibuprofen', 'paracetamol', 'insulin', 'antibiotics']
        
        found_entities = {
            'symptoms': [],
            'diseases': [],
            'medications': []
        }
        
        tokens = text.split()
        
        for token in tokens:
            if token in symptoms:
                found_entities['symptoms'].append(token)
            elif token in diseases:
                found_entities['diseases'].append(token)
            elif token in medications:
                found_entities['medications'].append(token)
        
        return found_entities
    
    def preprocess_batch(self, text_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Preprocess a batch of clinical texts
        
        Args:
            text_data_list: List of text data from the collector
            
        Returns:
            List[Dict]: List of preprocessed text data
        """
        results = []
        
        for text_data in text_data_list:
            result = self.preprocess_text(text_data)
            if result:
                results.append(result)
        
        logger.info(f"Preprocessed {len(results)} texts out of {len(text_data_list)} total")
        return results


class LabResultPreprocessor(DataPreprocessor):
    """Class for preprocessing laboratory result data"""
    
    def __init__(self, config_path: Optional[str] = None):
        """Initialize the lab result preprocessor"""
        super().__init__(config_path)
        
        # Default preprocessing parameters
        self.normalization = self.config.get('normalization', 'standard')  # 'standard', 'minmax', or None
        self.handle_missing = self.config.get('handle_missing', 'mean')  # 'mean', 'median', 'zero', or 'drop'
        self.outlier_detection = self.config.get('outlier_detection', True)
        self.outlier_threshold = self.config.get('outlier_threshold', 3.0)  # Z-score threshold
        
        # Initialize scalers
        self.standard_scaler = StandardScaler() if self.normalization == 'standard' else None
        self.minmax_scaler = MinMaxScaler() if self.normalization == 'minmax' else None
        
        logger.info(f"LabResultPreprocessor initialized with normalization: {self.normalization}")
    
    def preprocess_lab_data(self, lab_data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Preprocess laboratory result data
        
        Args:
            lab_data: Lab data from the collector
            
        Returns:
            Dict: Preprocessed lab data
        """
        try:
            data = lab_data.get('data')
            if not data:
                logger.error("No data found in lab results")
                return {}
            
            # Convert to DataFrame for easier processing
            if isinstance(data, list):
                df = pd.DataFrame(data)
            elif isinstance(data, dict):
                df = pd.DataFrame([data])
            else:
                logger.error(f"Unsupported data type: {type(data)}")
                return {}
            
            # Store original data
            original_df = df.copy()
            
            # Handle missing values
            if self.handle_missing == 'mean':
                df = df.fillna(df.mean(numeric_only=True))
            elif self.handle_missing == 'median':
                df = df.fillna(df.median(numeric_only=True))
            elif self.handle_missing == 'zero':
                df = df.fillna(0)
            elif self.handle_missing == 'drop':
                df = df.dropna()
            
            # Handle non-numeric columns
            numeric_cols = df.select_dtypes(include=['number']).columns
            non_numeric_cols = [col for col in df.columns if col not in numeric_cols]
            
            # Detect and handle outliers in numeric columns
            if self.outlier_detection:
                for col in numeric_cols:
                    if df[col].dtype in [np.float64, np.int64]:
                        # Calculate z-scores
                        z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
                        # Mark outliers
                        outliers = z_scores > self.outlier_threshold
                        # Replace outliers with column mean
                        df.loc[outliers, col] = df[col].mean()
            
            # Apply normalization to numeric columns
            numeric_df = df[numeric_cols]
            if self.normalization == 'standard':
                normalized_data = self.standard_scaler.fit_transform(numeric_df)
                normalized_df = pd.DataFrame(normalized_data, columns=numeric_cols, index=df.index)
            elif self.normalization == 'minmax':
                normalized_data = self.minmax_scaler.fit_transform(numeric_df)
                normalized_df = pd.DataFrame(normalized_data, columns=numeric_cols, index=df.index)
            else:
                normalized_df = numeric_df
            
            # Combine normalized numeric columns with non-numeric columns
            if non_numeric_cols:
                result_df = pd.concat([normalized_df, df[non_numeric_cols]], axis=1)
            else:
                result_df = normalized_df
            
            # Calculate statistics
            stats = {}
            for col in numeric_cols:
                stats[col] = {
                    'mean': float(df[col].mean()),
                    'median': float(df[col].median()),
                    'std': float(df[col].std()),
                    'min': float(df[col].min()),
                    'max': float(df[col].max())
                }
            
            # Create result dictionary
            result = {
                'original_data': original_df.to_dict(orient='records'),
                'preprocessed_data': result_df.to_dict(orient='records'),
                'statistics': stats,
                'metadata': lab_data.get('metadata', {}),
                'preprocessing_info': {
                    'normalization': self.normalization,
                    'handle_missing': self.handle_missing,
                    'outlier_detection': self.outlier_detection,
                    'outlier_threshold': self.outlier_threshold
                }
            }
            
            logger.info(f"Successfully preprocessed lab data from {lab_data.get('file_path', 'unknown')}")
            return result
            
        except Exception as e:
            logger.error(f"Error preprocessing lab data: {str(e)}")
            return {}
    
    def preprocess_batch(self, lab_data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Preprocess a batch of lab results
        
        Args:
            lab_data_list: List of lab data from the collector
            
        Returns:
            List[Dict]: List of preprocessed lab data
        """
        results = []
        
        for lab_data in lab_data_list:
            result = self.preprocess_lab_data(lab_data)
            if result:
                results.append(result)
        
        logger.info(f"Preprocessed {len(results)} lab results out of {len(lab_data_list)} total")
        return results


# Composite preprocessor that uses all specialized preprocessors
class MedicalDataPreprocessor:
    """Composite class that combines all data preprocessors"""
    
    def __init__(self, config_path: Optional[str] = None):
        """Initialize all data preprocessors"""
        self.image_preprocessor = MedicalImagePreprocessor(config_path)
        self.text_preprocessor = ClinicalTextPreprocessor(config_path)
        self.lab_preprocessor = LabResultPreprocessor(config_path)
        logger.info("MedicalDataPreprocessor initialized with all specialized preprocessors")
    
    def preprocess_all(self, collection: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
        """
        Preprocess all types of medical data in a collection
        
        Args:
            collection: Collection of data from the collector
            
        Returns:
            Dict: Dictionary with keys for each data type and values as lists of preprocessed data
        """
        preprocessed = {}
        
        if 'images' in collection:
            preprocessed['images'] = self.image_preprocessor.preprocess_batch(collection['images'])
        
        if 'clinical_texts' in collection:
            preprocessed['clinical_texts'] = self.text_preprocessor.preprocess_batch(collection['clinical_texts'])
        
        if 'lab_results' in collection:
            preprocessed['lab_results'] = self.lab_preprocessor.preprocess_batch(collection['lab_results'])
        
        logger.info(f"Preprocessed all data in collection")
        
        return preprocessed
    
    def save_preprocessed_data(self, preprocessed_data: Dict[str, List[Dict[str, Any]]], output_dir: str) -> None:
        """
        Save all preprocessed data to files
        
        Args:
            preprocessed_data: Preprocessed data collection
            output_dir: Directory to save the data
        """
        os.makedirs(output_dir, exist_ok=True)
        
        # Save a summary
        summary = {
            'total_items': sum(len(items) for items in preprocessed_data.values()),
            'counts': {data_type: len(items) for data_type, items in preprocessed_data.items()}
        }
        
        with open(os.path.join(output_dir, 'preprocessing_summary.json'), 'w', encoding='utf-8') as f:
            json.dump(summary, f, indent=2)
        
        # Save each data type in its own subdirectory
        for data_type, items in preprocessed_data.items():
            type_dir = os.path.join(output_dir, data_type)
            os.makedirs(type_dir, exist_ok=True)
            
            for i, item in enumerate(items):
                # Save metadata and preprocessing info
                meta_file = os.path.join(type_dir, f'item_{i}_meta.json')
                meta_data = {
                    'metadata': item.get('metadata', {}),
                    'preprocessing_info': item.get('preprocessing_info', {})
                }
                
                with open(meta_file, 'w', encoding='utf-8') as f:
                    json.dump(meta_data, f, indent=2)
                
                # Save the actual data based on type
                if data_type == 'images':
                    # Save image as NPY file
                    img_file = os.path.join(type_dir, f'item_{i}_image.npy')
                    np.save(img_file, item.get('preprocessed_image'))
                
                elif data_type == 'clinical_texts':
                    # Save text data as JSON
                    text_file = os.path.join(type_dir, f'item_{i}_text.json')
                    text_data = {
                        'preprocessed_text': item.get('preprocessed_text', ''),
                        'tokens': item.get('tokens', []),
                        'medical_entities': item.get('medical_entities', {})
                    }
                    
                    with open(text_file, 'w', encoding='utf-8') as f:
                        json.dump(text_data, f, indent=2)
                
                elif data_type == 'lab_results':
                    # Save lab data as CSV
                    lab_file = os.path.join(type_dir, f'item_{i}_lab.csv')
                    preprocessed_data = item.get('preprocessed_data', [])
                    pd.DataFrame(preprocessed_data).to_csv(lab_file, index=False)
        
        logger.info(f"Saved all preprocessed data to {output_dir}")


if __name__ == "__main__":
    # Example usage
    from data_collector import MedicalDataCollector
    
    collector = MedicalDataCollector()
    collection = collector.collect_all_from_directory("./sample_data")
    
    preprocessor = MedicalDataPreprocessor()
    preprocessed_data = preprocessor.preprocess_all(collection)
    preprocessor.save_preprocessed_data(preprocessed_data, "./data/preprocessed")

modules\preprocessing_init_.py


modules\text_analysis\clinical_text_analyzer.py

"""
Clinical Text Analyzer Module

This module integrates all text analysis components (preprocessing, NER, classification,
summarization, and relation extraction) to provide a comprehensive clinical text analysis.
"""

import os
import json
import logging
from typing import List, Dict, Any, Optional, Union, Tuple

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import seaborn as sns

from .text_preprocessor import ClinicalTextPreprocessor
from .medical_ner import MedicalNamedEntityRecognizer
from .text_classifier import ClinicalTextClassifier
from .text_summarizer import ClinicalTextSummarizer
from .relation_extractor import MedicalRelationExtractor

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler("clinical_text_analysis.log"),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class ClinicalTextAnalyzer:
    """
    Class for comprehensive clinical text analysis
    """
    
    def __init__(self, config_path: Optional[str] = None):
        """
        Initialize the clinical text analyzer
        
        Args:
            config_path: Path to the configuration file
        """
        self.config = self._load_config(config_path)
        
        # Initialize components
        self.preprocessor = ClinicalTextPreprocessor(config_path)
        self.ner = MedicalNamedEntityRecognizer(config_path)
        self.classifier = ClinicalTextClassifier(config_path)
        self.summarizer = ClinicalTextSummarizer(config_path)
        self.relation_extractor = MedicalRelationExtractor(config_path)
        
        logger.info("Initialized ClinicalTextAnalyzer")
    
    def _load_config(self, config_path: Optional[str]) -> Dict[str, Any]:
        """
        Load configuration from file
        
        Args:
            config_path: Path to the configuration file
            
        Returns:
            Dict[str, Any]: Configuration dictionary
        """
        default_config = {
            "analysis": {
                "default_pipeline": ["preprocess", "ner", "classify", "summarize", "extract_relations"],
                "output_format": "json",
                "visualization": True
            }
        }
        
        if config_path and os.path.exists(config_path):
            try:
                with open(config_path, 'r') as f:
                    config = json.load(f)
                logger.info(f"Loaded configuration from {config_path}")
                return config
            except Exception as e:
                logger.error(f"Error loading configuration from {config_path}: {e}")
                return default_config
        else:
            logger.info("Using default configuration")
            return default_config
    
    def analyze_text(self, text: str, pipeline: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Analyze clinical text using the specified pipeline
        
        Args:
            text: Clinical text to analyze
            pipeline: List of analysis steps to perform
            
        Returns:
            Dict[str, Any]: Analysis results
        """
        if not text or not isinstance(text, str):
            logger.warning(f"Invalid input text: {text}")
            return {"success": False, "error": "Invalid input text"}
        
        # Use default pipeline if not specified
        if pipeline is None:
            pipeline = self.config["analysis"]["default_pipeline"]
        
        results = {
            "success": True,
            "text": text,
            "pipeline": pipeline
        }
        
        try:
            # Preprocess text
            if "preprocess" in pipeline:
                logger.info("Preprocessing text")
                preprocessed_text = self.preprocessor.preprocess_text(text)
                sentences = self.preprocessor.extract_sentences(text)
                medical_terms = self.preprocessor.extract_medical_terms(text)
                
                results["preprocessing"] = {
                    "preprocessed_text": preprocessed_text,
                    "sentences": sentences,
                    "sentence_count": len(sentences),
                    "medical_terms": medical_terms,
                    "medical_term_count": len(medical_terms)
                }
                
                # Extract demographics if available
                demographics = self.preprocessor.extract_demographics(text)
                if demographics:
                    results["demographics"] = demographics
            
            # Named Entity Recognition
            if "ner" in pipeline:
                logger.info("Performing Named Entity Recognition")
                entities = self.ner.extract_entities(text, method="transformer")
                
                # Group entities by type
                entities_by_type = {}
                for entity in entities:
                    entity_type = entity["label"]
                    if entity_type not in entities_by_type:
                        entities_by_type[entity_type] = []
                    entities_by_type[entity_type].append(entity["text"])
                
                results["entities"] = {
                    "all_entities": entities,
                    "entity_count": len(entities),
                    "entities_by_type": entities_by_type
                }
            
            # Classification
            if "classify" in pipeline:
                logger.info("Performing text classification")
                # Assuming we have a pre-trained model for classification
                try:
                    classification_result = self.classifier.predict(text, method="transformer")
                    results["classification"] = classification_result
                except Exception as e:
                    logger.error(f"Error in classification: {e}")
                    results["classification"] = {"error": str(e)}
            
            # Summarization
            if "summarize" in pipeline:
                logger.info("Generating text summary")
                summary_result = self.summarizer.summarize(text, method="hybrid")
                
                if summary_result["success"]:
                    results["summary"] = summary_result
                else:
                    logger.error(f"Error in summarization: {summary_result['error']}")
                    results["summary"] = {"error": summary_result["error"]}
                
                # Extract key findings
                findings_result = self.summarizer.key_findings_extract(text)
                if findings_result["success"]:
                    results["key_findings"] = findings_result["findings"]
            
            # Relation Extraction
            if "extract_relations" in pipeline and "entities" in results:
                logger.info("Extracting relations between entities")
                relations = self.relation_extractor.extract_relations(
                    text, 
                    results["entities"]["all_entities"], 
                    method="hybrid"
                )
                
                # Build knowledge graph
                knowledge_graph = self.relation_extractor.build_knowledge_graph(relations)
                
                results["relations"] = {
                    "all_relations": relations,
                    "relation_count": len(relations),
                    "knowledge_graph": knowledge_graph
                }
            
            # Generate visualizations if enabled
            if self.config["analysis"]["visualization"]:
                logger.info("Generating visualizations")
                visualizations = self._generate_visualizations(results)
                results["visualizations"] = visualizations
            
            return results
        except Exception as e:
            logger.error(f"Error in text analysis: {e}")
            return {"success": False, "error": str(e)}
    
    def analyze_document(self, document_path: str, pipeline: Optional[List[str]] = None) -> Dict[str, Any]:
        """
        Analyze a clinical document from file
        
        Args:
            document_path: Path to the document
            pipeline: List of analysis steps to perform
            
        Returns:
            Dict[str, Any]: Analysis results
        """
        if not document_path or not os.path.exists(document_path):
            logger.error(f"Document not found: {document_path}")
            return {"success": False, "error": f"Document not found: {document_path}"}
        
        try:
            # Read document
            with open(document_path, 'r', encoding='utf-8') as f:
                text = f.read()
            
            # Analyze text
            results = self.analyze_text(text, pipeline)
            results["document_path"] = document_path
            
            return results
        except Exception as e:
            logger.error(f"Error analyzing document: {e}")
            return {"success": False, "error": str(e)}
    
    def analyze_batch(self, texts: List[str], pipeline: Optional[List[str]] = None) -> List[Dict[str, Any]]:
        """
        Analyze a batch of clinical texts
        
        Args:
            texts: List of clinical texts to analyze
            pipeline: List of analysis steps to perform
            
        Returns:
            List[Dict[str, Any]]: List of analysis results
        """
        if not texts:
            logger.warning("Empty batch of texts")
            return []
        
        results = []
        
        for i, text in enumerate(texts):
            logger.info(f"Analyzing text {i+1}/{len(texts)}")
            result = self.analyze_text(text, pipeline)
            results.append(result)
        
        return results
    
    def _generate_visualizations(self, results: Dict[str, Any]) -> Dict[str, Any]:
        """
        Generate visualizations for analysis results
        
        Args:
            results: Analysis results
            
        Returns:
            Dict[str, Any]: Visualization data
        """
        visualizations = {}
        
        # Entity distribution visualization
        if "entities" in results and results["entities"]["entity_count"] > 0:
            try:
                entities_by_type = results["entities"]["entities_by_type"]
                entity_counts = {entity_type: len(entities) for entity_type, entities in entities_by_type.items()}
                
                # Create entity distribution chart data
                visualizations["entity_distribution"] = {
                    "type": "bar_chart",
                    "data": entity_counts
                }
            except Exception as e:
                logger.error(f"Error generating entity visualization: {e}")
        
        # Relation visualization
        if "relations" in results and results["relations"]["relation_count"] > 0:
            try:
                # Use knowledge graph for relation visualization
                visualizations["relations"] = {
                    "type": "knowledge_graph",
                    "data": results["relations"]["knowledge_graph"]
                }
            except Exception as e:
                logger.error(f"Error generating relation visualization: {e}")
        
        # Text summary visualization
        if "summary" in results and "summary" in results["summary"]:
            try:
                original_text = results["text"]
                summary = results["summary"]["summary"]
                
                visualizations["summary_comparison"] = {
                    "type": "text_comparison",
                    "data": {
                        "original": {
                            "text": original_text[:1000] + ("..." if len(original_text) > 1000 else ""),
                            "word_count": len(original_text.split())
                        },
                        "summary": {
                            "text": summary,
                            "word_count": len(summary.split())
                        },
                        "compression_ratio": len(summary.split()) / len(original_text.split()) if len(original_text.split()) > 0 else 0
                    }
                }
            except Exception as e:
                logger.error(f"Error generating summary visualization: {e}")
        
        return visualizations
    
    def generate_report(self, analysis_results: Dict[str, Any], 
                       output_path: Optional[str] = None, 
                       format: str = "html") -> str:
        """
        Generate a report from analysis results
        
        Args:
            analysis_results: Analysis results
            output_path: Path to save the report
            format: Report format ('html', 'json', or 'text')
            
        Returns:
            str: Report content or path to saved report
        """
        if not analysis_results or not analysis_results.get("success", False):
            logger.warning("Invalid analysis results for report generation")
            return "Error: Invalid analysis results"
        
        try:
            if format == "html":
                report = self._generate_html_report(analysis_results)
            elif format == "json":
                report = json.dumps(analysis_results, indent=2)
            elif format == "text":
                report = self._generate_text_report(analysis_results)
            else:
                logger.error(f"Unsupported report format: {format}")
                return f"Error: Unsupported report format: {format}"
            
            # Save report if output path is provided
            if output_path:
                with open(output_path, 'w', encoding='utf-8') as f:
                    f.write(report)
                logger.info(f"Report saved to {output_path}")
                return output_path
            
            return report
        except Exception as e:
            logger.error(f"Error generating report: {e}")
            return f"Error generating report: {e}"
    
    def _generate_html_report(self, results: Dict[str, Any]) -> str:
        """
        Generate an HTML report from analysis results
        
        Args:
            results: Analysis results
            
        Returns:
            str: HTML report
        """
        html = f"""
        <!DOCTYPE html>
        <html lang="en">
        <head>
            <meta charset="UTF-8">
            <meta name="viewport" content="width=device-width, initial-scale=1.0">
            <title>Clinical Text Analysis Report</title>
            <style>
                body {{
                    font-family: Arial, sans-serif;
                    line-height: 1.6;
                    color: #333;
                    max-width: 1200px;
                    margin: 0 auto;
                    padding: 20px;
                }}
                h1, h2, h3 {{
                    color: #2c3e50;
                }}
                .section {{
                    margin-bottom: 30px;
                    padding: 20px;
                    border-radius: 5px;
                    background-color: #f9f9f9;
                    box-shadow: 0 2px 4px rgba(0,0,0,0.1);
                }}
                .entity {{
                    display: inline-block;
                    margin: 5px;
                    padding: 5px 10px;
                    border-radius: 15px;
                    font-size: 14px;
                }}
                .entity-DISEASE {{
                    background-color: #ffcccc;
                }}
                .entity-SYMPTOM {{
                    background-color: #ffffcc;
                }}
                .entity-MEDICATION {{
                    background-color: #ccffcc;
                }}
                .entity-PROCEDURE {{
                    background-color: #ccccff;
                }}
                .entity-TEST {{
                    background-color: #ffccff;
                }}
                table {{
                    width: 100%;
                    border-collapse: collapse;
                    margin-top: 10px;
                }}
                th, td {{
                    padding: 8px;
                    text-align: left;
                    border-bottom: 1px solid #ddd;
                }}
                th {{
                    background-color: #f2f2f2;
                }}
                .text-container {{
                    max-height: 300px;
                    overflow-y: auto;
                    padding: 10px;
                    border: 1px solid #ddd;
                    border-radius: 5px;
                    background-color: #fff;
                }}
            </style>
        </head>
        <body>
            <h1>Clinical Text Analysis Report</h1>
            
            <div class="section">
                <h2>Overview</h2>
                <p><strong>Analysis Pipeline:</strong> {', '.join(results['pipeline'])}</p>
                <p><strong>Text Length:</strong> {len(results['text'])} characters, {len(results['text'].split())} words</p>
        """
        
        # Add preprocessing section if available
        if "preprocessing" in results:
            html += f"""
                <div class="section">
                    <h2>Text Preprocessing</h2>
                    <p><strong>Sentence Count:</strong> {results['preprocessing']['sentence_count']}</p>
                    <p><strong>Medical Terms:</strong> {results['preprocessing']['medical_term_count']}</p>
                    
                    <h3>Medical Terms</h3>
                    <div class="text-container">
                        {', '.join(results['preprocessing']['medical_terms'][:100])}
                        {'...' if len(results['preprocessing']['medical_terms']) > 100 else ''}
                    </div>
                </div>
            """
        
        # Add demographics section if available
        if "demographics" in results:
            html += f"""
                <div class="section">
                    <h2>Patient Demographics</h2>
                    <table>
                        <tr>
                            <th>Attribute</th>
                            <th>Value</th>
                        </tr>
            """
            
            for key, value in results["demographics"].items():
                html += f"""
                        <tr>
                            <td>{key.replace('_', ' ').title()}</td>
                            <td>{value}</td>
                        </tr>
                """
            
            html += """
                    </table>
                </div>
            """
        
        # Add entities section if available
        if "entities" in results:
            html += f"""
                <div class="section">
                    <h2>Named Entity Recognition</h2>
                    <p><strong>Total Entities:</strong> {results['entities']['entity_count']}</p>
                    
                    <h3>Entities by Type</h3>
            """
            
            for entity_type, entities in results["entities"]["entities_by_type"].items():
                html += f"""
                    <h4>{entity_type} ({len(entities)})</h4>
                    <div>
                """
                
                for entity in entities[:30]:
                    html += f"""
                        <span class="entity entity-{entity_type}">{entity}</span>
                    """
                
                if len(entities) > 30:
                    html += "..."
                
                html += """
                    </div>
                """
            
            html += """
                </div>
            """
        
        # Add classification section if available
        if "classification" in results and "error" not in results["classification"]:
            html += f"""
                <div class="section">
                    <h2>Text Classification</h2>
                    <table>
                        <tr>
                            <th>Label</th>
                            <th>Confidence</th>
                        </tr>
            """
            
            for label, score in results["classification"]["predictions"].items():
                html += f"""
                        <tr>
                            <td>{label}</td>
                            <td>{score:.4f}</td>
                        </tr>
                """
            
            html += """
                    </table>
                </div>
            """
        
        # Add summary section if available
        if "summary" in results and "error" not in results["summary"]:
            html += f"""
                <div class="section">
                    <h2>Text Summarization</h2>
                    <p><strong>Method:</strong> {results['summary']['method']}</p>
                    <p><strong>Compression Ratio:</strong> {results['summary']['compression_ratio']:.2%}</p>
                    
                    <h3>Summary</h3>
                    <div class="text-container">
                        {results['summary']['summary']}
                    </div>
                </div>
            """
        
        # Add key findings section if available
        if "key_findings" in results:
            html += f"""
                <div class="section">
                    <h2>Key Findings</h2>
                    <table>
                        <tr>
                            <th>Category</th>
                            <th>Findings</th>
                        </tr>
            """
            
            for category, findings in results["key_findings"].items():
                if findings:
                    html += f"""
                            <tr>
                                <td>{category.replace('_', ' ').title()}</td>
                                <td>{', '.join(findings)}</td>
                            </tr>
                    """
            
            html += """
                    </table>
                </div>
            """
        
        # Add relations section if available
        if "relations" in results and results["relations"]["relation_count"] > 0:
            html += f"""
                <div class="section">
                    <h2>Relation Extraction</h2>
                    <p><strong>Total Relations:</strong> {results['relations']['relation_count']}</p>
                    
                    <h3>Extracted Relations</h3>
                    <table>
                        <tr>
                            <th>Relation Type</th>
                            <th>Entity 1</th>
                            <th>Entity 2</th>
                            <th>Confidence</th>
                        </tr>
            """
            
            for relation in results["relations"]["all_relations"][:30]:
                html += f"""
                        <tr>
                            <td>{relation['relation_type']}</td>
                            <td>{relation['entity1']['text']} ({relation['entity1']['label']})</td>
                            <td>{relation['entity2']['text']} ({relation['entity2']['label']})</td>
                            <td>{relation['confidence']:.2f}</td>
                        </tr>
                """
            
            if len(results["relations"]["all_relations"]) > 30:
                html += """
                        <tr>
                            <td colspan="4">...</td>
                        </tr>
                """
            
            html += """
                    </table>
                </div>
            """
        
        # Close HTML tags
        html += """
            </div>
        </body>
        </html>
        """
        
        return html
    
    def _generate_text_report(self, results: Dict[str, Any]) -> str:
        """
        Generate a plain text report from analysis results
        
        Args:
            results: Analysis results
            
        Returns:
            str: Text report
        """
        report = "CLINICAL TEXT ANALYSIS REPORT\n"
        report += "=" * 30 + "\n\n"
        
        # Overview
        report += "OVERVIEW\n"
        report += "-" * 8 + "\n"
        report += f"Analysis Pipeline: {', '.join(results['pipeline'])}\n"
        report += f"Text Length: {len(results['text'])} characters, {len(results['text'].split())} words\n\n"
        
        # Preprocessing
        if "preprocessing" in results:
            report += "TEXT PREPROCESSING\n"
            report += "-" * 17 + "\n"
            report += f"Sentence Count: {results['preprocessing']['sentence_count']}\n"
            report += f"Medical Terms: {results['preprocessing']['medical_term_count']}\n"
            
            # List some medical terms
            report += "Sample Medical Terms: "
            report += ", ".join(results['preprocessing']['medical_terms'][:20])
            if len(results['preprocessing']['medical_terms']) > 20:
                report += "..."
            report += "\n\n"
        
        # Demographics
        if "demographics" in results:
            report += "PATIENT DEMOGRAPHICS\n"
            report += "-" * 20 + "\n"
            for key, value in results["demographics"].items():
                report += f"{key.replace('_', ' ').title()}: {value}\n"
            report += "\n"
        
        # Entities
        if "entities" in results:
            report += "NAMED ENTITY RECOGNITION\n"
            report += "-" * 24 + "\n"
            report += f"Total Entities: {results['entities']['entity_count']}\n\n"
            
            report += "Entities by Type:\n"
            for entity_type, entities in results["entities"]["entities_by_type"].items():
                report += f"{entity_type} ({len(entities)}): "
                report += ", ".join(entities[:15])
                if len(entities) > 15:
                    report += "..."
                report += "\n"
            report += "\n"
        
        # Classification
        if "classification" in results and "error" not in results["classification"]:
            report += "TEXT CLASSIFICATION\n"
            report += "-" * 18 + "\n"
            
            for label, score in results["classification"]["predictions"].items():
                report += f"{label}: {score:.4f}\n"
            report += "\n"
        
        # Summary
        if "summary" in results and "error" not in results["summary"]:
            report += "TEXT SUMMARIZATION\n"
            report += "-" * 17 + "\n"
            report += f"Method: {results['summary']['method']}\n"
            report += f"Compression Ratio: {results['summary']['compression_ratio']:.2%}\n\n"
            
            report += "Summary:\n"
            report += results['summary']['summary'] + "\n\n"
        
        # Key Findings
        if "key_findings" in results:
            report += "KEY FINDINGS\n"
            report += "-" * 12 + "\n"
            
            for category, findings in results["key_findings"].items():
                if findings:
                    report += f"{category.replace('_', ' ').title()}:\n"
                    for finding in findings:
                        report += f"- {finding}\n"
                    report += "\n"
        
        # Relations
        if "relations" in results and results["relations"]["relation_count"] > 0:
            report += "RELATION EXTRACTION\n"
            report += "-" * 19 + "\n"
            report += f"Total Relations: {results['relations']['relation_count']}\n\n"
            
            report += "Sample Relations:\n"
            for relation in results["relations"]["all_relations"][:15]:
                report += f"{relation['entity1']['text']} --[{relation['relation_type']}]--> {relation['entity2']['text']} (Confidence: {relation['confidence']:.2f})\n"
            
            if len(results["relations"]["all_relations"]) > 15:
                report += "...\n"
            report += "\n"
        
        return report
    
    def save_results(self, results: Dict[str, Any], output_path: str) -> bool:
        """
        Save analysis results to file
        
        Args:
            results: Analysis results
            output_path: Path to save results
            
        Returns:
            bool: Success flag
        """
        try:
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2)
            logger.info(f"Results saved to {output_path}")
            return True
        except Exception as e:
            logger.error(f"Error saving results: {e}")
            return False

modules\text_analysis\medical_ner.py

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

天天进步2015

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值