"""
**Score-Based Generative Modeling with Stochastic Differential Equations (SDE)**
This module implements a complete framework for score-based generative models using SDEs,
as described in Song et al. (2021, "Score-Based Generative Modeling through Stochastic
Differential Equations"). It provides components for forward and reverse diffusion
processes, hyperparameter management, training, and image sampling, supporting Variance
Exploding (VE), Variance Preserving (VP), sub-Variance Preserving (sub-VP), and ODE
methods for flexible noise schedules. Supports both unconditional and conditional
generation with text prompts.
**Components**
- **ForwardSDE**: Forward diffusion process to add noise using SDE methods.
- **ReverseSDE**: Reverse diffusion process to denoise using SDE methods.
- **SchedulerSDE**: Noise schedule and SDE-specific parameter management.
- **TrainSDE**: Training loop with mixed precision and scheduling.
- **SampleSDE**: Image generation from trained SDE models.
**References**
- Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." arXiv preprint arXiv:2011.13456 (2020).
---------------------------------------------------------------------------------
"""
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
from typing import Optional, Tuple, Callable, List, Any, Union, Dict
from typing_extensions import Self
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertTokenizer
import warnings
from torchvision.utils import save_image
from .utils import LossAdapter
import os
__all__ = [
"ForwardSDE",
"ReverseSDE",
"SchedulerSDE",
"TrainSDE",
"SampleSDE",
]
###==================================================================================================================###
[docs]
class ForwardSDE(nn.Module):
"""
Forward diffusion process for continuous-time diffusion models.
This module implements the marginal forward noising process
p(x_t | x_0) for several commonly used stochastic differential equation
(SDE) formulations, including:
• Variance Preserving (VP-SDE)
• Variance Exploding (VE-SDE)
• Sub-Variance Preserving (Sub-VP-SDE)
• Probability Flow ODE (ODE)
Given clean data x₀, Gaussian noise ε ~ N(0, I), and continuous time
t ∈ [0, 1], the forward process samples x_t and provides the *true score*
∇ₓ log p(x_t | x₀), which is commonly used for score matching objectives.
Supported forward marginals:
1. VP-SDE:
p(x_t | x_0) = N(α(t) x_0, σ²(t) I)
2. VE-SDE:
p(x_t | x_0) = N(x_0, σ²(t) I),
where σ(t) = σ_min (σ_max / σ_min)^t
3. Sub-VP-SDE:
p(x_t | x_0) = N(x_0, σ²(t) I),
where σ²(t) = 1 - exp(-∫₀ᵗ β(s) ds)
4. Probability Flow ODE:
Shares the same marginals as VP-SDE but corresponds to a
deterministic dynamics during sampling.
The returned score is analytically computed as:
∇ₓ log p(x_t | x₀) = -(x_t - μ(t)) / σ²(t) = -ε / σ(t)
where μ(t) is the mean of the forward transition.
Parameters
----------
scheduler : SchedulerSDE
Scheduler providing β(t), α(t), and σ(t) for VP and Sub-VP processes.
method : str, default="vp"
Forward process type. Must be one of:
{"vp", "ve", "sub-vp", "ode"}.
pred_type: Prediction parameterization.
One of {"noise", "x0", "score"}.
sigma_min : float, default=0.01
Minimum noise scale for the VE-SDE.
sigma_max : float, default=50.0
Maximum noise scale for the VE-SDE.
eps : float, default=1e-8
Small constant for numerical stability when computing the score.
Notes
-----
• Time t is assumed to be normalized to [0, 1].
• All operations are vectorized and support arbitrary data dimensions.
• Broadcasting is handled automatically to match the shape of x₀.
• For the ODE method, noise is still used to compute the analytical
score during training, even though sampling is deterministic.
References
----------
- Song et al., "Score-Based Generative Modeling through SDEs", ICLR 2021
- Ho et al., "Denoising Diffusion Probabilistic Models", NeurIPS 2020
- Kingma et al., "Variational Diffusion Models", NeurIPS 2021
"""
def __init__(
self,
scheduler: nn.Module,
method: str = "vp",
pred_type = 'noise',
sigma_min: float = 0.01,
sigma_max: float = 50.0,
eps: float = 1e-8
):
super().__init__()
valid_methods = ["vp", "ve", "sub-vp", "ode"]
if method not in valid_methods:
raise ValueError(f"sde_method must be one of {valid_methods}, got {method}")
valid_types = ["noise", "score"]
if pred_type not in valid_types:
raise ValueError(f"pred_type must be one of {valid_types}, got {pred_type}")
self.vs = scheduler
self.method = method
self.pred_type = pred_type
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.eps = eps
def _broadcast_to_shape(self, tensor: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
"""Broadcast tensor to target shape by adding trailing dimensions"""
while tensor.dim() < len(target_shape):
tensor = tensor.unsqueeze(-1)
return tensor
[docs]
def get_forward_params(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get mean coefficient and std for the forward process based on SDE method
Returns:
mean_coeff: coefficient for clean data x_0
std: standard deviation of noise
"""
mean_coeff = None
std = None
if self.method == "vp":
# vp-sde: p(x_t | x_0) = N(α(t)x_0, σ²(t)I)
mean_coeff = self.vs.alpha(t)
std = self.vs.std(t)
elif self.method == "ve":
# ve-sde: p(x_t | x_0) = N(x_0, σ²(t)I)
# σ(t) grows from sigma_min to sigma_max
mean_coeff = torch.ones_like(t)
sigma_t = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
std = sigma_t
elif self.method == "sub-vp":
# sub-vp-sde: p(x_t | x_0) = N(x_0, σ²(t)I) where σ²(t) = 1 - e^(-∫β(s)ds)
mean_coeff = torch.ones_like(t)
std = self.vs.std(t)
elif self.method == "ode":
# probability flow ode: same marginals as vp-sde but deterministic
mean_coeff = self.vs.alpha(t)
std = self.vs.std(t)
return mean_coeff, std
[docs]
def forward(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Sample from transition kernel and compute true score
Arguments:
x0: (batch, ..., dims) clean data
t: (batch, ) continuous time in [0, 1]
noise: (batch, ..., dims) standard Gaussian noise
Returns:
xt: (batch, ..., dims) noised data
target: (batch, ..., dims) true score/added noise
"""
mean_coeff, std = self.get_forward_params(t)
# broadcast to match x0 shape
mean_coeff = self._broadcast_to_shape(mean_coeff, x0.shape)
std = self._broadcast_to_shape(std, x0.shape)
# x_t = mean_coeff * x_0 + std * ε
xt = mean_coeff * x0 + std * noise
if self.pred_type == 'noise':
target = noise
elif self.pred_type == "score":
# ∇_x log p(x_t | x_0) = -(x_t - mean_coeff*x_0) / σ²(t) = -ε / σ(t)
target = -noise / (std + self.eps)
return xt, target
###==================================================================================================================###
[docs]
class ReverseSDE(nn.Module):
"""
Reverse-time diffusion process for continuous-time sde diffusion models
This module implements a single-step numerical solver for the *reverse-time*
stochastic differential equation (SDE) or probability flow ordinary
differential equation (ODE) corresponding to a trained score-based model.
Given a noisy sample x_t at time t and an estimate of the score
∇ₓ log p_t(x), the reverse process evolves the system backward in time
(t → 0) using an Euler–Maruyama discretization.
Supported reverse dynamics:
• Variance Preserving (VP-SDE)
• Variance Exploding (VE-SDE)
• Sub-Variance Preserving (Sub-VP-SDE)
• Probability Flow ODE (ODE)
General reverse SDE form:
dx = [f(x, t) - g²(t) ∇ₓ log p_t(x)] dt + g(t) dW̄_t
where:
• f(x, t) is the forward drift
• g(t) is the diffusion coefficient
• dW̄_t denotes reverse-time Brownian motion
For the probability flow ODE, the diffusion term vanishes and the dynamics
become deterministic while preserving the same marginals as the VP-SDE.
Parameters
----------
scheduler : nn.Module
Scheduler providing β(t) and related quantities for VP and Sub-VP
dynamics. Typically an instance of `SchedulerSDE`.
method : str, default="vp"
Type of reverse-time dynamics. Must be one of:
{"vp", "ve", "sub-vp", "ode"}.
pred_type: Prediction parameterization.
One of {"noise", "score"}.
sigma_min : float, default=0.01
Minimum noise scale for the VE-SDE.
sigma_max : float, default=50.0
Maximum noise scale for the VE-SDE.
Notes
-----
• Time t is assumed to be normalized to [0, 1].
• Reverse integration proceeds with a *negative* time step dt < 0.
• The score ∇ₓ log p_t(x) is typically predicted by a neural network.
• For the final step or ODE-based sampling, stochastic noise can be disabled.
• All tensor operations support broadcasting over arbitrary data shapes.
Numerical Integration
---------------------
The update rule implemented is the Euler–Maruyama scheme:
x_{t+dt} = x_t
+ [f(x_t, t) - g²(t)·score(x_t, t)] dt
+ g(t) √|dt| ε
where ε ~ N(0, I). For ODE sampling, the stochastic term is omitted.
References
----------
- Anderson, "Reverse-Time Diffusion Equation Models", 1982
- Song et al., "Score-Based Generative Modeling through SDEs", ICLR 2021
- Kingma et al., "Variational Diffusion Models", NeurIPS 2021
"""
def __init__(
self,
scheduler: nn.Module,
method: str = "vp",
pred_type: str = 'noise',
sigma_min: float = 0.01,
sigma_max: float = 50.0,
eps: float = 1e-8
):
super().__init__()
valid_methods = ["vp", "ve", "sub-vp", "ode"]
if method not in valid_methods:
raise ValueError(f"sde_method must be one of {valid_methods}, got {method}")
valid_types = ["noise", "score"]
if pred_type not in valid_types:
raise ValueError(f"pred_type must be one of {valid_types}, got {pred_type}")
self.vs = scheduler
self.method = method
self.pred_type = pred_type
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.eps = eps
def _broadcast_to_shape(self, tensor: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
"""Broadcast tensor to target shape by adding trailing dimensions"""
while tensor.dim() < len(target_shape):
tensor = tensor.unsqueeze(-1)
return tensor
[docs]
def get_reverse_coeffs(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get drift and diffusion coefficients for reverse SDE
Returns:
drift_coeff: coefficient for drift term
g_squared: squared diffusion coefficient (for score term)
diffusion_coeff: coefficient for diffusion term
"""
if self.method == "vp":
# vp-sde: dx = [-½β(t)x - β(t)∇log p_t(x)]dt + √β(t)dw̄
drift_coeff = -0.5 * self.vs.beta(t)
g_squared = self.vs.beta(t)
diffusion_coeff = torch.sqrt(self.vs.beta(t))
elif self.method == "ve":
# ve-sde: dx = [-σ(t)dσ/dt ∇log p_t(x)]dt + √(2σ(t)dσ/dt)dw̄
sigma_t = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
dsigma_dt = sigma_t * torch.log(torch.tensor(self.sigma_max / self.sigma_min))
drift_coeff = torch.zeros_like(t)
g_squared = 2 * sigma_t * dsigma_dt
diffusion_coeff = torch.sqrt(g_squared)
elif self.method == "sub-vp":
# sub-vp-sde: dx = [-β(t)∇log p_t(x)]dt + √β(t)dw̄
drift_coeff = torch.zeros_like(t)
g_squared = self.vs.beta(t)
diffusion_coeff = torch.sqrt(self.vs.beta(t))
elif self.method == "ode":
# probability flow ode: deterministic
drift_coeff = -0.5 * self.vs.beta(t)
g_squared = self.vs.beta(t)
diffusion_coeff = torch.zeros_like(t) # no diffusion in ode
return drift_coeff, g_squared, diffusion_coeff
[docs]
def forward(self, xt: torch.Tensor, pred: torch.Tensor, t: torch.Tensor, dt: float, last_step: bool = False) -> torch.Tensor:
"""Single reverse Euler-Maruyama step
Args:
xt: (batch, ..., dims) current state
pred: (batch, ..., dims) output (prediction of diffusion model)
t: (batch,) current time
dt: scalar time step (negative for reverse)
last_step: if True, skip noise for deterministic final step
Returns:
x_prev: (batch, ..., dims) previous state
"""
if not torch.is_tensor(dt):
assert dt < 0.0, "dt must be negative for reverse diffusion!"
dt = torch.tensor(dt, device=xt.device, dtype=xt.dtype)
drift_coeff, g_squared, diffusion_coeff = self.get_reverse_coeffs(t)
# broadcast to match xt shape
drift_coeff = self._broadcast_to_shape(drift_coeff, xt.shape)
g_squared = self._broadcast_to_shape(g_squared, xt.shape)
diffusion_coeff = self._broadcast_to_shape(diffusion_coeff, xt.shape)
if self.method == "ve":
sigma_t = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
std = sigma_t
else:
std = self.vs.std(t)
while std.dim() < len(xt.shape):
std = std.unsqueeze(-1)
if self.pred_type == "noise":
score = -pred / (std + self.eps)
elif self.pred_type == "score":
score = pred
# [-½β(t)x - β(t)∇log p_t(x)]dt + √β(t)dw̄
# reverse drift: f(x,t) - g²(t)·score
if self.method == 'ode':
drift = drift_coeff * xt - 0.5 * g_squared * score
else:
drift = drift_coeff * xt - g_squared * score
if last_step or self.method == "ode":
noise = torch.zeros_like(xt)
else:
noise = torch.randn_like(xt)
diffusion = diffusion_coeff * noise
# Euler-Maruyama step
x_prev = xt + drift * dt + diffusion * torch.sqrt(torch.abs(dt))
return x_prev
###==================================================================================================================###
[docs]
class SchedulerSDE(nn.Module):
"""
Continuous-time variance (noise) scheduler for diffusion models formulated
as stochastic differential equations (SDEs).
This class defines the time-dependent noise rate β(t) and its derived
quantities used in forward diffusion processes of the form:
dx = -½ β(t) x dt + √β(t) dW_t
where t ∈ [0, 1] is continuous time and W_t is standard Brownian motion.
Supported schedules:
• Linear schedule:
β(t) = β_min + t (β_max - β_min)
• Cosine schedule:
Defined implicitly via the cumulative signal power
ᾱ(t) = cos²((t + s) / (1 + s) · π / 2),
following Nichol & Dhariwal (2021).
The scheduler provides convenient access to commonly used quantities:
• β(t) — instantaneous noise rate
• ∫₀ᵗ β(s) ds — cumulative noise
• α(t) — signal scaling factor
• σ²(t) — noise variance
• SNR(t) — signal-to-noise ratio
All methods operate on PyTorch tensors and support broadcasting.
Parameters
----------
schedule_type : str, default="linear"
Type of noise schedule. Must be one of {"linear", "cosine"}.
beta_min : float, default=0.1
Minimum noise rate for the linear schedule. Must satisfy
0 < beta_min < beta_max. Ignored for cosine schedule.
beta_max : float, default=20.0
Maximum noise rate for the linear schedule. Ignored for cosine schedule.
cosine_s : float, default=0.008
Small offset used in the cosine schedule to prevent singularities
near t = 0. Matches the formulation from improved DDPMs.
Notes
-----
• Time t is assumed to be normalized to [0, 1].
• α(t) and σ(t) satisfy:
α²(t) + σ²(t) = 1
for both schedules.
• The cosine schedule defines β(t) implicitly through α²(t); the β(t)
returned in this case is an approximation derived from finite differences.
References
----------
- Ho et al., "Denoising Diffusion Probabilistic Models", NeurIPS 2020
- Song et al., "Score-Based Generative Modeling through SDEs", ICLR 2021
- Nichol & Dhariwal, "Improved Denoising Diffusion Probabilistic Models", ICML 2021
"""
def __init__(
self,
schedule_type: str = "linear",
beta_min: float = 0.1,
beta_max: float = 20.0,
cosine_s: float = 0.008
):
super().__init__()
valid_schedules = ["linear", "cosine"]
if schedule_type not in valid_schedules:
raise ValueError(f"schedule_type must be one of {valid_schedules}, got {schedule_type}")
self.schedule_type = schedule_type
self.beta_min = beta_min
self.beta_max = beta_max
self.cosine_s = cosine_s
if schedule_type == "linear" and not (0.0 < beta_min < beta_max):
raise ValueError("For linear schedule, require 0 < beta_min < beta_max")
[docs]
def beta(self, t: torch.Tensor) -> torch.Tensor:
"""β(t) - noise schedule"""
if self.schedule_type == "linear":
return self.beta_min + t * (self.beta_max - self.beta_min)
elif self.schedule_type == "cosine":
# β(t) = -d/dt log ᾱ(t) = tan(x) * π / (1+s)
t_mapped = (t + self.cosine_s) / (1 + self.cosine_s) * torch.pi / 2
beta_t = torch.tan(t_mapped) * (torch.pi / (1 + self.cosine_s))
return torch.clamp(beta_t, min=0.0, max=1000.0)
[docs]
def integral_beta(self, t: torch.Tensor) -> torch.Tensor:
"""∫₀ᵗ β(s) ds"""
if self.schedule_type == "linear":
return self.beta_min * t + 0.5 * (self.beta_max - self.beta_min) * t ** 2
elif self.schedule_type == "cosine":
return -torch.log(self.alpha_squared(t))
def _cosine_alpha_bar(self, t: torch.Tensor) -> torch.Tensor:
"""ᾱ(t) = cos²((t+s)/(1+s) · π/2) for cosine schedule"""
return torch.cos((t + self.cosine_s) / (1 + self.cosine_s) * torch.pi / 2) ** 2
[docs]
def alpha(self, t: torch.Tensor) -> torch.Tensor:
"""α(t) = exp(-½∫₀ᵗ β(s) ds)"""
if self.schedule_type == "cosine":
return torch.sqrt(self.alpha_squared(t))
return torch.exp(-0.5 * self.integral_beta(t))
[docs]
def alpha_squared(self, t: torch.Tensor) -> torch.Tensor:
"""α²(t) = exp(-∫₀ᵗ β(s) ds)"""
if self.schedule_type == "cosine":
return self._cosine_alpha_bar(t) / self._cosine_alpha_bar(torch.zeros_like(t))
return torch.exp(-self.integral_beta(t))
[docs]
def variance(self, t: torch.Tensor) -> torch.Tensor:
"""σ²(t) = 1 - α²(t)"""
return 1.0 - self.alpha_squared(t)
[docs]
def std(self, t: torch.Tensor) -> torch.Tensor:
"""σ(t) = √(1 - α²(t))"""
return torch.sqrt(self.variance(t))
[docs]
def snr(self, t: torch.Tensor) -> torch.Tensor:
"""signal-to-noise ratio: SNR(t) = α²(t) / σ²(t)"""
alpha_sq = self.alpha_squared(t)
var = self.variance(t)
return alpha_sq / (var + 1e-8)
###==================================================================================================================###
[docs]
class TrainSDE(nn.Module):
"""Trainer for score-based generative models using Stochastic Differential Equations.
Manages the training process for SDE-based generative models, optimizing a noise
predictor to learn the noise added by the forward SDE process, as described in Song
et al. (2021). Supports conditional training with text prompts, mixed precision,
learning rate scheduling, early stopping, and checkpointing.
Parameters
----------
score_net : nn.Module
Model to predict score/noise.
fwd_sde : nn.Module
Forward SDE diffusion module for adding noise.
rwd_sde: nn.Module
Reverse SDE diffusion module for denoising.
train_loader : torch.utils.data.DataLoader
DataLoader for training data.
optim : torch.optim.Optimizer
Optimizer for training the noise predictor and conditional model (if applicable).
loss_fn : callable
Loss function to compute the difference between predicted and actual noise.
val_loader : torch.utils.data.DataLoader, optional
DataLoader for validation data, default None.
max_epochs : int, optional
Maximum number of training epochs (default: 1000).
device : torch.device, optional
Device for computation (default: CUDA if available, else CPU).
cond_net : nn.Module, optional
Model for conditional generation (e.g., text embeddings), default None.
metrics_ : object, optional
Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
tokenizer : BertTokenizer, optional
Tokenizer for processing text prompts, default None (loads "bert-base-uncased").
max_token_length : int, optional
Maximum length for tokenized prompts (default: 77).
store_path : str, optional
Path to save model checkpoints (default: "sde_train").
patience : int, optional
Number of epochs to wait for improvement before early stopping (default: 20).
warmup_steps : int, optional
Number of steps for learning rate warmup (default: 1000).
val_freq : int, optional
Frequency (in epochs) for validation (default: 10).
norm_range : tuple, optional
Range for clamping generated images (default: (-1, 1)).
norm_output : bool, optional
Whether to normalize generated images to [0, 1] for metrics (default: True).
use_ddp : bool, optional
Whether to use Distributed Data Parallel training (default: False).
grad_acc : int, optional
Number of gradient accumulation steps before optimizer update (default: 1).
log_freq : int, optional
Number of epochs before printing loss.
use_comp : bool, optional
whether the model is internally compiled using torch.compile (default: false)
time_eps: float, optional
lower bound for diffusion time sampling (time_eps, 1.0) (default: 1e-5)
num_steps: int, optional
number of time staps for sampling during validation (default: 400)
"""
def __init__(
self,
score_net: torch.nn.Module,
fwd_sde: torch.nn.Module,
rwd_sde: torch.nn.Module,
train_loader: torch.utils.data.DataLoader,
optim: torch.optim.Optimizer,
loss_fn: Callable,
val_loader: Optional[torch.utils.data.DataLoader] = None,
max_epochs: int = 1000,
device: str = 'cuda',
cond_net: Optional[torch.nn.Module] = None,
metrics_: Optional[Any] = None,
tokenizer: Optional[BertTokenizer] = None,
max_token_length: int = 77,
store_path: Optional[str] = None,
patience: int = 20,
warmup_steps: int = 1000,
val_freq: int = 10,
norm_range: Tuple[float, float] = (-1.0, 1.0),
norm_output: bool = True,
use_ddp: bool = False,
grad_acc: int = 1,
log_freq: int = 1,
use_comp: bool = False,
time_eps: float = 1e-5,
num_steps: int = 400,
use_amp: bool = False,
*args
) -> None:
super().__init__()
self.use_ddp = use_ddp
self.grad_acc = grad_acc
self.use_amp = use_amp
if isinstance(device, str):
self.device = torch.device(device)
else:
self.device = device
if self.use_ddp:
self._setup_ddp()
else:
self._setup_single_gpu()
self.score_net = score_net.to(self.device)
self.fwd_sde = fwd_sde.to(self.device)
self.rwd_sde = rwd_sde.to(self.device)
self.cond_net = cond_net.to(self.device) if cond_net else None
self.metrics_ = metrics_
self.optim = optim
self.loss_fn = LossAdapter(loss_fn) # wrap the loss in loss adapter to accept extra variables if it does not
self.store_path = store_path or "sde_train"
self.train_loader = train_loader
self.val_loader = val_loader
self.max_epochs = max_epochs
self.max_token_length = max_token_length
self.patience = patience
self.val_freq = val_freq
self.norm_range = norm_range
self.norm_output = norm_output
self.log_freq = log_freq
self.use_comp = use_comp
self.time_eps = time_eps
self.num_steps = num_steps
self.global_step = 0
self.warmup_steps = warmup_steps
self.best_loss = float('inf')
self.losses = {'train_losses': [], 'val_losses': []}
self.scheduler = ReduceLROnPlateau(
self.optim,
patience=self.patience,
factor=0.5
)
self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps)
self._device_type = self.device.type if hasattr(self.device, 'type') else ('cuda' if 'cuda' in str(self.device) else 'cpu')
if tokenizer is None:
try:
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
except Exception as e:
raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.")
else:
self.tokenizer = tokenizer
def _setup_ddp(self) -> None:
"""Setup Distributed Data Parallel training configuration.
Initializes process group, determines rank information, and sets up
CUDA device for the current process.
"""
if "RANK" not in os.environ:
raise ValueError("DDP enabled but RANK environment variable not set")
if "LOCAL_RANK" not in os.environ:
raise ValueError("DDP enabled but LOCAL_RANK environment variable not set")
if "WORLD_SIZE" not in os.environ:
raise ValueError("DDP enabled but WORLD_SIZE environment variable not set")
if not torch.distributed.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
init_process_group(backend=backend)
# get rank info
self.ddp_rank = int(os.environ["RANK"])
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
self.device = torch.device(f"cuda:{self.ddp_local_rank}")
torch.cuda.set_device(self.device)
else:
self.device = torch.device("cpu")
self.master_process = self.ddp_rank == 0
if self.master_process:
print(f"DDP initialized with world_size={self.ddp_world_size}")
def _setup_single_gpu(self) -> None:
"""Setup single GPU or CPU training configuration."""
self.ddp_rank = 0
self.ddp_local_rank = 0
self.ddp_world_size = 1
self.master_process = True
[docs]
def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]:
"""Loads a training checkpoint to resume training.
Restores the state of the noise predictor, conditional model (if applicable),
and optimizer from a saved checkpoint. Handles DDP model state dict loading.
Parameters
----------
checkpoint_path : str
Path to the checkpoint file.
Returns
-------
epoch : int
The epoch at which the checkpoint was saved.
loss : float
The loss at the checkpoint.
"""
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
except FileNotFoundError:
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
if 'model_state_dict_score_net' not in checkpoint:
raise KeyError("Checkpoint missing 'model_state_dict_score_net' key")
state_dict = checkpoint['model_state_dict_score_net']
if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {f'module.{k}': v for k, v in state_dict.items()}
elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()):
state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
self.score_net.load_state_dict(state_dict)
if self.cond_net is not None:
if 'model_state_dict_cond' in checkpoint and checkpoint['model_state_dict_cond'] is not None:
cond_state_dict = checkpoint['model_state_dict_cond']
if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()):
cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()}
elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()):
cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()}
self.cond_net.load_state_dict(cond_state_dict)
else:
warnings.warn(
"Checkpoint contains no 'model_state_dict_cond' or it is None, "
"skipping conditional model loading"
)
if 'scheduler_model' not in checkpoint:
raise KeyError("Checkpoint missing 'scheduler_model' key")
try:
if isinstance(self.fwd_sde.vs, nn.Module):
self.fwd_sde.vs.load_state_dict(checkpoint['scheduler_model'])
if isinstance(self.rwd_sde.vs, nn.Module):
self.rwd_sde.vs.load_state_dict(checkpoint['scheduler_model'])
else:
self.fwd_sde.vs = checkpoint['scheduler_model']
self.rwd_sde.vs = checkpoint['scheduler_model']
except Exception as e:
warnings.warn(f"Scheduler loading failed: {e}. Continuing with current scheduler.")
if 'optim_state_dict' not in checkpoint:
raise KeyError("Checkpoint missing 'optim_state_dict' key")
try:
self.optim.load_state_dict(checkpoint['optim_state_dict'])
except ValueError as e:
warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.")
epoch = checkpoint.get('epoch', -1)
loss = checkpoint.get('loss', float('inf'))
if self.master_process:
print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}")
return epoch, loss
[docs]
@staticmethod
def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) -> torch.optim.lr_scheduler.LambdaLR:
"""Creates a learning rate scheduler for warmup.
Generates a scheduler that linearly increases the learning rate from 0 to the
optimizer's initial value over the specified warmup epochs, then maintains it.
Parameters
----------
optimizer : torch.optim.Optimizer
Optimizer to apply the scheduler to.
warmup_steps : int
Number of steps for the warmup phase.
Returns
-------
torch.optim.lr_scheduler.LambdaLR
Learning rate scheduler for warmup.
"""
def lr_lambda(step):
if step < warmup_steps:
return 0.1 + (0.9 * step / warmup_steps)
return 1.0
return LambdaLR(optimizer, lr_lambda)
def _wrap_models_for_ddp(self) -> None:
"""Wrap models with DistributedDataParallel for multi-GPU training."""
if self.use_ddp:
ddp_kwargs = dict(find_unused_parameters=False)
if self._device_type == 'cuda':
ddp_kwargs['device_ids'] = [self.ddp_local_rank]
self.score_net = DDP(self.score_net, **ddp_kwargs)
if self.cond_net is not None:
self.cond_net = DDP(self.cond_net, **ddp_kwargs)
[docs]
def forward(self) -> Dict:
"""Trains the SDE model to predict noise added by the forward diffusion process.
Executes the training loop, optimizing the noise predictor and conditional model
(if applicable) using mixed precision, gradient clipping, and learning rate
scheduling. Supports validation, early stopping, and checkpointing.
Returns
-------
losses : dictionary of train and validation losses.
**Notes**
- Training uses mixed precision via `torch.cuda.amp` or `torch.amp` for efficiency.
- Checkpoints are saved when the validation (or training) loss improves, and on early stopping.
- Early stopping is triggered if no improvement occurs for `patience` epochs.
"""
self.score_net.train()
if self.cond_net is not None:
self.cond_net.train()
if self.use_comp:
try:
self.score_net = torch.compile(self.score_net)
if self.cond_net is not None:
self.cond_net = torch.compile(self.cond_net)
except Exception as e:
if self.master_process:
print(f"Model compilation failed: {e}. Continuing without compilation.")
self._wrap_models_for_ddp()
use_amp = self.use_amp and self._device_type == 'cuda'
scaler = torch.amp.GradScaler(self._device_type, enabled=use_amp)
wait = 0
for epoch in range(self.max_epochs):
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not self.master_process)
if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'):
self.train_loader.sampler.set_epoch(epoch)
train_losses_epoch = []
for step, (x, y) in enumerate(pbar):
x = x.to(self.device)
if self.cond_net is not None:
y_encoded = self._process_conditional_input(y)
else:
y_encoded = None
with torch.autocast(device_type=self._device_type, enabled=use_amp):
noise = torch.randn_like(x)
t = self.sample_time(x.shape[0], self.time_eps)
xt, target = self.fwd_sde(x, t, noise)
pred = self.score_net(xt, t, y_encoded, clip_embeddings=None)
var = self.fwd_sde.vs.variance(t)
if self.fwd_sde.method == "ve":
sigma = self.fwd_sde.sigma_min * (self.fwd_sde.sigma_max / self.fwd_sde.sigma_min) ** t
loss = self.loss_fn(pred, target, sigma) / self.grad_acc
else:
loss = self.loss_fn(pred, target, var) / self.grad_acc
scaler.scale(loss).backward()
if (step + 1) % self.grad_acc == 0:
scaler.unscale_(self.optim)
torch.nn.utils.clip_grad_norm_(self.score_net.parameters(), max_norm=1.0)
if self.cond_net is not None:
torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0)
scaler.step(self.optim)
scaler.update()
self.optim.zero_grad(set_to_none=True)
if self.global_step < self.warmup_steps:
self.warmup_lr_scheduler.step()
self.global_step += 1
pbar.set_postfix({'Loss': f'{loss.item() * self.grad_acc:.4f}'})
train_losses_epoch.append(loss.item() * self.grad_acc)
mean_train_loss = torch.tensor(train_losses_epoch).mean().item()
self.losses['train_losses'].append(mean_train_loss)
if self.use_ddp:
loss_tensor = torch.tensor(mean_train_loss, device=self.device)
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
mean_train_loss = loss_tensor.item()
if self.master_process and (epoch + 1) % self.log_freq == 0:
current_lr = self.optim.param_groups[0]['lr']
print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}")
if self.val_loader is not None and (epoch + 1) % self.val_freq == 0:
val_metrics = self.validate()
val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics
if self.master_process:
print(f" | Val Loss: {val_loss:.4f}", end="")
if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid:
print(f" | FID: {fid:.4f}", end="")
if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="")
if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
print(f" | LPIPS: {lpips_score:.4f}", end="")
print()
self.scheduler.step(val_loss)
self.losses['val_losses'].append((val_loss, fid, mse, psnr, ssim, lpips_score))
else:
if self.master_process:
print()
self.scheduler.step(mean_train_loss)
if self.master_process:
if mean_train_loss < self.best_loss:
self.best_loss = mean_train_loss
wait = 0
self._save_checkpoint(epoch + 1, self.best_loss, "best_")
else:
wait += 1
if wait >= self.patience:
print("Early stopping triggered")
self._save_checkpoint(epoch + 1, mean_train_loss, "early_stop_")
break
if (epoch + 1) % self.val_freq == 0:
self._save_checkpoint(epoch + 1, mean_train_loss, "")
if self.use_ddp:
destroy_process_group()
return self.losses
[docs]
def sample_time(self, batch_size: int, eps: float = 1e-5) -> torch.Tensor:
return eps + (1 - eps) * torch.rand(batch_size, device=self.device)
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor:
"""Process conditional input for text-to-image generation.
Parameters
----------
y : torch.Tensor or list
Conditional input (text prompts).
Returns
-------
torch.Tensor
Encoded conditional input.
"""
y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y
y_list = [str(item) for item in y_list]
y_encoded = self.tokenizer(
y_list,
padding="max_length",
truncation=True,
max_length=self.max_token_length,
return_tensors="pt"
).to(self.device)
input_ids = y_encoded["input_ids"]
attention_mask = y_encoded["attention_mask"]
y_encoded = self.cond_net(input_ids, attention_mask)
return y_encoded
def _save_checkpoint(self, epoch: int, loss: float, pref: str = "") -> None:
"""Save model checkpoint (only called by master process).
Parameters
----------
epoch : int
Current epoch number.
loss : float
Current loss value.
pref : str, optional
pref to add to checkpoint filename.
"""
try:
score_net_state = (
self.score_net.module.state_dict() if self.use_ddp
else self.score_net.state_dict()
)
cond_state = None
if self.cond_net is not None:
cond_state = (
self.cond_net.module.state_dict() if self.use_ddp
else self.cond_net.state_dict()
)
checkpoint = {
'epoch': epoch,
'model_state_dict_score_net': score_net_state,
'model_state_dict_cond': cond_state,
'optim_state_dict': self.optim.state_dict(),
'loss': loss,
'losses': self.losses,
'scheduler_model': self.fwd_sde.vs.state_dict(),
'max_epochs': self.max_epochs,
}
filename = f"{pref}model_epoch_{epoch}.pth"
filepath = os.path.join(self.store_path, filename)
os.makedirs(self.store_path, exist_ok=True)
torch.save(checkpoint, filepath)
print(f"Model saved at epoch {epoch} with loss: {loss:.4f}")
except Exception as e:
print(f"Failed to save model: {e}")
[docs]
def validate(self) -> Tuple[float, float, float, float, float, float]:
"""Validates the noise predictor and computes evaluation Metrics.
Computes validation loss (MSE between predicted and ground truth noise) and generates
samples using the reverse diffusion model by manually iterating over timesteps.
Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS)
if metrics_ is provided.
Returns
-------
val_loss : float
Mean validation loss.
fid : float, or `float('inf')` if not computed
Mean FID score.
mse : float, or None if not computed
Mean MSE
psnr : float, or None if not computed
Mean PSNR
ssim : float, or None if not computed
Mean SSIM
lpips_score : float, or None if not computed
Mean LPIPS score
"""
self.score_net.eval()
if self.cond_net is not None:
self.cond_net.eval()
val_losses = []
fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], []
with torch.no_grad():
for x, y in self.val_loader:
x = x.to(self.device)
x_orig = x.clone()
if self.cond_net is not None:
y_encoded = self._process_conditional_input(y)
else:
y_encoded = None
noise = torch.randn_like(x)
t = self.sample_time(x.shape[0], self.time_eps)
xt, target = self.fwd_sde(x, t, noise)
pred = self.score_net(xt, t, y_encoded, clip_embeddings=None)
var = self.fwd_sde.vs.variance(t)
if self.fwd_sde.method == "ve":
sigma = self.fwd_sde.sigma_min * (self.fwd_sde.sigma_max / self.fwd_sde.sigma_min) ** t
loss = self.loss_fn(pred, target, sigma) / self.grad_acc
else:
loss = self.loss_fn(pred, target, var) / self.grad_acc
val_losses.append(loss.item())
if self.metrics_ is not None and self.rwd_sde is not None:
xt = torch.randn(x.shape, device=self.device)
# reverse diffusion sampling
t_schedule = torch.linspace(1.0, self.time_eps, self.num_steps + 1, device=self.device)
dt = torch.tensor(-(1.0 - self.time_eps) / self.num_steps, device=self.device, dtype=xt.dtype)
for t in range(self.num_steps):
t_current = float(t_schedule[t])
t_batch = torch.full((xt.shape[0],), t_current, dtype=xt.dtype, device=self.device)
pred = self.score_net(xt, t_batch, y_encoded, None)
last_step = (t == self.num_steps - 1)
xt = self.rwd_sde(xt, pred, t_batch, dt, last_step=last_step)
x_hat = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
if self.norm_output:
x_hat = (x_hat - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0])
x_orig = (x_orig - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0])
metrics_result = self.metrics_.forward(x_orig, x_hat)
fid, mse, psnr, ssim, lpips_score = metrics_result
if hasattr(self.metrics_, 'fid') and self.metrics_.fid:
fid_scores.append(fid)
if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics:
mse_scores.append(mse)
psnr_scores.append(psnr)
ssim_scores.append(ssim)
if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips:
lpips_scores.append(lpips_score)
val_loss = torch.tensor(val_losses).mean().item()
if self.use_ddp:
val_loss_tensor = torch.tensor(val_loss, device=self.device)
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG)
val_loss = val_loss_tensor.item()
fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf')
mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None
psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None
ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None
lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None
self.score_net.train()
if self.cond_net is not None:
self.cond_net.train()
return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
###==================================================================================================================###
[docs]
class SampleSDE(nn.Module):
"""Sampler for generating images using SDE-based generative models.
Generates images by iteratively denoising random noise using the reverse SDE process
and a trained noise predictor, as described in Song et al. (2021). Supports both
unconditional and conditional generation with text prompts.
Parameters
----------
rwd_sde : ReverseSDE
Reverse SDE diffusion module for denoising.
score_net : nn.Module
Model to predict noise added during the forward SDE process.
img_size : tuple
Shape of generated images as (height, width).
cond_net : nn.Module, optional
Model for conditional generation (e.g., TextEncoder), default None.
tokenizer : str or BertTokenizer, optional
Tokenizer for processing text prompts, default "bert-base-uncased".
max_token_length : int, optional
Maximum length for tokenized prompts (default: 77).
batch_size : int, optional
Number of images to generate per batch (default: 1).
in_channels : int, optional
Number of input channels for generated images (default: 3).
device : srt, optional
Device for computation (default: CUDA).
norm_range : tuple, optional
Range for clamping generated images (min, max), default (-1, 1).
"""
def __init__(
self,
rwd_sde: torch.nn.Module,
score_net: torch.nn.Module,
img_size: Tuple[int, int],
cond_net: Optional[torch.nn.Module] = None,
tokenizer: str = "bert-base-uncased",
max_token_length: int = 77,
batch_size: int = 1,
in_channels: int = 3,
device: str = 'cuda',
norm_range: Tuple[float, float] = (-1.0, 1.0),
time_eps: float = 1e-5
) -> None:
super().__init__()
if isinstance(device, str):
self.device = torch.device(device)
else:
self.device = device
self.rwd_sde = rwd_sde.to(self.device)
self.score_net = score_net.to(self.device)
self.cond_net = cond_net.to(self.device) if cond_net else None
self.tokenizer = BertTokenizer.from_pretrained(tokenizer)
self.max_token_length = max_token_length
self.in_channels = in_channels
self.img_size = img_size
self.batch_size = batch_size
self.norm_range = norm_range
self.time_eps = time_eps
if not isinstance(img_size, (tuple, list)) or len(img_size) != 2 or not all(isinstance(s, int) and s > 0 for s in img_size):
raise ValueError("img_size must be a tuple of two positive integers (height, width)")
if batch_size <= 0:
raise ValueError("batch_size must be positive")
if not isinstance(norm_range, (tuple, list)) or len(norm_range) != 2 or norm_range[0] >= norm_range[1]:
raise ValueError("norm_range must be a tuple (min, max) with min < max")
[docs]
def tokenize(self, prompts: Union[str, List]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Tokenizes text prompts for conditional generation.
Converts input prompts into tokenized tensors using the specified tokenizer.
Parameters
----------
prompts : str or list
Text prompt(s) for conditional generation. Can be a single string or a list
of strings.
Returns
-------
input_ids : torch.Tensor
Tokenized input IDs, shape (batch_size, max_token_length).
attention_mask : torch.Tensor
Attention mask, shape (batch_size, max_token_length).
"""
if isinstance(prompts, str):
prompts = [prompts]
elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts):
raise TypeError("prompts must be a string or list of strings")
encoded = self.tokenizer(
prompts,
padding="max_length",
truncation=True,
max_length=self.max_token_length,
return_tensors="pt"
)
return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
[docs]
def forward(
self,
num_steps: int,
conds: Optional[Union[str, List]] = None,
norm_output: bool = True,
save_imgs: bool = True,
save_path: str = "sde_samples"
) -> torch.Tensor:
"""Generates images using the reverse SDE sampling process.
Iteratively denoises random noise to generate images using the reverse SDE process
and noise predictor. Supports conditional generation with text prompts.
Parameters
----------
conds : str or list, optional
Text prompt(s) for conditional generation, default None.
norm_output : bool, optional
If True, normalizes output images to [0, 1] (default: True).
save_imgs : bool, optional
If True, saves generated images to `save_path` (default: True).
save_path : str, optional
Directory to save generated images (default: "sde_samples").
Returns
-------
samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width).
If `norm_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `norm_range`.
"""
if conds is not None and self.cond_net is None:
raise ValueError("Conditions provided but no conditional model specified")
if conds is None and self.cond_net is not None:
raise ValueError("Conditions must be provided for conditional model")
init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1], device=self.device)
self.score_net.eval()
self.rwd_sde.eval()
if self.cond_net:
self.cond_net.eval()
if self.cond_net is not None and conds is not None:
input_ids, attention_masks = self.tokenize(conds)
key_padding_mask = (attention_masks == 0)
y = self.cond_net(input_ids, key_padding_mask)
else:
y = None
t_schedule = torch.linspace(1.0, self.time_eps, num_steps + 1, device=self.device)
dt = -(1.0 - self.time_eps) / num_steps
iterator = tqdm(
range(num_steps),
total=num_steps,
desc="Sampling",
dynamic_ncols=True,
leave=True,
)
#iterator = tqdm(range(num_steps), desc="Sampling")
with torch.no_grad():
xt = init_samps
for step in iterator:
t_current = float(t_schedule[step])
t_batch = torch.full((self.batch_size,), t_current, dtype=xt.dtype, device=self.device)
pred = self.score_net(xt, t_batch, y)
last_step = (step == num_steps - 1)
xt = self.rwd_sde(xt, pred, t_batch, dt, last_step = last_step)
samps = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1])
if norm_output:
samps = (samps - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0])
if save_imgs:
os.makedirs(save_path, exist_ok=True)
for i in range(samps.size(0)):
img_path = os.path.join(save_path, f"img_{i+1}.png")
save_image(samps[i], img_path)
return samps
[docs]
def to(self, device: torch.device) -> Self:
"""Moves the module and its components to the specified device.
Updates the device attribute and moves the reverse diffusion, noise predictor,
and conditional model (if present) to the specified device.
Parameters
----------
device : torch.device
Target device for the module and its components.
Returns
-------
sample_sde (SampleSDE) - moved to the specified device.
"""
self.device = device
self.score_net.to(device)
self.rwd_sde.to(device)
if self.cond_net:
self.cond_net.to(device)
return super().to(device)