Diffusion Models
Diffusion Models: The New Frontier of Generative AI
Diffusion models have emerged as the new state-of-the-art in generative AI, powering systems like DALL-E 2 and Stable Diffusion. In this comprehensive guide, we'll explore how these models work, from the foundational Denoising Diffusion Probabilistic Models (DDPM) to the revolutionary latent diffusion architectures behind today's most impressive AI art generators.
1. The Diffusion Process
Diffusion models are inspired by thermodynamics, gradually adding noise to data then learning to reverse this process:
Forward Diffusion (Fixed Process)
Gradually adds Gaussian noise to data over T steps according to a variance schedule βt:
q(xt|xt-1) = N(xt; √(1-βt)xt-1, βtI)
Reverse Diffusion (Learned Process)
A neural network learns to gradually denoise:
pθ(xt-1|xt) = N(xt-1; μθ(xt,t), Σθ(xt,t))
2. Denoising Diffusion Probabilistic Models (DDPM)
The foundational 2020 paper introduced key innovations:
| Component | Description | Impact |
|---|---|---|
| Noise Schedule | Carefully designed βt values | Balances speed and quality |
| Reparameterization | Predict noise instead of mean | Simplifies learning |
| Loss Function | Simple MSE on noise | Stable training |
class DDPM(nn.Module):
def __init__(self, network, timesteps, beta_start, beta_end):
super().__init__()
self.network = network
self.timesteps = timesteps
# Define beta schedule
self.betas = torch.linspace(beta_start, beta_end, timesteps)
# Pre-calculate diffusion parameters
self.alphas = 1. - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def forward(self, x, t, noise):
alpha_bar = self.alpha_bars[t].view(-1, 1, 1, 1)
noisy_x = torch.sqrt(alpha_bar) * x + torch.sqrt(1 - alpha_bar) * noise
predicted_noise = self.network(noisy_x, t)
loss = F.mse_loss(predicted_noise, noise)
return loss
@torch.no_grad()
def sample(self, shape, device):
# Start from pure noise
x = torch.randn(shape, device=device)
for t in reversed(range(self.timesteps)):
t_tensor = torch.full((shape[0],), t, device=device, dtype=torch.long)
pred_noise = self.network(x, t_tensor)
alpha = self.alphas[t]
alpha_bar = self.alpha_bars[t]
x = (x - (1-alpha)/torch.sqrt(1-alpha_bar)*pred_noise)/torch.sqrt(alpha)
if t > 0:
x += torch.sqrt(self.betas[t]) * torch.randn_like(x)
return x
3. Latent Diffusion Models (Stable Diffusion)
Stable Diffusion improved efficiency by operating in latent space:
Key Components
| Component | Purpose | Details |
|---|---|---|
| VAE Encoder/Decoder | Image ⇄ Latent Space | Compresses images to smaller latent representations |
| UNet | Denoising Model | Predicts noise in latent space with cross-attention |
| Text Encoder (CLIP) | Conditioning | Maps text prompts to embedding space |
4. Implementing Stable Diffusion Components
VAE (Variational Autoencoder)
class VAE(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU()
)
# Latent space
self.fc_mu = nn.Linear(128*8*8, 256)
self.fc_var = nn.Linear(128*8*8, 256)
# Decoder
self.decoder = nn.Sequential(
nn.Linear(256, 128*8*8),
nn.Unflatten(1, (128, 8, 8)),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
h = h.view(h.size(0), -1)
return self.fc_mu(h), self.fc_var(h)
def decode(self, z):
return self.decoder(z)
Conditional UNet
class UNetBlock(nn.Module):
def __init__(self, in_c, out_c, time_emb_dim, text_emb_dim):
super().__init__()
self.time_mlp = nn.Linear(time_emb_dim, out_c)
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
# Cross-attention for text conditioning
self.attn = nn.MultiheadAttention(embed_dim=out_c, num_heads=4)
self.norm = nn.LayerNorm(out_c)
self.proj = nn.Linear(text_emb_dim, out_c)
def forward(self, x, t, text_emb):
h = self.conv1(x)
# Add time embedding
t = F.silu(self.time_mlp(t))
h = h + t.view(-1, h.shape[1], 1, 1)
# Cross-attention with text
batch, channels, height, width = h.shape
h_flat = h.view(batch, channels, -1).permute(0, 2, 1)
text_proj = self.proj(text_emb).unsqueeze(1)
attn_out, _ = self.attn(h_flat, text_proj, text_proj)
h = h + attn_out.permute(0, 2, 1).view(batch, channels, height, width)
h = self.conv2(h)
return h
5. Practical Applications
Creative Tools
- AI art generation
- Photo realistic synthesis
- Style transfer
Design & Media
- Concept art creation
- Texture generation
- Advertising content
Scientific Applications
- Molecular design
- Medical imaging
- Data augmentation
Conclusion
Diffusion models represent a fundamental shift in generative AI, offering unprecedented quality and control compared to previous approaches. From the theoretical foundations of DDPM to the practical breakthroughs of Stable Diffusion, these models have opened new creative possibilities while presenting fascinating technical challenges.
In our next post, we'll explore how deep learning is revolutionizing reinforcement learning through techniques like Deep Q-Networks and Policy Gradient methods.
✅ SHARE
🔍 Curious about Deep Learning? Read our next post on Deep Reinforcement Learning (DQN, Policy Gradients)Follow DrASR Deep Learning for more in-depth tutorials, fundamentals, and research-backed content in Deep Learning.
If you found this helpful, leave a comment or share it with your peers. Let’s grow together in AI learning!
Comments
Post a Comment