SimCLR Module¶
The SimCLR module implements self-supervised contrastive learning for extracting meaningful visual representations from mosquito breeding spot images without requiring extensive labeled data.
Overview¶
SimCLR (Simple Contrastive Learning of Visual Representations) learns visual features by maximizing agreement between differently augmented views of the same image. This approach is particularly valuable for our use case where labeled data is limited or unreliable.
Key Components¶
- SimCLRModel: Main model architecture with ResNet backbone
- SimCLRDataset: Dataset class for contrastive learning
- NTXentLoss: Normalized Temperature-scaled Cross Entropy loss
- Training Pipeline: Complete training workflow with checkpointing
Quick Start¶
from prismh.models.simclr import SimCLRModel, train_simclr
import torch
# Initialize model
model = SimCLRModel(base_model='resnet50', pretrained=True)
# Train on your data
train_simclr(
model=model,
train_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device=device,
epochs=100
)
Command Line Usage¶
# Train SimCLR model
python -m prismh.models.simclr \
--data_dir /path/to/clean/images \
--output_dir simclr_results \
--epochs 100 \
--batch_size 32
# Resume training from checkpoint
python -m prismh.models.simclr \
--data_dir /path/to/clean/images \
--output_dir simclr_results \
--resume_from simclr_results/checkpoints/checkpoint_epoch_50.pt
API Reference¶
SimCLRModel ¶
Bases: Module
SimCLR model with encoder and projection head
Source code in src/prismh/models/simclr.py
SimCLRDataset ¶
Bases: Dataset
Dataset for SimCLR that returns two augmented views of each image
Source code in src/prismh/models/simclr.py
__init__ ¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
root_dir
|
string or Path
|
Directory with all the images. |
required |
transform
|
callable
|
Optional transform to be applied on a sample. |
None
|
Source code in src/prismh/models/simclr.py
SimCLRProjectionHead ¶
Bases: Module
Projection head for SimCLR
Source code in src/prismh/models/simclr.py
NTXentLoss ¶
Bases: Module
Normalized Temperature-scaled Cross Entropy Loss from SimCLR paper
Source code in src/prismh/models/simclr.py
forward ¶
Calculate NT-Xent loss Args: z_i, z_j: Normalized projection vectors from the two augmented views
Source code in src/prismh/models/simclr.py
train_simclr ¶
train_simclr(model, train_loader, optimizer, criterion, device, epochs=100, checkpoint_dir='checkpoints', resume_from=None, early_stopping_patience=10, validation_loader=None, writer=None)
Train the SimCLR model with proper analytics and early stopping
Source code in src/prismh/models/simclr.py
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 |
|
Architecture Details¶
Model Architecture¶
graph TD
A[Input Images] --> B[Data Augmentation]
B --> C[Two Augmented Views]
C --> D[ResNet Encoder]
D --> E[Feature Representations]
E --> F[Projection Head]
F --> G[Normalized Projections]
G --> H[Contrastive Loss]
Backbone Options¶
Backbone | Output Dim | Parameters | Use Case |
---|---|---|---|
ResNet-18 | 512 | 11.7M | Fast experimentation |
ResNet-50 | 2048 | 25.6M | Recommended for production |
Data Augmentation Pipeline¶
transforms = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
Training Configuration¶
Hyperparameters¶
Parameter | Default | Description |
---|---|---|
batch_size |
32 | Batch size for training |
learning_rate |
0.001 | Initial learning rate |
temperature |
0.5 | Temperature for contrastive loss |
epochs |
100 | Number of training epochs |
weight_decay |
1e-4 | L2 regularization |
Training Features¶
- Automatic Mixed Precision: Faster training with reduced memory
- Early Stopping: Prevents overfitting
- Learning Rate Scheduling: Adaptive learning rate adjustment
- Checkpointing: Resume training from interruptions
- TensorBoard Logging: Training metrics visualization
Usage Examples¶
Basic Training¶
from prismh.models.simclr import SimCLRModel, SimCLRDataset, train_simclr
from torch.utils.data import DataLoader
import torch
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create dataset and dataloader
dataset = SimCLRDataset(
root_dir='data/clean_images',
transform=get_simclr_transforms()
)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# Initialize model
model = SimCLRModel(base_model='resnet50').to(device)
# Setup training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = NTXentLoss(temperature=0.5, batch_size=32)
# Train model
train_simclr(
model=model,
train_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device=device,
epochs=100
)
Advanced Training with Validation¶
from torch.utils.tensorboard import SummaryWriter
# Setup validation
val_dataset = SimCLRDataset('data/val_images', transform=get_simclr_transforms())
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# Setup logging
writer = SummaryWriter('runs/simclr_experiment')
# Train with validation
train_simclr(
model=model,
train_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device=device,
epochs=100,
validation_loader=val_loader,
writer=writer,
early_stopping_patience=10
)
Fine-tuning Pretrained Model¶
# Load pretrained SimCLR model
model = SimCLRModel(base_model='resnet50', pretrained=True)
# Load domain-specific checkpoint
checkpoint = torch.load('pretrained_simclr.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# Fine-tune on your data
train_simclr(
model=model,
train_loader=train_loader,
optimizer=optimizer,
criterion=criterion,
device=device,
epochs=50, # Fewer epochs for fine-tuning
checkpoint_dir='finetuned_checkpoints'
)
Feature Extraction¶
After training, use the encoder for downstream tasks:
# Load trained model
model = SimCLRModel(base_model='resnet50')
checkpoint = torch.load('best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
# Extract features
model.eval()
with torch.no_grad():
features, _ = model(images)
# Use features for classification, clustering, etc.
Performance Optimization¶
Multi-GPU Training¶
Memory Optimization¶
# Gradient accumulation for larger effective batch size
accumulation_steps = 4
for i, (images1, images2) in enumerate(train_loader):
# Forward pass
loss = compute_loss(images1, images2) / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Mixed Precision Training¶
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for images1, images2 in train_loader:
with autocast():
loss = compute_loss(images1, images2)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Monitoring and Visualization¶
TensorBoard Integration¶
# Log training metrics
writer.add_scalar('Loss/Train', loss.item(), epoch)
writer.add_scalar('Learning_Rate', optimizer.param_groups[0]['lr'], epoch)
# Log embeddings
writer.add_embedding(features, metadata=labels, tag='SimCLR_Features')
t-SNE Visualization¶
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
# Extract features
features = extract_features(model, dataloader)
# Reduce dimensionality
tsne = TSNE(n_components=2, random_state=42)
features_2d = tsne.fit_transform(features)
# Plot
plt.scatter(features_2d[:, 0], features_2d[:, 1])
plt.title('SimCLR Features t-SNE')
plt.show()
Model Evaluation¶
Downstream Task Performance¶
# Evaluate on classification task
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# Extract features
features = extract_features(model, dataloader)
# Train classifier
classifier = LogisticRegression()
classifier.fit(features, labels)
# Evaluate
predictions = classifier.predict(test_features)
accuracy = accuracy_score(test_labels, predictions)
print(f'Classification accuracy: {accuracy:.3f}')
Clustering Quality¶
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
# Cluster features
kmeans = KMeans(n_clusters=5)
cluster_labels = kmeans.fit_predict(features)
# Evaluate clustering
silhouette = silhouette_score(features, cluster_labels)
print(f'Silhouette score: {silhouette:.3f}')
Troubleshooting¶
Common Issues¶
Out of memory errors:
# Reduce batch size
batch_size = 16 # Instead of 32
# Use gradient accumulation
accumulation_steps = 2
Slow convergence:
# Increase learning rate
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
# Adjust temperature
criterion = NTXentLoss(temperature=0.3)
Poor feature quality:
# Increase augmentation strength
transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8)
# Train for more epochs
epochs = 200
Integration with Pipeline¶
With Preprocessing¶
from prismh.core.preprocess import ImagePreprocessor
# Preprocess first
preprocessor = ImagePreprocessor(input_dir='raw/', output_dir='processed/')
preprocessor.run_preprocessing()
# Then train SimCLR on clean images
dataset = SimCLRDataset('processed/clean')
# ... training code
With Classification¶
# After SimCLR training, use for classification
from prismh.models.classify import train_classifier
# Use SimCLR encoder as backbone
encoder = model.encoder
train_classifier(encoder, labeled_data)
Related Documentation¶
- Feature Extraction - Extract embeddings using trained SimCLR
- Classification - Downstream classification tasks
- Clustering - Unsupervised analysis
- Configuration - Training parameter tuning