Self-Supervised & Contrastive Learning
Self-Supervised Learning: Unlocking the Potential of Unlabeled Data
With the vast majority of the world's data being unlabeled, self-supervised learning has emerged as a powerful paradigm for learning meaningful representations without manual annotations. In this comprehensive guide, we'll explore contrastive learning methods like SimCLR, MoCo, and BYOL that are closing the gap with supervised learning on many tasks.
1. The Self-Supervised Learning Paradigm
Self-supervised learning creates supervisory signals from the data itself through:
| Approach | Method | Example |
|---|---|---|
| Pretext Tasks | Define artificial tasks | Image rotation prediction |
| Contrastive Learning | Compare similar/dissimilar pairs | SimCLR, MoCo |
| Generative Methods | Reconstruct input | Autoencoders, BERT |
2. Contrastive Learning Framework
Contrastive methods learn by pulling positive pairs together and pushing negatives apart in embedding space:
Key Components
- Augmentations: Create positive pairs (two views of same image)
- Encoder: Maps inputs to embeddings (typically CNN)
- Projection Head: Small MLP for contrastive loss
- Loss Function: NT-Xent (normalized temperature-scaled cross entropy)
The NT-Xent loss for a positive pair (i,j):
ℓi,j = -log exp(sim(zi,zj)/τ / ∑k≠i exp(sim(zi,zk)/τ
3. SimCLR: A Simple Framework
SimCLR established several best practices:
| Component | Implementation | Impact |
|---|---|---|
| Augmentations | Random crop + color jitter + blur | Creates meaningful positives |
| Projection Head | 2-layer MLP with ReLU | Improves representation quality |
| Large Batch Size | 4096+ with LARS optimizer | Provides many negatives |
class SimCLR(nn.Module):
def __init__(self, encoder, projection_dim=128):
super().__init__()
self.encoder = encoder
self.projector = nn.Sequential(
nn.Linear(encoder.output_dim, encoder.output_dim),
nn.ReLU(),
nn.Linear(encoder.output_dim, projection_dim)
)
def forward(self, x1, x2):
# Get representations
h1 = self.encoder(x1)
h2 = self.encoder(x2)
# Project to latent space
z1 = self.projector(h1)
z2 = self.projector(h2)
return h1, h2, z1, z2
def nt_xent_loss(z1, z2, temperature=0.5):
batch_size = z1.size(0)
# Concatenate all embeddings
all_z = torch.cat([z1, z2], dim=0)
# Compute similarity matrix
sim_matrix = torch.matmul(all_z, all_z.T) / temperature
# Create labels (positives are diagonal after concat)
labels = torch.arange(batch_size, device=z1.device)
labels = torch.cat([labels + batch_size, labels])
# Cross-entropy loss
loss = F.cross_entropy(sim_matrix, labels)
return loss
4. Momentum Contrast (MoCo)
MoCo addresses the batch size limitation with:
Key Innovations
- Momentum Encoder: Slowly updated version of main encoder
- Dynamic Queue: Maintains large set of negatives
- Shuffling BN: Prevents information leakage
class MoCo(nn.Module):
def __init__(self, encoder, dim=128, K=65536, m=0.999, T=0.2):
super().__init__()
self.K = K # Queue size
self.m = m # Momentum
self.T = T # Temperature
# Encoders
self.encoder_q = encoder # Query encoder
self.encoder_k = copy.deepcopy(encoder) # Key encoder
# Projection heads
self.projector_q = nn.Sequential(
nn.Linear(encoder.output_dim, encoder.output_dim),
nn.ReLU(),
nn.Linear(encoder.output_dim, dim)
)
self.projector_k = copy.deepcopy(self.projector_q)
# Initialize key encoder
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# Create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = F.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0
# Replace keys at ptr
self.queue[:, ptr:ptr + batch_size] = keys.T
ptr = (ptr + batch_size) % self.K
self.queue_ptr[0] = ptr
5. Bootstrap Your Own Latent (BYOL)
BYOL achieves state-of-the-art without negative samples:
Key Features
- Two networks: online (with predictor) and target (momentum)
- Predicts target network's representation
- No contrastive loss - uses MSE between projections
- Surprisingly avoids collapsed solutions
class BYOL(nn.Module):
def __init__(self, encoder, projection_dim=256, hidden_dim=4096, m=0.996):
super().__init__()
self.m = m
# Online network
self.online_encoder = encoder
self.online_projector = nn.Sequential(
nn.Linear(encoder.output_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
self.online_predictor = nn.Sequential(
nn.Linear(projection_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, projection_dim)
)
# Target network
self.target_encoder = copy.deepcopy(encoder)
self.target_projector = copy.deepcopy(self.online_projector)
# Initialize target as online
for param in self.target_encoder.parameters():
param.requires_grad = False
for param in self.target_projector.parameters():
param.requires_grad = False
@torch.no_grad()
def update_target(self):
# Momentum update target networks
for online_param, target_param in zip(
self.online_encoder.parameters(), self.target_encoder.parameters()
):
target_param.data = self.m * target_param.data + (1 - self.m) * online_param.data
for online_param, target_param in zip(
self.online_projector.parameters(), self.target_projector.parameters()
):
target_param.data = self.m * target_param.data + (1 - self.m) * online_param.data
def forward(self, x1, x2):
# Online network forward (both augmented views)
h1 = self.online_encoder(x1)
z1 = self.online_projector(h1)
q1 = self.online_predictor(z1)
h2 = self.online_encoder(x2)
z2 = self.online_projector(h2)
q2 = self.online_predictor(z2)
# Target network forward (with stop gradient)
with torch.no_grad():
self.update_target()
t1 = self.target_encoder(x2)
t1 = self.target_projector(t1)
t2 = self.target_encoder(x1)
t2 = self.target_projector(t2)
# Normalize
q1 = F.normalize(q1, dim=1)
q2 = F.normalize(q2, dim=1)
t1 = F.normalize(t1, dim=1)
t2 = F.normalize(t2, dim=1)
# Symmetric loss
loss = 2 - 2 * (q1 * t1).sum(dim=1).mean()
loss += 2 - 2 * (q2 * t2).sum(dim=1).mean()
return loss
6. Applications of Self-Supervised Learning
Computer Vision
- Pre-training for object detection
- Medical image analysis
- Few-shot learning
Natural Language Processing
- BERT-style pre-training
- Cross-modal retrieval
- Unsupervised translation
Other Domains
- Audio representation learning
- Graph neural networks
- Reinforcement learning
Conclusion
Self-supervised learning has emerged as a powerful paradigm for learning from unlabeled data, with contrastive methods like SimCLR, MoCo, and BYOL achieving remarkable results across domains. As these techniques continue to mature, they promise to reduce our reliance on costly labeled datasets while enabling more flexible and generalizable representations.
In our next post, we'll explore model deployment techniques to bring your trained models into production.
✅ SHARE
🔍 Curious about Deep Learning? Read our next post on Model Deployment (ONNX, TensorRT, FastAPI)Follow DrASR Deep Learning for more in-depth tutorials, fundamentals, and research-backed content in Deep Learning.
If you found this helpful, leave a comment or share it with your peers. Let’s grow together in AI learning!
Comments
Post a Comment