Skip to content

Feature Extraction

The feature extraction module extracts meaningful visual representations from images using pre-trained SimCLR models. These embeddings capture semantic information about mosquito breeding spots and can be used for downstream tasks like clustering and classification.

Overview

The feature extraction process:

  1. Model Loading: Loads a pre-trained or fine-tuned SimCLR model
  2. Data Processing: Applies standardized transforms to images
  3. Feature Extraction: Generates dense feature vectors (embeddings)
  4. Storage: Saves embeddings and metadata for further analysis

Quick Start

from prismh.core.extract_embeddings import extract_embeddings_main

# Extract embeddings using default configuration
extract_embeddings_main()

Command Line Usage

# Basic feature extraction
python -m prismh.core.extract_embeddings \
    --input_dir /path/to/clean/images \
    --output_dir /path/to/embeddings

# With specific model and device
python -m prismh.core.extract_embeddings \
    --input_dir results/clean \
    --output_dir results/embeddings \
    --model_path models/simclr_finetuned.pt \
    --device cuda \
    --batch_size 64

API Reference

SimCLRModel

Bases: Module

SimCLR model with encoder and projection head

Source code in src/prismh/core/extract_embeddings.py
class SimCLRModel(nn.Module):
    """SimCLR model with encoder and projection head"""
    def __init__(self, base_model='resnet50', pretrained=True, output_dim=128):
        super(SimCLRModel, self).__init__()
        if base_model == 'resnet50':
            # Use the updated weights parameter for torchvision >= 0.13
            weights = models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
            self.encoder = models.resnet50(weights=weights)
            self.encoder_dim = 2048
        elif base_model == 'resnet18':
            weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
            self.encoder = models.resnet18(weights=weights)
            self.encoder_dim = 512
        else:
            raise ValueError(f"Unsupported base model: {base_model}")
        self.encoder.fc = nn.Identity()
        self.projection_head = SimCLRProjectionHead(
            input_dim=self.encoder_dim,
            output_dim=output_dim
        )
    def forward(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        # Note: For embedding extraction, we only need 'features'
        return features, projections # Original return for compatibility if needed elsewhere

PathBasedDataset

Bases: Dataset

Source code in src/prismh/core/extract_embeddings.py
class PathBasedDataset(Dataset): # Simplified dataset for extraction
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, img_path # Return image and its path
        except Exception as e:
            print(f"Warning: Error loading image {img_path}: {e}")
            # Return None or a placeholder if an image fails to load
            # For simplicity, we'll return None and handle it in the loop
            return None, img_path

extract_embeddings_main

extract_embeddings_main()
Source code in src/prismh/core/extract_embeddings.py
def extract_embeddings_main():
    # --- Configuration ---
    output_dir = Path("simclr_finetuned")
    checkpoint_dir = output_dir / "checkpoints"
    splits_file = output_dir / "data_splits.pkl"
    batch_size = 64 # Can be larger for inference
    num_workers = 0 # Set to 0 for macOS compatibility if needed
    # --- End Configuration ---

    # Device configuration
    if torch.cuda.is_available():
        device = torch.device('cuda')
        print(f"Using CUDA GPU: {torch.cuda.get_device_name(0)}")
    elif hasattr(torch, 'backends') and hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("Using Apple Silicon GPU (MPS)")
    else:
        device = torch.device('cpu')
        print("Using CPU")

    # Load data splits
    try:
        with splits_file.open('rb') as f:
            splits = pickle.load(f)
        train_paths = splits['train']
        val_paths = splits['val']
        test_paths = splits['test']
        all_image_paths = train_paths + val_paths + test_paths
        print(f"Loaded data splits from {splits_file}")
    except FileNotFoundError:
        print(f"Error: Data splits file not found at {splits_file}")
        print("Please ensure 'simclr.py' has been run successfully to generate the splits file.")
        return
    except Exception as e:
        print(f"Error loading splits file {splits_file}: {e}")
        return

    # Initialize model architecture
    # Set pretrained=False as we are loading specific weights
    model = SimCLRModel(base_model='resnet50', pretrained=False, output_dim=128)

    # Find the best checkpoint to load
    best_model_path = checkpoint_dir / 'best_model.pt'
    latest_epoch_checkpoint = None

    # Check if checkpoint directory exists
    if not checkpoint_dir.is_dir():
        print(f"Error: Checkpoint directory not found at {checkpoint_dir}")
        print("Please ensure 'simclr.py' has been run and checkpoints are saved.")
        return

    checkpoints = list(checkpoint_dir.glob('checkpoint_epoch_*.pt'))

    if checkpoints:
        # Sort by epoch number extracted from filename (Path.stem extracts filename without suffix)
        try:
            checkpoints.sort(key=lambda p: int(p.stem.split('_')[-1]), reverse=True)
            latest_epoch_checkpoint = checkpoints[0]
        except (ValueError, IndexError):
            print("Warning: Could not parse epoch number from checkpoint filenames.")
            latest_epoch_checkpoint = None # Reset if parsing fails

    checkpoint_to_load = None
    if best_model_path.exists():
        checkpoint_to_load = best_model_path
        print(f"Found best model checkpoint: {best_model_path}")
    elif latest_epoch_checkpoint:
        checkpoint_to_load = latest_epoch_checkpoint
        print(f"Using latest epoch checkpoint: {latest_epoch_checkpoint}")
    else:
        print(f"Error: No suitable checkpoint (.pt file starting with 'checkpoint_epoch_' or 'best_model.pt') found in {checkpoint_dir}")
        print("Please ensure 'simclr.py' has run and saved checkpoints.")
        return

    # Load the checkpoint
    try:
        checkpoint = torch.load(checkpoint_to_load, map_location=device)
        # Ensure the checkpoint contains the model state dict
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f"Successfully loaded weights from {checkpoint_to_load}")
        else:
            print(f"Error: Checkpoint {checkpoint_to_load} does not contain 'model_state_dict'.")
            return
    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_to_load}")
        return
    except Exception as e:
        print(f"Error loading checkpoint {checkpoint_to_load}: {e}")
        return

    model = model.to(device)
    model.eval() # Set to evaluation mode

    # Define evaluation transform (consistent preprocessing)
    eval_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Function to extract embeddings for a given set of paths
    def run_extraction(paths, output_filename):
        dataset = PathBasedDataset(paths, transform=eval_transform)
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True if device != torch.device('cpu') else False, # pin_memory only works with CUDA
            drop_last=False
        )

        embeddings_list = []
        filepaths_list = []

        with torch.no_grad():
            for images, loaded_paths in tqdm(dataloader, desc=f"Extracting for {output_filename}"):
                # Handle potential loading errors from dataset
                # Filter out None images and their corresponding paths
                valid_indices = [i for i, img in enumerate(images) if img is not None]
                if not valid_indices: # Skip batch if all images failed to load
                    # loaded_paths is a tuple here, need to access elements
                    failed_paths = [p for i, p in enumerate(loaded_paths) if i not in valid_indices]
                    if failed_paths: # Only print if there were actually paths that failed
                        print(f"Warning: Skipping batch, failed to load images: {failed_paths}")
                    continue

                images_tensor = torch.stack([images[i] for i in valid_indices]).to(device)
                valid_paths = [loaded_paths[i] for i in valid_indices]

                # Get features from the encoder
                features, _ = model(images_tensor)

                embeddings_list.append(features.cpu().numpy())
                filepaths_list.extend(valid_paths)

        if embeddings_list:
            embeddings_np = np.vstack(embeddings_list)
            # Ensure output directory exists
            output_dir.mkdir(parents=True, exist_ok=True)
            output_path = output_dir / output_filename
            # Save file paths as UTF-8 strings
            np.savez_compressed(output_path, embeddings=embeddings_np, file_paths=np.array(filepaths_list, dtype='str'))
            print(f"Saved {len(embeddings_np)} embeddings to {output_path}")
        else:
            print(f"No embeddings extracted for {output_filename}. This might happen if all images in the split failed to load.")

    # Run extraction for all required splits
    print("Starting embedding extraction...")
    run_extraction(train_paths, 'train_embeddings.npz')
    run_extraction(val_paths, 'val_embeddings.npz')
    run_extraction(test_paths, 'test_embeddings.npz')
    run_extraction(all_image_paths, 'all_embeddings.npz')

    print("Embedding extraction complete.")

Configuration

Model Configuration

Parameter Default Description
model_path Auto-detect Path to SimCLR checkpoint
base_model resnet50 Backbone architecture
output_dim 128 Projection head output dimension

Processing Configuration

Parameter Default Description
batch_size 64 Batch size for inference
num_workers 0 DataLoader worker processes
device Auto-detect Device (cpu/cuda/mps)

Data Configuration

Parameter Default Description
input_dir preprocess_results/clean Clean images directory
output_dir simclr_finetuned Output directory
image_size 224 Input image size

Output Format

Embeddings File

The extraction process generates all_embeddings.npz containing:

# Load embeddings
data = np.load('all_embeddings.npz', allow_pickle=True)

embeddings = data['embeddings']      # Shape: (N, feature_dim)
file_paths = data['file_paths']      # Shape: (N,) - corresponding file paths

File Structure

output_dir/
├── all_embeddings.npz              # Main embeddings file
├── train_embeddings.npz            # Training set embeddings
├── val_embeddings.npz              # Validation set embeddings
├── test_embeddings.npz             # Test set embeddings
└── extraction_metadata.json        # Extraction configuration

Usage Examples

Basic Extraction

from prismh.core.extract_embeddings import extract_embeddings_main
from pathlib import Path
import numpy as np

def basic_extraction():
    # Run extraction with default settings
    extract_embeddings_main()

    # Load and examine results
    embeddings_file = Path("simclr_finetuned/all_embeddings.npz")
    if embeddings_file.exists():
        data = np.load(embeddings_file, allow_pickle=True)
        print(f"Extracted {len(data['embeddings'])} embeddings")
        print(f"Feature dimension: {data['embeddings'].shape[1]}")
    else:
        print("No embeddings found. Check configuration.")

basic_extraction()

Custom Model Path

from prismh.core.extract_embeddings import extract_embeddings_main
import os

def extract_with_custom_model():
    # Set custom model path
    os.environ['SIMCLR_MODEL_PATH'] = 'models/custom_simclr.pt'

    # Extract embeddings
    extract_embeddings_main()

    print("Extraction completed with custom model")

extract_with_custom_model()

Batch Processing with Custom Configuration

from prismh.core.extract_embeddings import SimCLRModel, PathBasedDataset
from torch.utils.data import DataLoader
import torch
import numpy as np
from pathlib import Path

def custom_extraction(image_dir, model_path, output_file, batch_size=32):
    """Custom feature extraction with full control"""

    # Setup device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load model
    model = SimCLRModel(base_model='resnet50', pretrained=False)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()

    # Prepare data
    image_paths = list(Path(image_dir).glob("*.jpg"))
    dataset = PathBasedDataset(image_paths, transform=get_eval_transform())
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    # Extract embeddings
    all_embeddings = []
    all_paths = []

    with torch.no_grad():
        for batch_images, batch_paths in dataloader:
            if batch_images is not None:
                batch_images = batch_images.to(device)
                features, _ = model(batch_images)

                all_embeddings.append(features.cpu().numpy())
                all_paths.extend(batch_paths)

    # Save results
    embeddings = np.vstack(all_embeddings)
    np.savez_compressed(
        output_file,
        embeddings=embeddings,
        file_paths=np.array(all_paths)
    )

    print(f"Saved {len(embeddings)} embeddings to {output_file}")

# Usage
custom_extraction(
    image_dir="data/clean_images",
    model_path="models/simclr_best.pt",
    output_file="custom_embeddings.npz"
)

Performance Optimization

GPU Optimization

import torch

# Optimize for GPU
if torch.cuda.is_available():
    # Enable memory efficiency
    torch.backends.cudnn.benchmark = True

    # Use larger batch sizes
    batch_size = 128

    # Enable pin memory
    pin_memory = True
else:
    batch_size = 32
    pin_memory = False

Memory Management

import gc
import torch

def memory_efficient_extraction():
    """Memory-efficient feature extraction"""

    # Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Process in smaller batches
    batch_size = 32

    # Clear variables after use
    del model, embeddings
    gc.collect()

Parallel Processing

from concurrent.futures import ThreadPoolExecutor
import numpy as np

def parallel_extraction(image_dirs, output_dir):
    """Extract embeddings from multiple directories in parallel"""

    def extract_single_dir(image_dir):
        dir_name = Path(image_dir).name
        output_file = Path(output_dir) / f"{dir_name}_embeddings.npz"

        # Run extraction for this directory
        custom_extraction(image_dir, "models/simclr.pt", output_file)
        return output_file

    # Process directories in parallel
    with ThreadPoolExecutor(max_workers=3) as executor:
        futures = [executor.submit(extract_single_dir, dir_path) 
                  for dir_path in image_dirs]

        results = [future.result() for future in futures]

    print(f"Completed parallel extraction: {results}")

Integration with Pipeline

After Preprocessing

from prismh.core.preprocess import ImagePreprocessor
from prismh.core.extract_embeddings import extract_embeddings_main

def preprocess_and_extract():
    """Complete preprocessing and feature extraction"""

    # Step 1: Preprocess images
    preprocessor = ImagePreprocessor(
        input_dir="raw_images",
        output_dir="processed"
    )
    preprocessor.run_preprocessing()

    # Step 2: Extract features from clean images
    # Update configuration to use clean images
    import os
    os.environ['CLEAN_IMAGES_DIR'] = 'processed/clean'

    extract_embeddings_main()

    print("Preprocessing and feature extraction completed")

preprocess_and_extract()

Before Clustering

from prismh.core.extract_embeddings import extract_embeddings_main
from prismh.core.cluster_embeddings import cluster_main

def extract_and_cluster():
    """Feature extraction followed by clustering"""

    # Extract embeddings
    extract_embeddings_main()

    # Run clustering on embeddings
    cluster_main()

    print("Feature extraction and clustering completed")

extract_and_cluster()

Model Compatibility

Supported Architectures

Model Backbone Feature Dim Use Case
SimCLR-ResNet18 ResNet-18 512 Fast inference
SimCLR-ResNet50 ResNet-50 2048 Best performance
Custom SimCLR Various Configurable Domain-specific

Loading Different Models

# Load ImageNet pretrained
model = SimCLRModel(base_model='resnet50', pretrained=True)

# Load custom checkpoint
checkpoint = torch.load('custom_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Load fine-tuned model
model = SimCLRModel(base_model='resnet50', pretrained=False)
model.load_state_dict(torch.load('finetuned_simclr.pt'))

Quality Assessment

Embedding Quality Metrics

from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
import numpy as np

def assess_embedding_quality(embeddings_file):
    """Assess the quality of extracted embeddings"""

    data = np.load(embeddings_file)
    embeddings = data['embeddings']

    # Clustering-based quality assessment
    kmeans = KMeans(n_clusters=5, random_state=42)
    cluster_labels = kmeans.fit_predict(embeddings)

    # Silhouette score (higher is better)
    silhouette = silhouette_score(embeddings, cluster_labels)

    # Embedding statistics
    mean_norm = np.mean(np.linalg.norm(embeddings, axis=1))
    std_norm = np.std(np.linalg.norm(embeddings, axis=1))

    metrics = {
        'silhouette_score': silhouette,
        'mean_embedding_norm': mean_norm,
        'std_embedding_norm': std_norm,
        'num_embeddings': len(embeddings),
        'embedding_dim': embeddings.shape[1]
    }

    return metrics

# Assess quality
quality = assess_embedding_quality('all_embeddings.npz')
print(f"Embedding quality metrics: {quality}")

Troubleshooting

Common Issues

Model not found:

# Check model path and ensure it exists
model_path = Path("models/simclr_model.pt")
if not model_path.exists():
    print(f"Model not found at {model_path}")
    print("Train a SimCLR model first or download a pretrained one")

Out of memory:

# Reduce batch size
batch_size = 16  # Instead of 64

# Clear GPU cache
torch.cuda.empty_cache()

# Use CPU if necessary
device = torch.device('cpu')

Inconsistent image sizes:

# Ensure all images are properly resized
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

Advanced Usage

Custom Feature Extractors

class CustomFeatureExtractor:
    def __init__(self, model_path, device='auto'):
        self.device = self._setup_device(device)
        self.model = self._load_model(model_path)
        self.transform = self._get_transform()

    def extract_features(self, image_paths, batch_size=32):
        """Extract features from a list of image paths"""
        dataset = PathBasedDataset(image_paths, self.transform)
        dataloader = DataLoader(dataset, batch_size=batch_size)

        features = []
        with torch.no_grad():
            for batch in dataloader:
                batch_features = self.model.encoder(batch.to(self.device))
                features.append(batch_features.cpu().numpy())

        return np.vstack(features)

# Usage
extractor = CustomFeatureExtractor('models/custom_simclr.pt')
features = extractor.extract_features(image_paths)