Skip to content

SimCLR Module

The SimCLR module implements self-supervised contrastive learning for extracting meaningful visual representations from mosquito breeding spot images without requiring extensive labeled data.

Overview

SimCLR (Simple Contrastive Learning of Visual Representations) learns visual features by maximizing agreement between differently augmented views of the same image. This approach is particularly valuable for our use case where labeled data is limited or unreliable.

Key Components

  • SimCLRModel: Main model architecture with ResNet backbone
  • SimCLRDataset: Dataset class for contrastive learning
  • NTXentLoss: Normalized Temperature-scaled Cross Entropy loss
  • Training Pipeline: Complete training workflow with checkpointing

Quick Start

from prismh.models.simclr import SimCLRModel, train_simclr
import torch

# Initialize model
model = SimCLRModel(base_model='resnet50', pretrained=True)

# Train on your data
train_simclr(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=100
)

Command Line Usage

# Train SimCLR model
python -m prismh.models.simclr \
    --data_dir /path/to/clean/images \
    --output_dir simclr_results \
    --epochs 100 \
    --batch_size 32

# Resume training from checkpoint
python -m prismh.models.simclr \
    --data_dir /path/to/clean/images \
    --output_dir simclr_results \
    --resume_from simclr_results/checkpoints/checkpoint_epoch_50.pt

API Reference

SimCLRModel

Bases: Module

SimCLR model with encoder and projection head

Source code in src/prismh/models/simclr.py
class SimCLRModel(nn.Module):
    """SimCLR model with encoder and projection head"""

    def __init__(self, base_model='resnet50', pretrained=True, output_dim=128):
        super(SimCLRModel, self).__init__()

        # Load the base encoder model (e.g., ResNet-50)
        if base_model == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            self.encoder_dim = 2048
        elif base_model == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            self.encoder_dim = 512
        else:
            raise ValueError(f"Unsupported base model: {base_model}")

        # Replace the final fully connected layer
        self.encoder.fc = nn.Identity()

        # Add projection head
        self.projection_head = SimCLRProjectionHead(
            input_dim=self.encoder_dim,
            output_dim=output_dim
        )

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        return features, F.normalize(projections, dim=1)

SimCLRDataset

Bases: Dataset

Dataset for SimCLR that returns two augmented views of each image

Source code in src/prismh/models/simclr.py
class SimCLRDataset(Dataset):
    """Dataset for SimCLR that returns two augmented views of each image"""

    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string or Path): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.image_paths = []

        # Collect all image paths
        for entry in self.root_dir.iterdir():
            if entry.is_file() and entry.suffix.lower() in ['.png', '.jpg', '.jpeg']:
                self.image_paths.append(entry)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            # Apply the same transform twice to get two different augmented views
            view1 = self.transform(image)
            view2 = self.transform(image)
            return view1, view2

        return image, image

__init__

__init__(root_dir, transform=None)

Parameters:

Name Type Description Default
root_dir string or Path

Directory with all the images.

required
transform callable

Optional transform to be applied on a sample.

None
Source code in src/prismh/models/simclr.py
def __init__(self, root_dir, transform=None):
    """
    Args:
        root_dir (string or Path): Directory with all the images.
        transform (callable, optional): Optional transform to be applied on a sample.
    """
    self.root_dir = Path(root_dir)
    self.transform = transform
    self.image_paths = []

    # Collect all image paths
    for entry in self.root_dir.iterdir():
        if entry.is_file() and entry.suffix.lower() in ['.png', '.jpg', '.jpeg']:
            self.image_paths.append(entry)

SimCLRProjectionHead

Bases: Module

Projection head for SimCLR

Source code in src/prismh/models/simclr.py
class SimCLRProjectionHead(nn.Module):
    """Projection head for SimCLR"""

    def __init__(self, input_dim, hidden_dim=2048, output_dim=128):
        super(SimCLRProjectionHead, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.projection(x)

NTXentLoss

Bases: Module

Normalized Temperature-scaled Cross Entropy Loss from SimCLR paper

Source code in src/prismh/models/simclr.py
class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss from SimCLR paper
    """
    def __init__(self, temperature=0.5, batch_size=32):
        super(NTXentLoss, self).__init__()
        self.temperature = temperature
        self.batch_size = batch_size
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)
        # Mask to remove positive examples from the denominator of the loss function
        mask = torch.ones((2 * batch_size, 2 * batch_size), dtype=bool)
        mask.fill_diagonal_(0)

        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0

        self.register_buffer("mask", mask)

    def forward(self, z_i, z_j):
        """
        Calculate NT-Xent loss
        Args:
            z_i, z_j: Normalized projection vectors from the two augmented views
        """
        # Calculate cosine similarity
        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = self.similarity_f(representations.unsqueeze(1), representations.unsqueeze(0)) / self.temperature

        # Mask out the positives
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        # Mask out the diagnonal (self-similarity)
        negatives = similarity_matrix[self.mask].reshape(2 * self.batch_size, -1)

        # Create labels - positives are the "correct" predictions
        labels = torch.zeros(2 * self.batch_size).long().to(positives.device)

        # Calculate loss
        logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
        loss = self.criterion(logits, labels)
        loss = loss / (2 * self.batch_size)

        return loss

forward

forward(z_i, z_j)

Calculate NT-Xent loss Args: z_i, z_j: Normalized projection vectors from the two augmented views

Source code in src/prismh/models/simclr.py
def forward(self, z_i, z_j):
    """
    Calculate NT-Xent loss
    Args:
        z_i, z_j: Normalized projection vectors from the two augmented views
    """
    # Calculate cosine similarity
    representations = torch.cat([z_i, z_j], dim=0)
    similarity_matrix = self.similarity_f(representations.unsqueeze(1), representations.unsqueeze(0)) / self.temperature

    # Mask out the positives
    sim_ij = torch.diag(similarity_matrix, self.batch_size)
    sim_ji = torch.diag(similarity_matrix, -self.batch_size)
    positives = torch.cat([sim_ij, sim_ji], dim=0)

    # Mask out the diagnonal (self-similarity)
    negatives = similarity_matrix[self.mask].reshape(2 * self.batch_size, -1)

    # Create labels - positives are the "correct" predictions
    labels = torch.zeros(2 * self.batch_size).long().to(positives.device)

    # Calculate loss
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    loss = self.criterion(logits, labels)
    loss = loss / (2 * self.batch_size)

    return loss

train_simclr

train_simclr(model, train_loader, optimizer, criterion, device, epochs=100, checkpoint_dir='checkpoints', resume_from=None, early_stopping_patience=10, validation_loader=None, writer=None)

Train the SimCLR model with proper analytics and early stopping

Source code in src/prismh/models/simclr.py
def train_simclr(model, train_loader, optimizer, criterion, device, epochs=100, checkpoint_dir="checkpoints", 
                 resume_from=None, early_stopping_patience=10, validation_loader=None, writer=None):
    """Train the SimCLR model with proper analytics and early stopping"""
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # Initialize trackers
    best_loss = float('inf')
    best_model_path = checkpoint_dir / 'best_model.pt'
    patience_counter = 0
    start_epoch = 0
    train_losses = []
    val_losses = []
    lr_history = []
    global_step = 0

    # Resume from checkpoint if available
    if resume_from:
        resume_path = Path(resume_from)
        if resume_path.exists():
            print(f"Resuming training from checkpoint: {resume_path}")
            checkpoint = torch.load(resume_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            if 'best_loss' in checkpoint:
                best_loss = checkpoint['best_loss']
            if 'train_losses' in checkpoint:
                train_losses = checkpoint['train_losses']
            if 'val_losses' in checkpoint:
                val_losses = checkpoint['val_losses']
            if 'lr_history' in checkpoint:
                lr_history = checkpoint['lr_history']
            if 'global_step' in checkpoint:
                global_step = checkpoint['global_step']
            print(f"Resuming from epoch {start_epoch}, best loss: {best_loss:.4f}, global step: {global_step}")

    # Create a learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    # Log model graph to TensorBoard if available
    if writer is not None:
        # Get a batch of data for graph visualization
        sample_images, _ = next(iter(train_loader))
        sample_images = sample_images.to(device)
        writer.add_graph(model, sample_images)

        # Log some example augmented pairs
        fig = plt.figure(figsize=(12, 6))
        for i in range(min(4, len(sample_images))):
            ax1 = fig.add_subplot(2, 4, i+1)
            img = sample_images[i].cpu().permute(1, 2, 0).numpy()
            # Denormalize image
            img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img = np.clip(img, 0, 1)
            ax1.imshow(img)
            ax1.set_title(f"View 1 - img {i}")
            ax1.axis('off')

            sample2, _ = next(iter(train_loader))
            ax2 = fig.add_subplot(2, 4, i+5)
            img2 = sample2[i].cpu().permute(1, 2, 0).numpy()
            img2 = img2 * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img2 = np.clip(img2, 0, 1)
            ax2.imshow(img2)
            ax2.set_title(f"View 2 - img {i}")
            ax2.axis('off')

        plt.tight_layout()
        writer.add_figure('Example Augmented Pairs', fig, global_step=0)

    # Function to evaluate on validation set
    def evaluate():
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for images1, images2 in validation_loader:
                images1, images2 = images1.to(device), images2.to(device)
                _, z1 = model(images1)
                _, z2 = model(images2)
                loss = criterion(z1, z2)
                total_val_loss += loss.item()
        return total_val_loss / len(validation_loader)

    # Function to log embeddings
    def log_embeddings(step):
        if writer is None:
            return

        # Extract embeddings for visualization
        model.eval()
        embeddings = []
        imgs = []
        with torch.no_grad():
            for i, (images1, _) in enumerate(validation_loader):
                if i >= 2:  # Limit to a few batches for visualization
                    break
                images1 = images1.to(device)
                features, _ = model(images1)
                embeddings.append(features.cpu().numpy())
                imgs.append(images1.cpu())

        if not embeddings:
            return

        embeddings = np.vstack(embeddings)
        imgs = torch.cat(imgs, dim=0)

        # Use t-SNE for dimensionality reduction
        if len(embeddings) > 10:  # Need enough samples for meaningful t-SNE
            tsne = TSNE(n_components=2, random_state=42)
            embeddings_2d = tsne.fit_transform(embeddings)

            # Plot t-SNE visualization
            fig = plt.figure(figsize=(10, 10))
            plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.5)
            plt.title('t-SNE of Feature Embeddings')
            writer.add_figure('Embeddings/t-SNE', fig, global_step=step)

        # Log embeddings with images
        writer.add_embedding(
            mat=torch.from_numpy(embeddings),
            label_img=imgs,
            global_step=step,
            tag='features' # Explicit tag for embeddings
        )

    print(f"Starting training from epoch {start_epoch+1}/{epochs}")
    for epoch in range(start_epoch, epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

        for i, (images1, images2) in enumerate(progress_bar):
            # Move images to device
            images1 = images1.to(device)
            images2 = images2.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass for both augmented views
            _, z1 = model(images1)
            _, z2 = model(images2)

            # Calculate loss
            loss = criterion(z1, z2)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Update statistics
            running_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})

            # Log to TensorBoard (every 10 batches)
            if writer is not None and i % 10 == 0:
                writer.add_scalar('Batch/train_loss', loss.item(), global_step)
                global_step += 1

        # Record average loss for the epoch
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)

        # Track current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        lr_history.append(current_lr)

        # Validation phase (if validation loader provided)
        val_loss = None
        if validation_loader:
            val_loss = evaluate()
            val_losses.append(val_loss)
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, LR: {current_lr:.6f}")

            # Log metrics to TensorBoard
            if writer is not None:
                writer.add_scalar('Epoch/train_loss', epoch_loss, epoch)
                writer.add_scalar('Epoch/val_loss', val_loss, epoch)
                writer.add_scalar('Epoch/learning_rate', current_lr, epoch)

                # Log embeddings periodically
                if epoch % 5 == 0 or epoch == epochs - 1:
                    log_embeddings(epoch)

            # Update scheduler based on validation loss
            scheduler.step(val_loss)

            # Early stopping check
            if val_loss < best_loss:
                best_loss = val_loss
                # Save best model
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': val_loss,
                    'best_loss': best_loss,
                    'train_losses': train_losses,
                    'val_losses': val_losses,
                    'lr_history': lr_history,
                    'global_step': global_step,
                }, best_model_path)
                patience_counter = 0
                print(f"New best model saved with validation loss: {best_loss:.4f}")
            else:
                patience_counter += 1
                print(f"Validation loss did not improve. Patience: {patience_counter}/{early_stopping_patience}")

                if patience_counter >= early_stopping_patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    # Load the best model before returning
                    checkpoint = torch.load(best_model_path, map_location=device)
                    model.load_state_dict(checkpoint['model_state_dict'])
                    break
        else:
            # If no validation set, use training loss
            print(f"Epoch {epoch+1}/{epochs}, Train Loss: {epoch_loss:.4f}, LR: {current_lr:.6f}")

            # Log to TensorBoard
            if writer is not None:
                writer.add_scalar('Epoch/train_loss', epoch_loss, epoch)
                writer.add_scalar('Epoch/learning_rate', current_lr, epoch)

                # Log embeddings periodically
                if epoch % 5 == 0 or epoch == epochs - 1:
                    log_embeddings(epoch)

            # Save if better than best so far
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': epoch_loss,
                    'best_loss': best_loss,
                    'train_losses': train_losses,
                    'val_losses': val_losses,
                    'lr_history': lr_history,
                    'global_step': global_step,
                }, best_model_path)
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    # Load the best model before returning
                    checkpoint = torch.load(best_model_path, map_location=device)
                    model.load_state_dict(checkpoint['model_state_dict'])
                    break

            # Update scheduler based on training loss
            scheduler.step(epoch_loss)

        # Regular checkpoint (every 5 epochs)
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss if validation_loader else epoch_loss,
                'best_loss': best_loss,
                'train_losses': train_losses,
                'val_losses': val_losses,
                'lr_history': lr_history,
                'global_step': global_step,
            }, checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pt')

    # Plot training curves
    plt.figure(figsize=(15, 5))

    # Plot loss curves
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    if val_losses:
        plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('SimCLR Training Loss')
    plt.legend()

    # Plot learning rate
    plt.subplot(1, 2, 2)
    plt.plot(lr_history)
    plt.xlabel('Epoch')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate Schedule')
    plt.yscale('log')

    plt.tight_layout()
    plt.savefig(checkpoint_dir / 'training_curves.png')

    # Save training history as CSV for further analysis
    history = pd.DataFrame({
        'epoch': list(range(1, len(train_losses) + 1)),
        'train_loss': train_losses,
        'val_loss': val_losses if val_losses else [None] * len(train_losses),
        'learning_rate': lr_history
    })
    history.to_csv(checkpoint_dir / 'training_history.csv', index=False)

    print(f"Training completed. Best loss: {best_loss:.4f}")
    print(f"Training analytics saved to {checkpoint_dir}/")

    return model, {'train_losses': train_losses, 'val_losses': val_losses, 'lr_history': lr_history}

Architecture Details

Model Architecture

graph TD
    A[Input Images] --> B[Data Augmentation]
    B --> C[Two Augmented Views]
    C --> D[ResNet Encoder]
    D --> E[Feature Representations]
    E --> F[Projection Head]
    F --> G[Normalized Projections]
    G --> H[Contrastive Loss]

Backbone Options

Backbone Output Dim Parameters Use Case
ResNet-18 512 11.7M Fast experimentation
ResNet-50 2048 25.6M Recommended for production

Data Augmentation Pipeline

transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

Training Configuration

Hyperparameters

Parameter Default Description
batch_size 32 Batch size for training
learning_rate 0.001 Initial learning rate
temperature 0.5 Temperature for contrastive loss
epochs 100 Number of training epochs
weight_decay 1e-4 L2 regularization

Training Features

  • Automatic Mixed Precision: Faster training with reduced memory
  • Early Stopping: Prevents overfitting
  • Learning Rate Scheduling: Adaptive learning rate adjustment
  • Checkpointing: Resume training from interruptions
  • TensorBoard Logging: Training metrics visualization

Usage Examples

Basic Training

from prismh.models.simclr import SimCLRModel, SimCLRDataset, train_simclr
from torch.utils.data import DataLoader
import torch

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create dataset and dataloader
dataset = SimCLRDataset(
    root_dir='data/clean_images',
    transform=get_simclr_transforms()
)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model
model = SimCLRModel(base_model='resnet50').to(device)

# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = NTXentLoss(temperature=0.5, batch_size=32)

# Train model
train_simclr(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=100
)

Advanced Training with Validation

from torch.utils.tensorboard import SummaryWriter

# Setup validation
val_dataset = SimCLRDataset('data/val_images', transform=get_simclr_transforms())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Setup logging
writer = SummaryWriter('runs/simclr_experiment')

# Train with validation
train_simclr(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=100,
    validation_loader=val_loader,
    writer=writer,
    early_stopping_patience=10
)

Fine-tuning Pretrained Model

# Load pretrained SimCLR model
model = SimCLRModel(base_model='resnet50', pretrained=True)

# Load domain-specific checkpoint
checkpoint = torch.load('pretrained_simclr.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Fine-tune on your data
train_simclr(
    model=model,
    train_loader=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    epochs=50,  # Fewer epochs for fine-tuning
    checkpoint_dir='finetuned_checkpoints'
)

Feature Extraction

After training, use the encoder for downstream tasks:

# Load trained model
model = SimCLRModel(base_model='resnet50')
checkpoint = torch.load('best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

# Extract features
model.eval()
with torch.no_grad():
    features, _ = model(images)
    # Use features for classification, clustering, etc.

Performance Optimization

Multi-GPU Training

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

Memory Optimization

# Gradient accumulation for larger effective batch size
accumulation_steps = 4
for i, (images1, images2) in enumerate(train_loader):
    # Forward pass
    loss = compute_loss(images1, images2) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Mixed Precision Training

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for images1, images2 in train_loader:
    with autocast():
        loss = compute_loss(images1, images2)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

Monitoring and Visualization

TensorBoard Integration

# Log training metrics
writer.add_scalar('Loss/Train', loss.item(), epoch)
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch)

# Log embeddings
writer.add_embedding(features, metadata=labels, tag='SimCLR_Features')

t-SNE Visualization

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Extract features
features = extract_features(model, dataloader)

# Reduce dimensionality
tsne = TSNE(n_components=2, random_state=42)
features_2d = tsne.fit_transform(features)

# Plot
plt.scatter(features_2d[:, 0], features_2d[:, 1])
plt.title('SimCLR Features t-SNE')
plt.show()

Model Evaluation

Downstream Task Performance

# Evaluate on classification task
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# Extract features
features = extract_features(model, dataloader)

# Train classifier
classifier = LogisticRegression()
classifier.fit(features, labels)

# Evaluate
predictions = classifier.predict(test_features)
accuracy = accuracy_score(test_labels, predictions)
print(f'Classification accuracy: {accuracy:.3f}')

Clustering Quality

from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# Cluster features
kmeans = KMeans(n_clusters=5)
cluster_labels = kmeans.fit_predict(features)

# Evaluate clustering
silhouette = silhouette_score(features, cluster_labels)
print(f'Silhouette score: {silhouette:.3f}')

Troubleshooting

Common Issues

Out of memory errors:

# Reduce batch size
batch_size = 16  # Instead of 32

# Use gradient accumulation
accumulation_steps = 2

Slow convergence:

# Increase learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

# Adjust temperature
criterion = NTXentLoss(temperature=0.3)

Poor feature quality:

# Increase augmentation strength
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8)

# Train for more epochs
epochs = 200

Integration with Pipeline

With Preprocessing

from prismh.core.preprocess import ImagePreprocessor

# Preprocess first
preprocessor = ImagePreprocessor(input_dir='raw/', output_dir='processed/')
preprocessor.run_preprocessing()

# Then train SimCLR on clean images
dataset = SimCLRDataset('processed/clean')
# ... training code

With Classification

# After SimCLR training, use for classification
from prismh.models.classify import train_classifier

# Use SimCLR encoder as backbone
encoder = model.encoder
train_classifier(encoder, labeled_data)