"""
**Utilities for text encoding, score prediction, and evaluation in diffusion models**
This module provides core components for building diffusion model pipelines, including
text encoding (used as a conditional model), U-Net-based score prediction (ScoreNet),
custom loss functions for training, and image quality evaluation. These utilities support
various diffusion model architectures, such as DDPM, DDIM, LDM, and SDE, and are designed
for standalone use in model training and sampling.
**Primary Components**
- **TextEncoder**: Encodes text prompts into embeddings using a pre-trained BERT model or a custom transformer.
- **ScoreNet**: Memory-efficient U-Net-like architecture for predicting noise or scores in diffusion models, supporting time and text conditioning.
- **Loss Functions**:
- `mse_loss`: Standard mean squared error loss.
- `snr_capped_loss`: SNR-weighted noise prediction loss with capped weighting, useful for VP/VE training.
- `ve_sigma_weighted_score_loss`: Sigma-weighted score matching loss for VE-SDEs.
- **Metrics**: Computes image quality metrics (MSE, PSNR, SSIM, FID, LPIPS) for evaluating generated images.
**Notes**
- The primary components are intended to be imported directly for use in diffusion model workflows.
- Additional supporting classes and functions in this module provide internal functionality for the primary components.
------------------------------------------------------------------------------------------
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_fid import fid_score
from transformers import BertModel
import os
import math
import shutil
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchvision.utils import save_image
from typing import Optional, Tuple, List
__all__ = [
"TextEncoder",
"EncoderLayer",
"FeedForward",
"Attention",
"Embedding",
"DiffusionNetwork",
"ResBlock",
"CrossAttention",
"get_timestep_embedding",
"LossAdapter",
"mse_loss",
"snr_capped_loss",
"min_snr_loss",
"ve_sigma_weighted_score_loss",
"Metrics",
]
###==================================================================================================================###
[docs]
class TextEncoder(torch.nn.Module):
"""Transformer-based encoder for text prompts in conditional diffusion models.
Encodes text prompts into embeddings using either a pre-trained BERT model or a
custom transformer architecture. Used as the `conditional_model` in diffusion models
(e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.
Parameters
----------
use_pretrained_model : bool, optional
If True, uses a pre-trained BERT model; otherwise, builds a custom transformer
(default: True).
model_name : str, optional
Name of the pre-trained model to load (default: "bert-base-uncased").
vocabulary_size : int, optional
Size of the vocabulary for the custom transformer’s embedding layer
(default: 30522).
num_layers : int, optional
Number of transformer encoder layers for the custom transformer (default: 6).
input_dimension : int, optional
Input embedding dimension for the custom transformer (default: 768).
output_dimension : int, optional
Output embedding dimension for both pre-trained and custom models
(default: 768).
num_heads : int, optional
Number of attention heads in the custom transformer (default: 8).
context_length : int, optional
Maximum sequence length for text prompts (default: 77).
dropout_rate : float, optional
Dropout rate for attention and feedforward layers (default: 0.1).
qkv_bias : bool, optional
If True, includes bias in query, key, and value projections for the custom
transformer (default: False).
scaling_value : int, optional
Scaling factor for the feedforward layer’s hidden dimension in the custom
transformer (default: 4).
epsilon : float, optional
Epsilon for layer normalization in the custom transformer (default: 1e-5).
use_learned_pos : bool, optional
If True, in the transformer structure uses learnable positional embeddings instead of sinusoidal encodings
(default: False).
**Notes**
- When `use_pretrained_model` is True, the BERT model’s parameters are frozen
(`requires_grad = False`), and a projection layer maps outputs to
`output_dimension`.
- The custom transformer uses `EncoderLayer` modules with multi-head attention and
feedforward networks, supporting variable input/output dimensions.
- The output shape is (batch_size, context_length, output_dimension).
"""
def __init__(
self,
use_pretrained_model: bool = True,
model_name: str = "bert-base-uncased",
vocabulary_size: int = 30522,
num_layers: int = 6,
input_dimension: int = 768,
output_dimension: int = 768,
num_heads: int = 8,
context_length: int = 77,
dropout_rate: float = 0.1,
qkv_bias: bool = False,
scaling_value: int = 4,
epsilon: float = 1e-5,
use_learned_pos: bool = False
) -> None:
super().__init__()
self.use_pretrained_model = use_pretrained_model
if self.use_pretrained_model:
self.bert = BertModel.from_pretrained(model_name)
for param in self.bert.parameters():
param.requires_grad = False
self.projection = nn.Linear(self.bert.config.hidden_size, output_dimension)
else:
self.embedding = Embedding(
vocabulary_size=vocabulary_size,
embedding_dimension=input_dimension,
max_context_length=context_length,
use_learned_pos=use_learned_pos
)
self.layers = torch.nn.ModuleList([
EncoderLayer(
input_dimension=input_dimension,
output_dimension=output_dimension,
num_heads=num_heads,
dropout_rate=dropout_rate,
qkv_bias=qkv_bias,
scaling_value=scaling_value,
epsilon=epsilon
)
for _ in range(num_layers)
])
[docs]
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Encodes text prompts into embeddings.
Processes input token IDs and an optional attention mask to produce embeddings
using either a pre-trained BERT model or a custom transformer.
Parameters
----------
x : torch.Tensor
Token IDs, shape (batch_size, seq_len).
attention_mask : torch.Tensor, optional
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
tokens to ignore (default: None).
Returns
-------
x (torch.Tensor) - Encoded embeddings, shape (batch_size, seq_len, output_dimension).
**Notes**
- For pre-trained BERT, the `last_hidden_state` is projected to
`output_dimension` and this layer is the only trainable layer in the model.
- For the custom transformer, token embeddings are processed through
`Embedding` and `EncoderLayer` modules.
- The attention mask should be 0 for padding tokens and 1 for valid tokens when
using the custom transformer, or follow BERT’s convention for pre-trained
models.
"""
if self.use_pretrained_model:
x = self.bert(input_ids=x, attention_mask=attention_mask)
x = x.last_hidden_state
x = self.projection(x)
else:
x = self.embedding(x)
for layer in self.layers:
x = layer(x, attention_mask=attention_mask)
return x
###==================================================================================================================###
[docs]
class EncoderLayer(torch.nn.Module):
"""Single transformer encoder layer with multi-head attention and feedforward network.
Used in the custom transformer of `TextEncoder` to process embedded text prompts.
Parameters
----------
input_dimension : int
Input embedding dimension.
output_dimension : int
Output embedding dimension.
num_heads : int
Number of attention heads.
dropout_rate : float
Dropout rate for attention and feedforward layers.
qkv_bias : bool
If True, includes bias in query, key, and value projections.
scaling_value : int
Scaling factor for the feedforward layer’s hidden dimension.
epsilon : float, optional
Epsilon for layer normalization (default: 1e-5).
**Notes**
- The layer follows the standard transformer encoder architecture: attention,
residual connection, normalization, feedforward, residual connection,
normalization.
- The attention mechanism uses `batch_first=True` for compatibility with
`TextEncoder`’s input format.
"""
def __init__(
self,
input_dimension: int,
output_dimension: int,
num_heads: int,
dropout_rate: float,
qkv_bias: bool,
scaling_value: int,
epsilon: float = 1e-5
) -> None:
super().__init__()
self.attention = nn.MultiheadAttention(
embed_dim=input_dimension,
num_heads=num_heads,
dropout=dropout_rate,
bias=qkv_bias,
batch_first=True
)
self.output_projection = nn.Linear(input_dimension, output_dimension) if input_dimension != output_dimension else nn.Identity()
self.residual_projection = nn.Linear(input_dimension, output_dimension, bias=False) if input_dimension != output_dimension else nn.Identity()
self.norm1 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
self.dropout1 = nn.Dropout(dropout_rate)
self.feedforward = FeedForward(
embedding_dimension=output_dimension,
scaling_value=scaling_value,
dropout_rate=dropout_rate
)
self.norm2 = nn.LayerNorm(normalized_shape=output_dimension, eps=epsilon)
self.dropout2 = nn.Dropout(dropout_rate)
[docs]
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Processes input embeddings through attention and feedforward layers.
Parameters
----------
x : torch.Tensor
Input embeddings, shape (batch_size, seq_len, input_dimension).
attention_mask : torch.Tensor, optional
Attention mask, shape (batch_size, seq_len), where 0 indicates padding
tokens to ignore (default: None).
Returns
-------
x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, output_dimension).
**Notes**
- The attention mask is passed as `key_padding_mask` to
`nn.MultiheadAttention`, where 0 indicates padding tokens.
- Residual connections and normalization are applied after attention and
feedforward layers.
"""
attn_output, _ = self.attention(x, x, x, key_padding_mask=attention_mask)
attn_output = self.output_projection(attn_output)
x = self.norm1(self.residual_projection(x) + self.dropout1(attn_output))
ff_output = self.feedforward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
###==================================================================================================================###
[docs]
class FeedForward(torch.nn.Module):
"""Feedforward network for transformer encoder layers.
Used in `EncoderLayer` to process attention outputs with a two-layer MLP and GELU
activation.
Parameters
----------
embedding_dimension : int
Input and output embedding dimension.
scaling_value : int
Scaling factor for the hidden layer’s dimension (hidden_dim =
embedding_dimension * scaling_value).
dropout_rate : float, optional
Dropout rate after the hidden layer (default: 0.1).
**Notes**
- The hidden layer dimension is `embedding_dimension * scaling_value`, following
standard transformer feedforward designs.
- GELU activation is used for non-linearity.
"""
def __init__(self, embedding_dimension: int, scaling_value: int, dropout_rate: float = 0.1) -> None:
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(
in_features=embedding_dimension,
out_features=embedding_dimension * scaling_value,
bias=True
),
torch.nn.GELU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Linear(
in_features=embedding_dimension * scaling_value,
out_features=embedding_dimension,
bias=True
)
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Processes input embeddings through the feedforward network.
Parameters
----------
x : torch.Tensor
Input embeddings, shape (batch_size, seq_len, embedding_dimension).
Returns
-------
x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, embedding_dimension).
"""
return self.layers(x)
###==================================================================================================================###
[docs]
class Attention(nn.Module):
"""Attention module for NoisePredictor, supporting text conditioning or self-attention.
Applies multi-head attention to enhance features, with optional text embeddings for
conditional generation.
Parameters
----------
in_channels : int
Number of input channels (embedding dimension for attention).
y_embed_dim : int, optional
Dimensionality of text embeddings (default: 768).
num_heads : int, optional
Number of attention heads (default: 4).
num_groups : int, optional
Number of groups for group normalization (default: 8).
dropout_rate : float, optional
Dropout rate for attention and output (default: 0.1).
Attributes
----------
in_channels : int
Input channel dimension.
y_embed_dim : int
Text embedding dimension.
num_heads : int
Number of attention heads.
dropout_rate : float
Dropout rate.
attention : torch.nn.MultiheadAttention
Multi-head attention with `batch_first=True`.
norm : torch.nn.GroupNorm
Group normalization before attention.
dropout : torch.nn.Dropout
Dropout layer for output.
y_projection : torch.nn.Linear
Projection for text embeddings to match `in_channels`.
Raises
------
AssertionError
If input channels do not match `in_channels`.
ValueError
If text embeddings (`y`) have incorrect dimensions after projection.
"""
def __init__(
self,
in_channels: int,
y_embed_dim: int = 768,
num_heads: int = 4,
num_groups: int = 8,
dropout_rate: float = 0.1
) -> None:
super().__init__()
self.in_channels = in_channels
self.y_embed_dim = y_embed_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads, dropout=dropout_rate, batch_first=True)
self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels)
self.dropout = nn.Dropout(dropout_rate)
self.y_projection = nn.Linear(y_embed_dim, in_channels)
[docs]
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor] = None):
"""Applies attention to input features with optional text conditioning.
Parameters
----------
x : torch.Tensor
Input tensor, shape (batch_size, in_channels, height, width).
y : torch.Tensor, optional
Text embeddings, shape (batch_size, seq_len, y_embed_dim) or
(batch_size, y_embed_dim) (default: None).
Returns
-------
torch.Tensor
Output tensor, same shape as input `x`.
"""
batch_size, channels, h, w = x.shape
assert channels == self.in_channels, f"Expected {self.in_channels} channels, got {channels}"
x_reshaped = x.view(batch_size, channels, h * w).permute(0, 2, 1)
if y is not None:
y = self.y_projection(y)
if y.dim() != 3:
if y.dim() == 2:
y = y.unsqueeze(1)
else:
raise ValueError(
f"Expected y to be 2D or 3D after projection, got {y.dim()}D with shape {y.shape}"
)
if y.shape[-1] != self.in_channels:
raise ValueError(
f"Expected y's embedding dim to match in_channels ({self.in_channels}), got {y.shape[-1]}"
)
out, _ = self.attention(x_reshaped, y, y)
else:
out, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
out = out.permute(0, 2, 1).view(batch_size, channels, h, w)
out = self.norm(out)
out = self.dropout(out)
return out
###==================================================================================================================###
[docs]
class Embedding(nn.Module):
"""Token and positional embedding layer for transformer inputs.
Used in `TextEncoder`’s transformer to embed token IDs and add positional encodings.
Parameters
----------
vocabulary_size : int
Size of the vocabulary for token embeddings.
embedding_dimension : int, optional
Dimension of token and positional embeddings (default: 768).
max_context_length : int, optional
Maximum sequence length for precomputing positional encodings (default: 77).
use_learned_pos : bool, optional
If True, uses learnable positional embeddings instead of sinusoidal encodings
(default: False).
**Notes**
- Supports both sinusoidal (fixed) and learned positional embeddings, selectable via
`use_learned_pos`.
- Sinusoidal encodings follow the transformer architecture, computed on-the-fly for
memory efficiency and cached for sequences up to `max_context_length`.
- Learned positional embeddings are initialized as a learnable parameter for flexibility.
- Optimized for device-agnostic operation, ensuring seamless CPU/GPU transitions.
- The output shape is (batch_size, seq_len, embedding_dimension).
"""
def __init__(
self,
vocabulary_size: int,
embedding_dimension: int = 768,
max_context_length: int = 77,
use_learned_pos: bool = False
) -> None:
super().__init__()
self.vocabulary_size = vocabulary_size
self.embedding_dimension = embedding_dimension
self.max_context_length = max_context_length
self.use_learned_pos = use_learned_pos
# Token embedding layer
self.token_embedding = nn.Embedding(
num_embeddings=vocabulary_size,
embedding_dim=embedding_dimension
)
if use_learned_pos:
# Learnable positional embeddings
self.positional_embedding = nn.Parameter(
torch.randn(1, max_context_length, embedding_dimension) / math.sqrt(embedding_dimension)
)
else:
# Register buffer for sinusoidal encodings
self.register_buffer(
"positional_encoding_cache",
torch.empty(1, 0, embedding_dimension, dtype=torch.float32)
)
def _generate_positional_encoding(self, seq_len: int, device: torch.device) -> torch.Tensor:
"""Generates sinusoidal positional encodings for transformer inputs.
Computes positional encodings using sine and cosine functions.
Parameters
----------
seq_len : int
Length of the sequence for which to generate positional encodings.
device : torch.device
Device on which to create the positional encodings.
Returns
-------
torch.Tensor
Positional encodings, shape (1, seq_len, embedding_dimension), where
even-indexed dimensions use sine and odd-indexed dimensions use cosine.
**Notes**
- Uses the formula: for position `pos` and dimension `i`,
`PE(pos, 2i) = sin(pos / 10000^(2i/d))` and
`PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`, where `d` is `embedding_dimension`.
- Fully vectorized for efficiency and supports any sequence length.
"""
position = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dimension, 2, dtype=torch.float32, device=device) *
(-math.log(10000.0) / self.embedding_dimension)
)
pos_enc = torch.zeros((1, seq_len, self.embedding_dimension), dtype=torch.float32, device=device)
pos_enc[:, :, 0::2] = torch.sin(position * div_term)
cos_div_term = div_term[:-1] if self.embedding_dimension % 2 else div_term
pos_enc[:, :, 1::2] = torch.cos(position * cos_div_term)
return pos_enc
[docs]
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
"""Embeds token IDs and adds positional encodings.
Parameters
----------
token_ids : torch.Tensor
Token IDs, shape (batch_size, seq_len).
Returns
-------
torch.Tensor
Embedded tokens with positional encodings, shape
(batch_size, seq_len, embedding_dimension).
**Notes**
- Automatically handles sequences longer than `max_context_length` by generating
positional encodings on-the-fly.
- For learned positional embeddings, sequences longer than `max_context_length`
will raise an error unless truncated.
- Ensures device compatibility by generating encodings on the input’s device.
"""
assert token_ids.dim() == 2, "Input token_ids should be of shape (batch_size, seq_len)"
batch_size, seq_len = token_ids.size()
device = token_ids.device
# Compute token embeddings
token_embedded = self.token_embedding(token_ids)
# Handle positional embeddings
if self.use_learned_pos:
if seq_len > self.max_context_length:
raise ValueError(
f"Sequence length ({seq_len}) exceeds max_context_length ({self.max_context_length}) "
"for learned positional embeddings."
)
position_encoded = self.positional_embedding[:, :seq_len, :]
else:
# Use cached sinusoidal encodings if available and sufficient
if (self.positional_encoding_cache.size(1) < seq_len or
self.positional_encoding_cache.device != device):
self.positional_encoding_cache = self._generate_positional_encoding(
max(seq_len, self.max_context_length), device
)
position_encoded = self.positional_encoding_cache[:, :seq_len, :]
return token_embedded + position_encoded
###==================================================================================================================###
[docs]
class DiffusionNetwork(nn.Module):
"""Memory-efficient U-Net architecture for diffusion models supporting time and conditional embeddings"""
def __init__(
self,
in_channels: int,
down_channels: List[int],
mid_channels: List[int],
up_channels: List[int],
down_sampling: List[bool],
time_embed_dim: int,
y_embed_dim: int,
num_down_blocks: int,
num_mid_blocks: int,
num_up_blocks: int,
dropout_rate: float = 0.1,
down_sampling_factor: int = 2,
y_to_all: bool = False,
cont_time: bool = True,
use_flash_attention: bool = True,
grad_check: bool = False
) -> None:
"""Initialize the ScoreNet U-Net with configurable down, middle, and up blocks, time embeddings, and optional attention.
Args:
in_channels: Number of input channels.
down_channels: List of channels for downsampling stages.
mid_channels: List of channels for middle blocks.
up_channels: List of channels for upsampling stages.
down_sampling: Boolean flags indicating whether to downsample at each down block.
time_embed_dim: Dimensionality of the time embedding.
y_embed_dim: Dimensionality of the conditional embedding.
num_down_blocks: Number of residual layers per down block.
num_mid_blocks: Number of residual layers per middle block.
num_up_blocks: Number of residual layers per up block.
dropout_rate: Dropout probability.
down_sampling_factor: Stride factor for downsampling/upsampling.
y_to_all: If True, applies conditional embeddings to all attention layers.
cont_time: Whether to use continuous time embeddings.
use_flash_attention: Whether to use flash attention for cross-attention layers.
grad_check: Whether to use gradient checkpointing.
"""
super().__init__()
self.cont_time = cont_time
self.grad_check = grad_check
assert len(down_channels) - 1 == len(down_sampling), \
f"down_sampling length must be len(down_channels)-1, got {len(down_sampling)} vs {len(down_channels) - 1}"
assert len(up_channels) - 1 <= len(down_channels) - 1, \
f"Cannot have more up blocks than down blocks"
self.conv_in = nn.Conv2d(in_channels, down_channels[0], 3, padding=1)
self.time_mlp = nn.Sequential(
nn.Linear(time_embed_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim)
)
self.encoder = nn.ModuleList()
for i in range(len(down_channels) - 1):
self.encoder.append(nn.ModuleDict({
'block': ResBlock(
in_channels=down_channels[i],
out_channels=down_channels[i + 1],
time_channels=time_embed_dim,
context_channels=y_embed_dim,
num_layers=num_down_blocks,
dropout=dropout_rate,
use_attention=(i == 0 or y_to_all),
use_flash=use_flash_attention
),
'downsample': nn.Conv2d(down_channels[i + 1], down_channels[i + 1], 3,
stride=down_sampling_factor, padding=1) if down_sampling[i] else nn.Identity()
}))
self.middle = nn.ModuleList()
for i in range(len(mid_channels) - 1):
self.middle.append(
ResBlock(
in_channels=mid_channels[i],
out_channels=mid_channels[i + 1],
time_channels=time_embed_dim,
context_channels=y_embed_dim,
num_layers=num_mid_blocks,
dropout=dropout_rate,
use_attention=True,
use_flash=use_flash_attention
)
)
num_decoder_stages = len(up_channels) - 1
up_sampling_ops = list(reversed(down_sampling[-num_decoder_stages:]))
encoder_output_channels = list(reversed(down_channels[1:]))[:num_decoder_stages]
self.decoder = nn.ModuleList()
for i in range(num_decoder_stages):
self.decoder.append(nn.ModuleDict({
'upsample': nn.ConvTranspose2d(up_channels[i], up_channels[i],
down_sampling_factor, stride=down_sampling_factor) if up_sampling_ops[
i] else nn.Identity(),
'block': ResBlock(
in_channels=up_channels[i] + encoder_output_channels[i],
out_channels=up_channels[i + 1],
time_channels=time_embed_dim,
context_channels=y_embed_dim,
num_layers=num_up_blocks,
dropout=dropout_rate,
use_attention=(i == 0 or y_to_all),
use_flash=use_flash_attention
)
}))
self.conv_out = nn.Sequential(
nn.GroupNorm(8, up_channels[-1]),
nn.SiLU(),
nn.Conv2d(up_channels[-1], in_channels, 3, padding=1)
)
self.apply(self._init_weights)
def _init_weights(self, m):
"""Initialize weights of Conv2d and Linear layers with Kaiming initialization and zero biases.
Args:
m: Module to initialize.
"""
if isinstance(m, (nn.Conv2d, nn.Linear)):
nn.init.kaiming_normal_(m.weight, a=0.01, nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
[docs]
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
clip_embeddings: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Forward pass through the U-Net with time and optional conditional embeddings.
Args:
x: Input tensor of shape [B, C, H, W].
t: Tensor of timesteps [B] or [B, 1].
y: Optional context embeddings [B, D] or [B, L, D].
clip_embeddings: Optional CLIP embeddings [B, D].
Returns:
Output tensor of shape [B, in_channels, H, W].
"""
t_emb = get_timestep_embedding(t, self.time_mlp[0].in_features, self.cont_time)
t_emb = self.time_mlp(t_emb)
if clip_embeddings is not None:
t_emb = t_emb + clip_embeddings
h = self.conv_in(x)
encoder_features = []
for stage in self.encoder:
h = self._apply_block(stage['block'], h, t_emb, y)
encoder_features.append(h)
h = stage['downsample'](h)
for block in self.middle:
h = self._apply_block(block, h, t_emb, y)
num_decoder_stages = len(self.decoder)
skips_for_decoder = list(reversed(encoder_features[-num_decoder_stages:]))
for stage, skip in zip(self.decoder, skips_for_decoder):
h = stage['upsample'](h)
if h.shape[2:] != skip.shape[2:]:
h = torch.nn.functional.interpolate(
h,
size=skip.shape[2:],
mode='bilinear' if h.dim() == 4 else 'trilinear',
align_corners=False
)
h = torch.cat([h, skip], dim=1)
h = self._apply_block(stage['block'], h, t_emb, y)
return self.conv_out(h)
def _apply_block(self, block, x, t_emb, y):
"""Apply a residual block with optional gradient checkpointing.
Args:
block: The ResBlock module to apply.
x: Input tensor.
t_emb: Time embedding tensor.
y: Optional conditional embedding.
Returns:
Output tensor after applying the block.
"""
if self.grad_check and self.training:
return torch.utils.checkpoint.checkpoint(
block, x, t_emb, y, use_reentrant=False
)
return block(x, t_emb, y)
[docs]
class ResBlock(nn.Module):
"""Efficient residual block with optional cross-attention for U-Net."""
def __init__(
self,
in_channels: int,
out_channels: int,
time_channels: int,
context_channels: int,
num_layers: int = 2,
dropout: float = 0.1,
use_attention: bool = False,
use_flash: bool = True
):
"""Initialize a ResBlock with optional attention and multiple residual layers.
Args:
in_channels: Number of input channels.
out_channels: Number of output channels.
time_channels: Dimensionality of time embedding.
context_channels: Dimensionality of conditional embedding.
num_layers: Number of residual layers in the block.
dropout: Dropout probability.
use_attention: Whether to include a cross-attention layer.
use_flash: Whether to use flash attention if available.
"""
super().__init__()
self.num_layers = num_layers
self.use_attention = use_attention and context_channels > 0
self.res_layers = nn.ModuleList()
for i in range(num_layers):
ch_in = in_channels if i == 0 else out_channels
self.res_layers.append(
nn.ModuleDict({
'norm1': nn.GroupNorm(8, ch_in),
'conv1': nn.Conv2d(ch_in, out_channels, 3, padding=1),
'time_emb': nn.Linear(time_channels, out_channels),
'norm2': nn.GroupNorm(8, out_channels),
'conv2': nn.Conv2d(out_channels, out_channels, 3, padding=1),
'dropout': nn.Dropout(dropout),
'skip': nn.Conv2d(ch_in, out_channels, 1) if ch_in != out_channels else nn.Identity()
})
)
if self.use_attention:
self.attention = CrossAttention(
out_channels, context_channels,
num_heads=4, dropout=dropout, use_flash=use_flash
)
[docs]
def forward(self, x: torch.Tensor, t_emb: torch.Tensor, context: Optional[torch.Tensor] = None):
"""Forward pass through the residual block.
Args:
x: Input tensor of shape [B, C, H, W].
t_emb: Time embedding tensor of shape [B, time_channels].
context: Optional conditional embeddings for cross-attention.
Returns:
Output tensor after residual layers (and optional attention).
"""
h = x
t_emb_activated = F.silu(t_emb)
for i, layer in enumerate(self.res_layers):
res = h
h = layer['norm1'](h)
h = F.silu(h)
h = layer['conv1'](h)
h = h + layer['time_emb'](t_emb_activated)[:, :, None, None]
h = layer['norm2'](h)
h = F.silu(h)
h = layer['dropout'](h)
h = layer['conv2'](h)
h = h + layer['skip'](res)
if i == 0 and self.use_attention and context is not None:
h = h + self.attention(h, context)
return h
[docs]
class CrossAttention(nn.Module):
"""Cross-attention module with optional flash attention."""
def __init__(
self,
channels: int,
context_dim: int,
num_heads: int = 4,
dropout: float = 0.0,
use_flash: bool = True
):
"""Initialize cross-attention with query, key, value projections and optional flash attention.
Args:
channels: Number of input channels.
context_dim: Dimensionality of the context embeddings.
num_heads: Number of attention heads.
dropout: Dropout probability for attention output.
use_flash: Whether to use flash attention if available.
"""
super().__init__()
assert channels % num_heads == 0
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.scale = self.head_dim ** -0.5
self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
self.norm = nn.GroupNorm(8, channels)
self.to_q = nn.Linear(channels, channels, bias=False)
self.to_kv = nn.Linear(context_dim, channels * 2, bias=False)
self.proj_out = nn.Linear(channels, channels)
self.dropout = nn.Dropout(dropout)
[docs]
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""Compute cross-attention output.
Args:
x: Input feature map tensor [B, C, H, W].
context: Context embeddings [B, D] or [B, L, D].
Returns:
Tensor of shape [B, C, H, W] after applying attention.
"""
B, C, H, W = x.shape
x_norm = self.norm(x)
x_flat = x_norm.view(B, C, H * W).transpose(1, 2) # [B, HW, C]
if context.dim() == 2:
context = context.unsqueeze(1) # [B, 1, D]
q = self.to_q(x_flat)
kv = self.to_kv(context)
k, v = kv.chunk(2, dim=-1)
q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
if self.use_flash:
out = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
else:
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, H * W, C)
out = self.proj_out(out)
out = self.dropout(out)
out = out.transpose(1, 2).view(B, C, H, W)
return out
[docs]
def get_timestep_embedding(timesteps: torch.Tensor, dim: int, continuous: bool = True, scale = 1000.0) -> torch.Tensor:
"""Compute sinusoidal timestep embeddings for continuous or discrete timesteps.
Args:
timesteps: Tensor of timesteps [B] or scalar.
dim: Dimensionality of the embedding vector.
continuous: If True, scales timesteps by 1000 to emulate discrete DDPM timesteps.
Returns:
Tensor of shape [B, dim] containing sinusoidal embeddings.
"""
if timesteps.dim() == 0:
timesteps = timesteps.unsqueeze(0)
elif timesteps.dim() == 2:
timesteps = timesteps.squeeze(-1)
if continuous:
timesteps = timesteps * scale
half_dim = dim // 2
emb = math.log(10000.0) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=timesteps.device, dtype=torch.float32) * -emb)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return emb
###==================================================================================================================###
[docs]
class LossAdapter:
"""Adapter to make any loss function compatible with extra arguments."""
def __init__(self, loss_fn):
self.loss_fn = loss_fn
def __call__(self, predictions, targets, *args, **kwargs):
if isinstance(self.loss_fn, torch.nn.Module):
return self.loss_fn(predictions, targets)
try:
return self.loss_fn(predictions, targets, *args, **kwargs)
except TypeError as e:
if "positional argument" in str(e) or "takes" in str(e):
return self.loss_fn(predictions, targets)
raise
[docs]
def mse_loss(pred: torch.Tensor, target: torch.Tensor, *args) -> torch.Tensor:
"""
Standard mean squared error (MSE) loss.
Computes the element-wise squared difference between `pred` and `target`
and returns the mean across all elements.
Args:
pred: Predicted tensor, shape [B, ...].
target: Target tensor, same shape as `pred`.
*args: Placeholder for optional unused arguments for API compatibility.
Returns:
Scalar tensor representing mean squared error.
"""
return ((pred - target) ** 2).mean()
[docs]
def snr_capped_loss(pred_noise: torch.Tensor, target_noise: torch.Tensor, variance: torch.Tensor, gamma: float = 5.0, *args) -> torch.Tensor:
"""
Signal-to-noise-ratio (SNR) capped noise prediction loss for diffusion models.
This implements a weighted MSE where the weight is the SNR of the timestep,
capped at a maximum value `gamma`. Typically used in VP/VE noise prediction.
Args:
pred_noise: Predicted noise tensor, same shape as target_noise.
target_noise: True noise tensor.
variance: Variance (sigma^2) corresponding to the timestep t, shape broadcastable to pred_noise.
gamma: Maximum SNR weight (default 5.0).
*args: Placeholder for optional unused arguments for API compatibility.
Returns:
Scalar tensor representing the SNR-weighted mean squared error.
"""
snr = (1 - variance) / variance.clamp(min=1e-8)
weight = snr.clamp(max=gamma)
while weight.dim() < target_noise.dim():
weight = weight.unsqueeze(-1)
return ((pred_noise - target_noise) ** 2 * weight).mean()
[docs]
def min_snr_loss(pred: torch.Tensor, target: torch.Tensor, snr: torch.Tensor, gamma: float = 5.0, *args) -> torch.Tensor:
"""
Min-SNR weighting strategy for stable diffusion training.
Normalizes the loss by min(snr, gamma) / snr.
"""
# weighting formula from 'efficient diffusion training via min-snr weighting strategy'
weight = torch.clamp(snr, max=gamma) / snr
while weight.dim() < pred.dim():
weight = weight.unsqueeze(-1)
loss = (pred - target) ** 2
return (loss * weight).mean()
[docs]
def ve_sigma_weighted_score_loss(pred_score: torch.Tensor, target_score: torch.Tensor, sigma: torch.Tensor, *args) -> torch.Tensor:
"""
VE-SDE sigma-weighted score matching loss.
Implements the recommended loss for Variance Exploding SDEs:
E[ || sigma(t) * s_theta(x_t, t) + epsilon ||^2 ]
where epsilon is the true noise used to perturb x_0.
Args:
pred_score: Model-predicted score tensor (∇_x log p(x_t)), shape [B, ...].
target_score: Target score, typically -epsilon / sigma(t).
sigma: Standard deviation (σ(t)) at the corresponding timesteps, shape broadcastable to pred_score.
*args: Placeholder for optional unused arguments for API compatibility.
Returns:
Scalar tensor representing the sigma-weighted score matching loss.
"""
while sigma.dim() < pred_score.dim():
sigma = sigma.unsqueeze(-1)
eps = -target_score * sigma
return ((sigma * pred_score + eps) ** 2).mean()
###==================================================================================================================###
[docs]
class Metrics:
"""Computes image quality metrics for evaluating diffusion models.
Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural
Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual
Image Patch Similarity (LPIPS) for comparing generated and ground truth images.
Parameters
----------
device : str, optional
Device for computation (e.g., 'cuda', 'cpu') (default: 'cuda').
fid : bool, optional
If True, compute FID score (default: True).
metrics : bool, optional
If True, compute MSE, PSNR, and SSIM (default: False).
lpips : bool, optional
If True, compute LPIPS using VGG backbone (default: False).
"""
def __init__(
self,
device: str = "cuda",
fid: bool = True,
metrics: bool = False,
lpips_: bool = False
) -> None:
self.device = device
self.fid = fid
self.metrics = metrics
self.lpips = lpips_
self.lpips_model = LearnedPerceptualImagePatchSimilarity(
net_type='vgg',
normalize=True # This handles [0,1] -> [-1,1] conversion
).to(device) if self.lpips else None
self.temp_dir_real = "temp_real"
self.temp_dir_fake = "temp_fake"
[docs]
def compute_fid(self, real_images: torch.Tensor, fake_images: torch.Tensor) -> float:
"""Computes the Fréchet Inception Distance (FID) between real and generated images.
Saves images to temporary directories and uses Inception V3 to compute FID,
cleaning up directories afterward.
Parameters
----------
real_images : torch.Tensor
Real images, shape (batch_size, channels, height, width), in [-1, 1].
fake_images : torch.Tensor
Generated images, same shape, in [-1, 1].
Returns
-------
fid (float) - FID score, or `float('inf')` if computation fails.
**Notes**
- Images are normalized to [0, 1] and saved as PNG files for FID computation.
- Uses Inception V3 with 2048-dimensional features (`dims=2048`).
"""
if real_images.shape != fake_images.shape:
raise ValueError(f"Shape mismatch: real_images {real_images.shape}, fake_images {fake_images.shape}")
real_images = (real_images + 1) / 2
fake_images = (fake_images + 1) / 2
real_images = real_images.clamp(0, 1).cpu()
fake_images = fake_images.clamp(0, 1).cpu()
os.makedirs(self.temp_dir_real, exist_ok=True)
os.makedirs(self.temp_dir_fake, exist_ok=True)
try:
for i, (real, fake) in enumerate(zip(real_images, fake_images)):
save_image(real, f"{self.temp_dir_real}/{i}.png")
save_image(fake, f"{self.temp_dir_fake}/{i}.png")
fid = fid_score.calculate_fid_given_paths(
paths=[self.temp_dir_real, self.temp_dir_fake],
batch_size=50,
device=self.device,
dims=2048
)
except Exception as e:
print(f"Error computing FID: {e}")
fid = float('inf')
finally:
shutil.rmtree(self.temp_dir_real, ignore_errors=True)
shutil.rmtree(self.temp_dir_fake, ignore_errors=True)
return fid
[docs]
def compute_metrics(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float]:
"""Computes MSE, PSNR, and SSIM for evaluating image quality.
Parameters
----------
x : torch.Tensor
Ground truth images, shape (batch_size, channels, height, width).
x_hat : torch.Tensor
Generated images, same shape as `x`.
Returns
-------
mse : float
Mean squared error.
psnr : float
Peak signal-to-noise ratio.
ssim : float
Structural similarity index (mean over batch).
"""
if x.shape != x_hat.shape:
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
mse = F.mse_loss(x_hat, x)
psnr = -10 * torch.log10(mse)
c1, c2 = (0.01 * 2) ** 2, (0.03 * 2) ** 2 # Adjusted for [-1, 1] range
eps = 1e-8
mu_x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
mu_y = F.avg_pool2d(x_hat, kernel_size=3, stride=1, padding=1)
mu_xy = mu_x * mu_y
sigma_x_sq = F.avg_pool2d(x.pow(2), kernel_size=3, stride=1, padding=1) - mu_x.pow(2)
sigma_y_sq = F.avg_pool2d(x_hat.pow(2), kernel_size=3, stride=1, padding=1) - mu_y.pow(2)
sigma_xy = F.avg_pool2d(x * x_hat, kernel_size=3, stride=1, padding=1) - mu_xy
ssim = ((2 * mu_xy + c1) * (2 * sigma_xy + c2)) / (
(mu_x.pow(2) + mu_y.pow(2) + c1) * (sigma_x_sq + sigma_y_sq + c2) + eps
)
return mse.item(), psnr.item(), ssim.mean().item()
[docs]
def compute_lpips(self, x: torch.Tensor, x_hat: torch.Tensor) -> float:
"""Computes LPIPS using a pre-trained VGG network.
Parameters
----------
x : torch.Tensor
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
x_hat : torch.Tensor
Generated images, same shape as `x`.
Returns
-------
lpips (float) - Mean LPIPS score over the batch.
"""
if self.lpips_model is None:
raise RuntimeError("LPIPS model not initialized; set lpips=True in __init__")
if x.shape != x_hat.shape:
raise ValueError(f"Shape mismatch: x {x.shape}, x_hat {x_hat.shape}")
# Normalize inputs to [0, 1] range
x = (x + 1) / 2 # Convert from [-1, 1] to [0, 1]
x_hat = (x_hat + 1) / 2
x = x.clamp(0, 1) # Ensure values are in [0, 1]
x_hat = x_hat.clamp(0, 1)
x = x.to(self.device)
x_hat = x_hat.to(self.device)
# Convert grayscale to RGB if needed
if x.shape[1] == 1:
x = x.repeat(1, 3, 1, 1) # Repeat grayscale channel 3 times
if x_hat.shape[1] == 1:
x_hat = x_hat.repeat(1, 3, 1, 1)
return self.lpips_model(x, x_hat).mean().item()
[docs]
def forward(self, x: torch.Tensor, x_hat: torch.Tensor) -> Tuple[float, float, float, float, float]:
"""Computes specified metrics for ground truth and generated images.
Parameters
----------
x : torch.Tensor
Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].
x_hat : torch.Tensor
Generated images, same shape as `x`.
Returns
-------
fid : float, or `float('inf')` if not computed
Mean FID score.
mse : float, or None if not computed
Mean MSE
psnr : float, or None if not computed
Mean PSNR
ssim : float, or None if not computed
Mean SSIM
lpips_score : float, or None if not computed
Mean LPIPS score
"""
fid = float('inf')
mse, psnr, ssim = None, None, None
lpips_score = None
if self.metrics:
mse, psnr, ssim = self.compute_metrics(x, x_hat)
if self.fid:
fid = self.compute_fid(x, x_hat)
if self.lpips:
lpips_score = self.compute_lpips(x, x_hat)
return fid, mse, psnr, ssim, lpips_score