SDE#

Score-Based Generative Modeling with Stochastic Differential Equations (SDE)

This module implements a complete framework for score-based generative models using SDEs, as described in Song et al. (2021, “Score-Based Generative Modeling through Stochastic Differential Equations”). It provides components for forward and reverse diffusion processes, hyperparameter management, training, and image sampling, supporting Variance Exploding (VE), Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE methods for flexible noise schedules. Supports both unconditional and conditional generation with text prompts.

Components

  • ForwardSDE: Forward diffusion process to add noise using SDE methods.

  • ReverseSDE: Reverse diffusion process to denoise using SDE methods.

  • SchedulerSDE: Noise schedule and SDE-specific parameter management.

  • TrainSDE: Training loop with mixed precision and scheduling.

  • SampleSDE: Image generation from trained SDE models.

References

  • Song, Yang, et al. “Score-based generative modeling through stochastic differential equations.” arXiv preprint arXiv:2011.13456 (2020).


class torchdiff.sde.ForwardSDE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Forward diffusion process for continuous-time diffusion models.

This module implements the marginal forward noising process p(x_t | x_0) for several commonly used stochastic differential equation (SDE) formulations, including:

  • Variance Preserving (VP-SDE)

  • Variance Exploding (VE-SDE)

  • Sub-Variance Preserving (Sub-VP-SDE)

  • Probability Flow ODE (ODE)

Given clean data x₀, Gaussian noise ε ~ N(0, I), and continuous time t ∈ [0, 1], the forward process samples x_t and provides the true score ∇ₓ log p(x_t | x₀), which is commonly used for score matching objectives.

Supported forward marginals:

  1. VP-SDE:

    p(x_t | x_0) = N(α(t) x_0, σ²(t) I)

  2. VE-SDE:

    p(x_t | x_0) = N(x_0, σ²(t) I), where σ(t) = σ_min (σ_max / σ_min)^t

  3. Sub-VP-SDE:

    p(x_t | x_0) = N(x_0, σ²(t) I), where σ²(t) = 1 - exp(-∫₀ᵗ β(s) ds)

  4. Probability Flow ODE:

    Shares the same marginals as VP-SDE but corresponds to a deterministic dynamics during sampling.

The returned score is analytically computed as:

∇ₓ log p(x_t | x₀) = -(x_t - μ(t)) / σ²(t) = -ε / σ(t)

where μ(t) is the mean of the forward transition.

Parameters:
  • scheduler (SchedulerSDE) – Scheduler providing β(t), α(t), and σ(t) for VP and Sub-VP processes.

  • method (str, default="vp") – Forward process type. Must be one of: {“vp”, “ve”, “sub-vp”, “ode”}.

  • pred_type (Prediction parameterization.) – One of {“noise”, “x0”, “score”}.

  • sigma_min (float, default=0.01) – Minimum noise scale for the VE-SDE.

  • sigma_max (float, default=50.0) – Maximum noise scale for the VE-SDE.

  • eps (float, default=1e-8) – Small constant for numerical stability when computing the score.

Notes

  • Time t is assumed to be normalized to [0, 1].

  • All operations are vectorized and support arbitrary data dimensions.

  • Broadcasting is handled automatically to match the shape of x₀.

  • For the ODE method, noise is still used to compute the analytical score during training, even though sampling is deterministic.

References

  • Song et al., “Score-Based Generative Modeling through SDEs”, ICLR 2021

  • Ho et al., “Denoising Diffusion Probabilistic Models”, NeurIPS 2020

  • Kingma et al., “Variational Diffusion Models”, NeurIPS 2021

get_forward_params(t: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#

Get mean coefficient and std for the forward process based on SDE method :returns: coefficient for clean data x_0

std: standard deviation of noise

Return type:

mean_coeff

forward(x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#

Sample from transition kernel and compute true score

Parameters:
  • x0 – (batch, …, dims) clean data

  • t – (batch, ) continuous time in [0, 1]

  • noise – (batch, …, dims) standard Gaussian noise

Returns:

(batch, …, dims) noised data target: (batch, …, dims) true score/added noise

Return type:

xt

class torchdiff.sde.ReverseSDE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Reverse-time diffusion process for continuous-time sde diffusion models

This module implements a single-step numerical solver for the reverse-time stochastic differential equation (SDE) or probability flow ordinary differential equation (ODE) corresponding to a trained score-based model.

Given a noisy sample x_t at time t and an estimate of the score ∇ₓ log p_t(x), the reverse process evolves the system backward in time (t → 0) using an Euler–Maruyama discretization.

Supported reverse dynamics:

  • Variance Preserving (VP-SDE)

  • Variance Exploding (VE-SDE)

  • Sub-Variance Preserving (Sub-VP-SDE)

  • Probability Flow ODE (ODE)

General reverse SDE form:

dx = [f(x, t) - g²(t) ∇ₓ log p_t(x)] dt + g(t) dW̄_t

where:
  • f(x, t) is the forward drift

  • g(t) is the diffusion coefficient

  • dW̄_t denotes reverse-time Brownian motion

For the probability flow ODE, the diffusion term vanishes and the dynamics become deterministic while preserving the same marginals as the VP-SDE.

Parameters:
  • scheduler (nn.Module) – Scheduler providing β(t) and related quantities for VP and Sub-VP dynamics. Typically an instance of SchedulerSDE.

  • method (str, default="vp") – Type of reverse-time dynamics. Must be one of: {“vp”, “ve”, “sub-vp”, “ode”}.

  • pred_type (Prediction parameterization.) – One of {“noise”, “score”}.

  • sigma_min (float, default=0.01) – Minimum noise scale for the VE-SDE.

  • sigma_max (float, default=50.0) – Maximum noise scale for the VE-SDE.

Notes

  • Time t is assumed to be normalized to [0, 1].

  • Reverse integration proceeds with a negative time step dt < 0.

  • The score ∇ₓ log p_t(x) is typically predicted by a neural network.

  • For the final step or ODE-based sampling, stochastic noise can be disabled.

  • All tensor operations support broadcasting over arbitrary data shapes.

Numerical Integration#

The update rule implemented is the Euler–Maruyama scheme:

x_{t+dt} = x_t
  • [f(x_t, t) - g²(t)·score(x_t, t)] dt

  • g(t) √|dt| ε

where ε ~ N(0, I). For ODE sampling, the stochastic term is omitted.

References

  • Anderson, “Reverse-Time Diffusion Equation Models”, 1982

  • Song et al., “Score-Based Generative Modeling through SDEs”, ICLR 2021

  • Kingma et al., “Variational Diffusion Models”, NeurIPS 2021

get_reverse_coeffs(t: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]#

Get drift and diffusion coefficients for reverse SDE

Returns:

coefficient for drift term g_squared: squared diffusion coefficient (for score term) diffusion_coeff: coefficient for diffusion term

Return type:

drift_coeff

forward(xt: torch.Tensor, pred: torch.Tensor, t: torch.Tensor, dt: float, last_step: bool = False) torch.Tensor[source]#

Single reverse Euler-Maruyama step :param xt: (batch, …, dims) current state :param pred: (batch, …, dims) output (prediction of diffusion model) :param t: (batch,) current time :param dt: scalar time step (negative for reverse) :param last_step: if True, skip noise for deterministic final step

Returns:

(batch, …, dims) previous state

Return type:

x_prev

class torchdiff.sde.SchedulerSDE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Continuous-time variance (noise) scheduler for diffusion models formulated as stochastic differential equations (SDEs).

This class defines the time-dependent noise rate β(t) and its derived quantities used in forward diffusion processes of the form:

dx = -½ β(t) x dt + √β(t) dW_t

where t ∈ [0, 1] is continuous time and W_t is standard Brownian motion.

Supported schedules:
  • Linear schedule:

    β(t) = β_min + t (β_max - β_min)

  • Cosine schedule:

    Defined implicitly via the cumulative signal power ᾱ(t) = cos²((t + s) / (1 + s) · π / 2), following Nichol & Dhariwal (2021).

The scheduler provides convenient access to commonly used quantities:
  • β(t) — instantaneous noise rate

  • ∫₀ᵗ β(s) ds — cumulative noise

  • α(t) — signal scaling factor

  • σ²(t) — noise variance

  • SNR(t) — signal-to-noise ratio

All methods operate on PyTorch tensors and support broadcasting.

Parameters:
  • schedule_type (str, default="linear") – Type of noise schedule. Must be one of {“linear”, “cosine”}.

  • beta_min (float, default=0.1) – Minimum noise rate for the linear schedule. Must satisfy 0 < beta_min < beta_max. Ignored for cosine schedule.

  • beta_max (float, default=20.0) – Maximum noise rate for the linear schedule. Ignored for cosine schedule.

  • cosine_s (float, default=0.008) – Small offset used in the cosine schedule to prevent singularities near t = 0. Matches the formulation from improved DDPMs.

Notes

  • Time t is assumed to be normalized to [0, 1].

  • α(t) and σ(t) satisfy:

    α²(t) + σ²(t) = 1

    for both schedules.

  • The cosine schedule defines β(t) implicitly through α²(t); the β(t) returned in this case is an approximation derived from finite differences.

References

  • Ho et al., “Denoising Diffusion Probabilistic Models”, NeurIPS 2020

  • Song et al., “Score-Based Generative Modeling through SDEs”, ICLR 2021

  • Nichol & Dhariwal, “Improved Denoising Diffusion Probabilistic Models”, ICML 2021

beta(t: torch.Tensor) torch.Tensor[source]#

β(t) - noise schedule

integral_beta(t: torch.Tensor) torch.Tensor[source]#

∫₀ᵗ β(s) ds

alpha(t: torch.Tensor) torch.Tensor[source]#

α(t) = exp(-½∫₀ᵗ β(s) ds)

alpha_squared(t: torch.Tensor) torch.Tensor[source]#

α²(t) = exp(-∫₀ᵗ β(s) ds)

variance(t: torch.Tensor) torch.Tensor[source]#

σ²(t) = 1 - α²(t)

std(t: torch.Tensor) torch.Tensor[source]#

σ(t) = √(1 - α²(t))

snr(t: torch.Tensor) torch.Tensor[source]#

signal-to-noise ratio: SNR(t) = α²(t) / σ²(t)

class torchdiff.sde.TrainSDE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Trainer for score-based generative models using Stochastic Differential Equations.

Manages the training process for SDE-based generative models, optimizing a noise predictor to learn the noise added by the forward SDE process, as described in Song et al. (2021). Supports conditional training with text prompts, mixed precision, learning rate scheduling, early stopping, and checkpointing.

Parameters:
  • score_net (nn.Module) – Model to predict score/noise.

  • fwd_sde (nn.Module) – Forward SDE diffusion module for adding noise.

  • rwd_sde (nn.Module) – Reverse SDE diffusion module for denoising.

  • train_loader (torch.utils.data.DataLoader) – DataLoader for training data.

  • optim (torch.optim.Optimizer) – Optimizer for training the noise predictor and conditional model (if applicable).

  • loss_fn (callable) – Loss function to compute the difference between predicted and actual noise.

  • val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data, default None.

  • max_epochs (int, optional) – Maximum number of training epochs (default: 1000).

  • device (torch.device, optional) – Device for computation (default: CUDA if available, else CPU).

  • cond_net (nn.Module, optional) – Model for conditional generation (e.g., text embeddings), default None.

  • metrics (object, optional) – Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).

  • tokenizer (BertTokenizer, optional) – Tokenizer for processing text prompts, default None (loads “bert-base-uncased”).

  • max_token_length (int, optional) – Maximum length for tokenized prompts (default: 77).

  • store_path (str, optional) – Path to save model checkpoints (default: “sde_train”).

  • patience (int, optional) – Number of epochs to wait for improvement before early stopping (default: 20).

  • warmup_steps (int, optional) – Number of steps for learning rate warmup (default: 1000).

  • val_freq (int, optional) – Frequency (in epochs) for validation (default: 10).

  • norm_range (tuple, optional) – Range for clamping generated images (default: (-1, 1)).

  • norm_output (bool, optional) – Whether to normalize generated images to [0, 1] for metrics (default: True).

  • use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).

  • grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).

  • log_freq (int, optional) – Number of epochs before printing loss.

  • use_comp (bool, optional) – whether the model is internally compiled using torch.compile (default: false)

  • time_eps (float, optional) – lower bound for diffusion time sampling (time_eps, 1.0) (default: 1e-5)

  • num_steps (int, optional) – number of time staps for sampling during validation (default: 400)

load_checkpoint(checkpoint_path: str) Tuple[int, float][source]#

Loads a training checkpoint to resume training.

Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint. Handles DDP model state dict loading.

Parameters:

checkpoint_path (str) – Path to the checkpoint file.

Returns:

  • epoch (int) – The epoch at which the checkpoint was saved.

  • loss (float) – The loss at the checkpoint.

static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#

Creates a learning rate scheduler for warmup.

Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs, then maintains it.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.

  • warmup_steps (int) – Number of steps for the warmup phase.

Returns:

Learning rate scheduler for warmup.

Return type:

torch.optim.lr_scheduler.LambdaLR

forward() Dict[source]#

Trains the SDE model to predict noise added by the forward diffusion process.

Executes the training loop, optimizing the noise predictor and conditional model (if applicable) using mixed precision, gradient clipping, and learning rate scheduling. Supports validation, early stopping, and checkpointing.

Returns:

  • losses (dictionary of train and validation losses.)

  • **Notes**

    • Training uses mixed precision via torch.cuda.amp or torch.amp for efficiency.

  • - Checkpoints are saved when the validation (or training) loss improves, and on early stopping.

    • Early stopping is triggered if no improvement occurs for patience epochs.

sample_time(batch_size: int, eps: float = 1e-05) torch.Tensor[source]#
validate() Tuple[float, float, float, float, float, float][source]#

Validates the noise predictor and computes evaluation Metrics.

Computes validation loss (MSE between predicted and ground truth noise) and generates samples using the reverse diffusion model by manually iterating over timesteps. Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS) if metrics_ is provided.

Returns:

  • val_loss (float) – Mean validation loss.

  • fid (float, or float(‘inf’) if not computed) – Mean FID score.

  • mse (float, or None if not computed) – Mean MSE

  • psnr (float, or None if not computed) – Mean PSNR

  • ssim (float, or None if not computed) – Mean SSIM

  • lpips_score (float, or None if not computed) – Mean LPIPS score

class torchdiff.sde.SampleSDE(*args: Any, **kwargs: Any)[source]#

Bases: Module

Sampler for generating images using SDE-based generative models.

Generates images by iteratively denoising random noise using the reverse SDE process and a trained noise predictor, as described in Song et al. (2021). Supports both unconditional and conditional generation with text prompts.

Parameters:
  • rwd_sde (ReverseSDE) – Reverse SDE diffusion module for denoising.

  • score_net (nn.Module) – Model to predict noise added during the forward SDE process.

  • img_size (tuple) – Shape of generated images as (height, width).

  • cond_net (nn.Module, optional) – Model for conditional generation (e.g., TextEncoder), default None.

  • tokenizer (str or BertTokenizer, optional) – Tokenizer for processing text prompts, default “bert-base-uncased”.

  • max_token_length (int, optional) – Maximum length for tokenized prompts (default: 77).

  • batch_size (int, optional) – Number of images to generate per batch (default: 1).

  • in_channels (int, optional) – Number of input channels for generated images (default: 3).

  • device (srt, optional) – Device for computation (default: CUDA).

  • norm_range (tuple, optional) – Range for clamping generated images (min, max), default (-1, 1).

tokenize(prompts: str | List) Tuple[torch.Tensor, torch.Tensor][source]#

Tokenizes text prompts for conditional generation.

Converts input prompts into tokenized tensors using the specified tokenizer.

Parameters:

prompts (str or list) – Text prompt(s) for conditional generation. Can be a single string or a list of strings.

Returns:

  • input_ids (torch.Tensor) – Tokenized input IDs, shape (batch_size, max_token_length).

  • attention_mask (torch.Tensor) – Attention mask, shape (batch_size, max_token_length).

forward(num_steps: int, conds: str | List | None = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = 'sde_samples') torch.Tensor[source]#

Generates images using the reverse SDE sampling process.

Iteratively denoises random noise to generate images using the reverse SDE process and noise predictor. Supports conditional generation with text prompts.

Parameters:
  • conds (str or list, optional) – Text prompt(s) for conditional generation, default None.

  • norm_output (bool, optional) – If True, normalizes output images to [0, 1] (default: True).

  • save_imgs (bool, optional) – If True, saves generated images to save_path (default: True).

  • save_path (str, optional) – Directory to save generated images (default: “sde_samples”).

Returns:

  • samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width).

  • If norm_output is True, images are normalized to [0, 1]; otherwise, they are clamped to norm_range.

to(device: torch.device) Self[source]#

Moves the module and its components to the specified device.

Updates the device attribute and moves the reverse diffusion, noise predictor, and conditional model (if present) to the specified device.

Parameters:

device (torch.device) – Target device for the module and its components.

Return type:

sample_sde (SampleSDE) - moved to the specified device.