Utilities#

Utilities for text encoding, score prediction, and evaluation in diffusion models

This module provides core components for building diffusion model pipelines, including text encoding (used as a conditional model), U-Net-based score prediction (ScoreNet), custom loss functions for training, and image quality evaluation. These utilities support various diffusion model architectures, such as DDPM, DDIM, LDM, and SDE, and are designed for standalone use in model training and sampling.

Primary Components

  • TextEncoder: Encodes text prompts into embeddings using a pre-trained BERT model or a custom transformer.

  • ScoreNet: Memory-efficient U-Net-like architecture for predicting noise or scores in diffusion models, supporting time and text conditioning.

  • Loss Functions:
    • mse_loss: Standard mean squared error loss.

    • snr_capped_loss: SNR-weighted noise prediction loss with capped weighting, useful for VP/VE training.

    • ve_sigma_weighted_score_loss: Sigma-weighted score matching loss for VE-SDEs.

  • Metrics: Computes image quality metrics (MSE, PSNR, SSIM, FID, LPIPS) for evaluating generated images.

Notes

  • The primary components are intended to be imported directly for use in diffusion model workflows.

  • Additional supporting classes and functions in this module provide internal functionality for the primary components.


class torchdiff.utils.TextEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer-based encoder for text prompts in conditional diffusion models.

Encodes text prompts into embeddings using either a pre-trained BERT model or a custom transformer architecture. Used as the conditional_model in diffusion models (e.g., DDPM, DDIM, SDE, LDM) to provide conditional inputs for noise prediction.

Parameters:
  • use_pretrained_model (bool, optional) – If True, uses a pre-trained BERT model; otherwise, builds a custom transformer (default: True).

  • model_name (str, optional) – Name of the pre-trained model to load (default: “bert-base-uncased”).

  • vocabulary_size (int, optional) – Size of the vocabulary for the custom transformer’s embedding layer (default: 30522).

  • num_layers (int, optional) – Number of transformer encoder layers for the custom transformer (default: 6).

  • input_dimension (int, optional) – Input embedding dimension for the custom transformer (default: 768).

  • output_dimension (int, optional) – Output embedding dimension for both pre-trained and custom models (default: 768).

  • num_heads (int, optional) – Number of attention heads in the custom transformer (default: 8).

  • context_length (int, optional) – Maximum sequence length for text prompts (default: 77).

  • dropout_rate (float, optional) – Dropout rate for attention and feedforward layers (default: 0.1).

  • qkv_bias (bool, optional) – If True, includes bias in query, key, and value projections for the custom transformer (default: False).

  • scaling_value (int, optional) – Scaling factor for the feedforward layer’s hidden dimension in the custom transformer (default: 4).

  • epsilon (float, optional) – Epsilon for layer normalization in the custom transformer (default: 1e-5).

  • use_learned_pos (bool, optional) – If True, in the transformer structure uses learnable positional embeddings instead of sinusoidal encodings (default: False).

  • **Notes**

  • True (- When use_pretrained_model is) – (requires_grad = False), and a projection layer maps outputs to output_dimension.

  • frozen (the BERT model’s parameters are) – (requires_grad = False), and a projection layer maps outputs to output_dimension.

  • and (- The custom transformer uses EncoderLayer modules with multi-head attention) – feedforward networks, supporting variable input/output dimensions.

  • (batch_size (- The output shape is)

  • context_length

  • output_dimension).

forward(x: torch.Tensor, attention_mask: torch.Tensor | None = None) torch.Tensor[source]#

Encodes text prompts into embeddings.

Processes input token IDs and an optional attention mask to produce embeddings using either a pre-trained BERT model or a custom transformer.

Parameters:
  • x (torch.Tensor) – Token IDs, shape (batch_size, seq_len).

  • attention_mask (torch.Tensor, optional) – Attention mask, shape (batch_size, seq_len), where 0 indicates padding tokens to ignore (default: None).

Returns:

  • x (torch.Tensor) - Encoded embeddings, shape (batch_size, seq_len, output_dimension).

  • **Notes**

    • For pre-trained BERT, the last_hidden_state is projected to – output_dimension and this layer is the only trainable layer in the model.

  • - For the custom transformer, token embeddings are processed throughEmbedding and EncoderLayer modules.

  • - The attention mask should be 0 for padding tokens and 1 for valid tokens when – using the custom transformer, or follow BERT’s convention for pre-trained models.

class torchdiff.utils.EncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single transformer encoder layer with multi-head attention and feedforward network.

Used in the custom transformer of TextEncoder to process embedded text prompts.

Parameters:
  • input_dimension (int) – Input embedding dimension.

  • output_dimension (int) – Output embedding dimension.

  • num_heads (int) – Number of attention heads.

  • dropout_rate (float) – Dropout rate for attention and feedforward layers.

  • qkv_bias (bool) – If True, includes bias in query, key, and value projections.

  • scaling_value (int) – Scaling factor for the feedforward layer’s hidden dimension.

  • epsilon (float, optional) – Epsilon for layer normalization (default: 1e-5).

  • **Notes**

  • architecture (- The layer follows the standard transformer encoder) – residual connection, normalization, feedforward, residual connection, normalization.

  • with (- The attention mechanism uses batch_first=True for compatibility) – TextEncoder’s input format.

forward(x: torch.Tensor, attention_mask: torch.Tensor | None = None) torch.Tensor[source]#

Processes input embeddings through attention and feedforward layers.

Parameters:
  • x (torch.Tensor) – Input embeddings, shape (batch_size, seq_len, input_dimension).

  • attention_mask (torch.Tensor, optional) – Attention mask, shape (batch_size, seq_len), where 0 indicates padding tokens to ignore (default: None).

Returns:

  • x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, output_dimension).

  • **Notes**

    • The attention mask is passed as key_padding_mask to – nn.MultiheadAttention, where 0 indicates padding tokens.

  • - Residual connections and normalization are applied after attention and – feedforward layers.

class torchdiff.utils.FeedForward(*args: Any, **kwargs: Any)[source]#

Bases: Module

Feedforward network for transformer encoder layers.

Used in EncoderLayer to process attention outputs with a two-layer MLP and GELU activation.

Parameters:
  • embedding_dimension (int) – Input and output embedding dimension.

  • scaling_value (int) – Scaling factor for the hidden layer’s dimension (hidden_dim = embedding_dimension * scaling_value).

  • dropout_rate (float, optional) – Dropout rate after the hidden layer (default: 0.1).

Notes

  • The hidden layer dimension is embedding_dimension * scaling_value, following standard transformer feedforward designs.

  • GELU activation is used for non-linearity.

forward(x: torch.Tensor) torch.Tensor[source]#

Processes input embeddings through the feedforward network.

Parameters:

x (torch.Tensor) – Input embeddings, shape (batch_size, seq_len, embedding_dimension).

Return type:

x (torch.Tensor) - Processed embeddings, shape (batch_size, seq_len, embedding_dimension).

class torchdiff.utils.Attention(*args: Any, **kwargs: Any)[source]#

Bases: Module

Attention module for NoisePredictor, supporting text conditioning or self-attention.

Applies multi-head attention to enhance features, with optional text embeddings for conditional generation.

Parameters:
  • in_channels (int) – Number of input channels (embedding dimension for attention).

  • y_embed_dim (int, optional) – Dimensionality of text embeddings (default: 768).

  • num_heads (int, optional) – Number of attention heads (default: 4).

  • num_groups (int, optional) – Number of groups for group normalization (default: 8).

  • dropout_rate (float, optional) – Dropout rate for attention and output (default: 0.1).

in_channels#

Input channel dimension.

Type:

int

y_embed_dim#

Text embedding dimension.

Type:

int

num_heads#

Number of attention heads.

Type:

int

dropout_rate#

Dropout rate.

Type:

float

attention#

Multi-head attention with batch_first=True.

Type:

torch.nn.MultiheadAttention

norm#

Group normalization before attention.

Type:

torch.nn.GroupNorm

dropout#

Dropout layer for output.

Type:

torch.nn.Dropout

y_projection#

Projection for text embeddings to match in_channels.

Type:

torch.nn.Linear

Raises:
  • AssertionError – If input channels do not match in_channels.

  • ValueError – If text embeddings (y) have incorrect dimensions after projection.

forward(x: torch.Tensor, y: torch.Tensor | None = None)[source]#

Applies attention to input features with optional text conditioning.

Parameters:
  • x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).

  • y (torch.Tensor, optional) – Text embeddings, shape (batch_size, seq_len, y_embed_dim) or (batch_size, y_embed_dim) (default: None).

Returns:

Output tensor, same shape as input x.

Return type:

torch.Tensor

class torchdiff.utils.Embedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

Token and positional embedding layer for transformer inputs.

Used in TextEncoder’s transformer to embed token IDs and add positional encodings.

Parameters:
  • vocabulary_size (int) – Size of the vocabulary for token embeddings.

  • embedding_dimension (int, optional) – Dimension of token and positional embeddings (default: 768).

  • max_context_length (int, optional) – Maximum sequence length for precomputing positional encodings (default: 77).

  • use_learned_pos (bool, optional) – If True, uses learnable positional embeddings instead of sinusoidal encodings (default: False).

  • **Notes**

  • embeddings (- Supports both sinusoidal (fixed) and learned positional) – use_learned_pos.

  • via (selectable) – use_learned_pos.

  • architecture (- Sinusoidal encodings follow the transformer) – memory efficiency and cached for sequences up to max_context_length.

  • for (computed on-the-fly) – memory efficiency and cached for sequences up to max_context_length.

  • flexibility. (- Learned positional embeddings are initialized as a learnable parameter for)

  • operation (- Optimized for device-agnostic)

  • transitions. (ensuring seamless CPU/GPU)

  • (batch_size (- The output shape is)

  • seq_len

  • embedding_dimension).

forward(token_ids: torch.Tensor) torch.Tensor[source]#

Embeds token IDs and adds positional encodings.

Parameters:

token_ids (torch.Tensor) – Token IDs, shape (batch_size, seq_len).

Returns:

  • torch.Tensor – Embedded tokens with positional encodings, shape (batch_size, seq_len, embedding_dimension).

  • **Notes**

    • Automatically handles sequences longer than max_context_length by generating – positional encodings on-the-fly.

    • For learned positional embeddings, sequences longer than max_context_length – will raise an error unless truncated.

  • - Ensures device compatibility by generating encodings on the input’s device.

class torchdiff.utils.DiffusionNetwork(*args: Any, **kwargs: Any)[source]#

Bases: Module

Memory-efficient U-Net architecture for diffusion models supporting time and conditional embeddings

Initialize the ScoreNet U-Net with configurable down, middle, and up blocks, time embeddings, and optional attention.

Parameters:
  • in_channels – Number of input channels.

  • down_channels – List of channels for downsampling stages.

  • mid_channels – List of channels for middle blocks.

  • up_channels – List of channels for upsampling stages.

  • down_sampling – Boolean flags indicating whether to downsample at each down block.

  • time_embed_dim – Dimensionality of the time embedding.

  • y_embed_dim – Dimensionality of the conditional embedding.

  • num_down_blocks – Number of residual layers per down block.

  • num_mid_blocks – Number of residual layers per middle block.

  • num_up_blocks – Number of residual layers per up block.

  • dropout_rate – Dropout probability.

  • down_sampling_factor – Stride factor for downsampling/upsampling.

  • y_to_all – If True, applies conditional embeddings to all attention layers.

  • cont_time – Whether to use continuous time embeddings.

  • use_flash_attention – Whether to use flash attention for cross-attention layers.

  • grad_check – Whether to use gradient checkpointing.

forward(x: torch.Tensor, t: torch.Tensor, y: torch.Tensor | None = None, clip_embeddings: torch.Tensor | None = None) torch.Tensor[source]#

Forward pass through the U-Net with time and optional conditional embeddings.

Parameters:
  • x – Input tensor of shape [B, C, H, W].

  • t – Tensor of timesteps [B] or [B, 1].

  • y – Optional context embeddings [B, D] or [B, L, D].

  • clip_embeddings – Optional CLIP embeddings [B, D].

Returns:

Output tensor of shape [B, in_channels, H, W].

class torchdiff.utils.ResBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Efficient residual block with optional cross-attention for U-Net.

Initialize a ResBlock with optional attention and multiple residual layers.

Parameters:
  • in_channels – Number of input channels.

  • out_channels – Number of output channels.

  • time_channels – Dimensionality of time embedding.

  • context_channels – Dimensionality of conditional embedding.

  • num_layers – Number of residual layers in the block.

  • dropout – Dropout probability.

  • use_attention – Whether to include a cross-attention layer.

  • use_flash – Whether to use flash attention if available.

forward(x: torch.Tensor, t_emb: torch.Tensor, context: torch.Tensor | None = None)[source]#

Forward pass through the residual block.

Parameters:
  • x – Input tensor of shape [B, C, H, W].

  • t_emb – Time embedding tensor of shape [B, time_channels].

  • context – Optional conditional embeddings for cross-attention.

Returns:

Output tensor after residual layers (and optional attention).

class torchdiff.utils.CrossAttention(*args: Any, **kwargs: Any)[source]#

Bases: Module

Cross-attention module with optional flash attention.

Initialize cross-attention with query, key, value projections and optional flash attention.

Parameters:
  • channels – Number of input channels.

  • context_dim – Dimensionality of the context embeddings.

  • num_heads – Number of attention heads.

  • dropout – Dropout probability for attention output.

  • use_flash – Whether to use flash attention if available.

forward(x: torch.Tensor, context: torch.Tensor) torch.Tensor[source]#

Compute cross-attention output.

Parameters:
  • x – Input feature map tensor [B, C, H, W].

  • context – Context embeddings [B, D] or [B, L, D].

Returns:

Tensor of shape [B, C, H, W] after applying attention.

torchdiff.utils.get_timestep_embedding(timesteps: torch.Tensor, dim: int, continuous: bool = True, scale=1000.0) torch.Tensor[source]#

Compute sinusoidal timestep embeddings for continuous or discrete timesteps.

Parameters:
  • timesteps – Tensor of timesteps [B] or scalar.

  • dim – Dimensionality of the embedding vector.

  • continuous – If True, scales timesteps by 1000 to emulate discrete DDPM timesteps.

Returns:

Tensor of shape [B, dim] containing sinusoidal embeddings.

class torchdiff.utils.LossAdapter(loss_fn)[source]#

Bases: object

Adapter to make any loss function compatible with extra arguments.

torchdiff.utils.mse_loss(pred: torch.Tensor, target: torch.Tensor, *args) torch.Tensor[source]#

Standard mean squared error (MSE) loss.

Computes the element-wise squared difference between pred and target and returns the mean across all elements.

Parameters:
  • pred – Predicted tensor, shape [B, …].

  • target – Target tensor, same shape as pred.

  • *args – Placeholder for optional unused arguments for API compatibility.

Returns:

Scalar tensor representing mean squared error.

torchdiff.utils.snr_capped_loss(pred_noise: torch.Tensor, target_noise: torch.Tensor, variance: torch.Tensor, gamma: float = 5.0, *args) torch.Tensor[source]#

Signal-to-noise-ratio (SNR) capped noise prediction loss for diffusion models.

This implements a weighted MSE where the weight is the SNR of the timestep, capped at a maximum value gamma. Typically used in VP/VE noise prediction.

Parameters:
  • pred_noise – Predicted noise tensor, same shape as target_noise.

  • target_noise – True noise tensor.

  • variance – Variance (sigma^2) corresponding to the timestep t, shape broadcastable to pred_noise.

  • gamma – Maximum SNR weight (default 5.0).

  • *args – Placeholder for optional unused arguments for API compatibility.

Returns:

Scalar tensor representing the SNR-weighted mean squared error.

torchdiff.utils.min_snr_loss(pred: torch.Tensor, target: torch.Tensor, snr: torch.Tensor, gamma: float = 5.0, *args) torch.Tensor[source]#

Min-SNR weighting strategy for stable diffusion training. Normalizes the loss by min(snr, gamma) / snr.

torchdiff.utils.ve_sigma_weighted_score_loss(pred_score: torch.Tensor, target_score: torch.Tensor, sigma: torch.Tensor, *args) torch.Tensor[source]#

VE-SDE sigma-weighted score matching loss.

Implements the recommended loss for Variance Exploding SDEs:

E[ || sigma(t) * s_theta(x_t, t) + epsilon ||^2 ]

where epsilon is the true noise used to perturb x_0.

Parameters:
  • pred_score – Model-predicted score tensor (∇_x log p(x_t)), shape [B, …].

  • target_score – Target score, typically -epsilon / sigma(t).

  • sigma – Standard deviation (σ(t)) at the corresponding timesteps, shape broadcastable to pred_score.

  • *args – Placeholder for optional unused arguments for API compatibility.

Returns:

Scalar tensor representing the sigma-weighted score matching loss.

class torchdiff.utils.Metrics(device: str = 'cuda', fid: bool = True, metrics: bool = False, lpips_: bool = False)[source]#

Bases: object

Computes image quality metrics for evaluating diffusion models.

Supports Mean Squared Error (MSE), Peak Signal-to-Noise Ratio (PSNR), Structural Similarity Index (SSIM), Fréchet Inception Distance (FID), and Learned Perceptual Image Patch Similarity (LPIPS) for comparing generated and ground truth images.

Parameters:
  • device (str, optional) – Device for computation (e.g., ‘cuda’, ‘cpu’) (default: ‘cuda’).

  • fid (bool, optional) – If True, compute FID score (default: True).

  • metrics (bool, optional) – If True, compute MSE, PSNR, and SSIM (default: False).

  • lpips (bool, optional) – If True, compute LPIPS using VGG backbone (default: False).

compute_fid(real_images: torch.Tensor, fake_images: torch.Tensor) float[source]#

Computes the Fréchet Inception Distance (FID) between real and generated images.

Saves images to temporary directories and uses Inception V3 to compute FID, cleaning up directories afterward.

Parameters:
  • real_images (torch.Tensor) – Real images, shape (batch_size, channels, height, width), in [-1, 1].

  • fake_images (torch.Tensor) – Generated images, same shape, in [-1, 1].

Returns:

  • fid (float) - FID score, or float(‘inf’) if computation fails.

  • **Notes**

  • - Images are normalized to [0, 1] and saved as PNG files for FID computation.

    • Uses Inception V3 with 2048-dimensional features (dims=2048).

compute_metrics(x: torch.Tensor, x_hat: torch.Tensor) Tuple[float, float, float][source]#

Computes MSE, PSNR, and SSIM for evaluating image quality.

Parameters:
  • x (torch.Tensor) – Ground truth images, shape (batch_size, channels, height, width).

  • x_hat (torch.Tensor) – Generated images, same shape as x.

Returns:

  • mse (float) – Mean squared error.

  • psnr (float) – Peak signal-to-noise ratio.

  • ssim (float) – Structural similarity index (mean over batch).

compute_lpips(x: torch.Tensor, x_hat: torch.Tensor) float[source]#

Computes LPIPS using a pre-trained VGG network.

Parameters:
  • x (torch.Tensor) – Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].

  • x_hat (torch.Tensor) – Generated images, same shape as x.

Return type:

lpips (float) - Mean LPIPS score over the batch.

forward(x: torch.Tensor, x_hat: torch.Tensor) Tuple[float, float, float, float, float][source]#

Computes specified metrics for ground truth and generated images.

Parameters:
  • x (torch.Tensor) – Ground truth images, shape (batch_size, channels, height, width), in [-1, 1].

  • x_hat (torch.Tensor) – Generated images, same shape as x.

Returns:

  • fid (float, or float(‘inf’) if not computed) – Mean FID score.

  • mse (float, or None if not computed) – Mean MSE

  • psnr (float, or None if not computed) – Mean PSNR

  • ssim (float, or None if not computed) – Mean SSIM

  • lpips_score (float, or None if not computed) – Mean LPIPS score