Source code for torchdiff.ddpm

"""
**Denoising Diffusion Probabilistic Models (DDPM) implementation**

This module provides a complete implementation of DDPM, as described in Ho et al.
(2020, "Denoising Diffusion Probabilistic Models"). It includes components for forward
and reverse diffusion processes, hyperparameter management, training, and image
sampling. Supports both unconditional and conditional generation with text prompts.

**Components**

- **ForwardDDPM**: Forward diffusion process to add noise.
- **ReverseDDPM**: Reverse diffusion process to denoise.
- **SchedulerDDPM**: Noise schedule management.
- **TrainDDPM**: Training loop with mixed precision and scheduling.
- **SampleDDPM**: Image generation from trained models.

**References**

- Ho, J., Jain, A., & Abbeel, P. (2020). Denoising Diffusion Probabilistic Models.

- Salimans, Tim, et al. "Pixelcnn++: Improving the pixelcnn with discretized logistic mixture likelihood and other modifications."
arXiv preprint arXiv:1701.05517 (2017).

-------------------------------------------------------------------------------
"""



import torch
import torch.nn as nn
from typing import Optional, Tuple, Callable, List, Any, Union, Dict
from typing_extensions import Self
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from tqdm import tqdm
from transformers import BertTokenizer
import warnings
from torchvision.utils import save_image
from .utils import LossAdapter
import os

__all__ = [
    "ForwardDDPM",
    "ReverseDDPM",
    "SchedulerDDPM",
    "TrainDDPM",
    "SampleDDPM",
]


###==================================================================================================================###


[docs] class ForwardDDPM(nn.Module): """ Forward diffusion process for DDPM. Implements sampling from the forward noising distribution: q(x_t | x_0) = N(√ᾱ_t x_0, (1 - ᾱ_t) I) Also computes the appropriate training target depending on the chosen prediction parameterization (x0 or v). """ def __init__(self, scheduler: nn.Module, pred_type: str = "noise") -> None: """ Initialize the forward diffusion process. Args: scheduler: Noise scheduler providing diffusion coefficients. pred_type: Prediction parameterization. One of {"noise", "x0", "v"}. """ super().__init__() valid_types = ["noise","x0", "v"] if pred_type not in valid_types: raise ValueError(f"prediction_type must be one of {valid_types}, got {pred_type}") self.vs = scheduler self.pred_type = pred_type
[docs] def forward( self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Sample a noised version of the input and compute the training target. Args: x0: Clean input data of shape (batch, ...). t: Discrete timesteps of shape (batch,), with values in [0, T-1]. noise: Standard Gaussian noise of the same shape as x0. Returns: xt: Noised data sampled from q(x_t | x_0). target: Training target corresponding to the selected prediction type (x0 or v). """ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t] sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t] sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, x0.shape) sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, x0.shape) # x_t ~ q(x_t | x_0) # x_t = √ᾱ_t * x_0 + √(1 - ᾱ_t) * ε xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise if self.pred_type == 'noise': target = noise elif self.pred_type == "x0": target = x0 elif self.pred_type == "v": # v-prediction: v = √ᾱ_t * ε - √(1 - ᾱ_t) * x_0 target = sqrt_alpha_cumprod_t * noise - sqrt_one_minus_alpha_cumprod_t * x0 return xt, target
###==================================================================================================================###
[docs] class ReverseDDPM(nn.Module): """ Reverse diffusion process for DDPM. Implements a single reverse denoising step: p_θ(x_{t-1} | x_t) = N(μ_θ(x_t, t), Σ_t) Supports different prediction parameterizations (noise, x0, v) and multiple variance types (fixed or learned). """ def __init__( self, scheduler: nn.Module, pred_type: str = "noise", var_type: str = "fixed_small", clip_out: bool = True ) -> None: """ Initialize the reverse diffusion process. Args: scheduler: Noise scheduler providing diffusion coefficients. pred_type: Model prediction parameterization. One of {"noise", "x0", "v"}. var_type: Variance type used in the reverse process. One of {"fixed_small", "fixed_large", "learned"}. clip_out: Whether to clip predicted x0 to a fixed range. """ super().__init__() valid_pred_types = ["noise", "x0", "v"] valid_var_types = ["fixed_small", "fixed_large", "learned"] if pred_type not in valid_pred_types: raise ValueError(f"pred_type must be one of {valid_pred_types}") if var_type not in valid_var_types: raise ValueError(f"var_type must be one of {valid_var_types}") self.vs = scheduler self.pred_type = pred_type self.var_type = var_type self.clip_out = clip_out
[docs] def predict_x0(self, xt: torch.Tensor, t: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: """ Convert the model output into a prediction of the original data x0. Args: xt: Current noised data x_t. t: Discrete timesteps of shape (batch,). pred: Model output corresponding to the selected prediction type. Returns: Predicted clean data x0. """ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t] sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t] sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, xt.shape) sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, xt.shape) if self.pred_type == "noise": # x_0 = (x_t - √(1 - ᾱ_t) * ε_θ) / √ᾱ_t x0 = (xt - sqrt_one_minus_alpha_cumprod_t * pred) / sqrt_alpha_cumprod_t elif self.pred_type == "x0": # directly predict x_0 x0 = pred elif self.pred_type == "v": # x_0 = √ᾱ_t * x_t - √(1 - ᾱ_t) * v_θ x0 = sqrt_alpha_cumprod_t * xt - sqrt_one_minus_alpha_cumprod_t * pred if self.clip_out: x0 = torch.clamp(x0, -1.0, 1.0) return x0
[docs] def get_variance(self, t: torch.Tensor, pred_var: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute the variance used in the reverse diffusion step. Args: t: Discrete timesteps of shape (batch,). pred_var: Optional model-predicted variance (required when var_type="learned"). Returns: Variance tensor for the reverse transition. """ if self.var_type == "fixed_small": # posterior variance: β_t * (1 - ᾱ_{t-1}) / (1 - ᾱ_t) var = self.vs.posterior_variance[t] elif self.var_type == "fixed_large": # β_t var = self.vs.betas[t] elif self.var_type == "learned": # model-predicted variance if pred_var is None: raise ValueError("predicted_variance must be provided when variance_type='learned'") # interpolate between fixed_small and fixed_large min_log = self.vs.posterior_log_variance[t] max_log = torch.log(self.vs.betas[t]) frac = (pred_var + 1) / 2 # map from [-1, 1] to [0, 1] var = torch.exp(frac * max_log + (1 - frac) * min_log) return var
[docs] def forward( self, xt: torch.Tensor, pred: torch.Tensor, t: torch.Tensor, pred_var: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Perform a single reverse diffusion step from x_t to x_{t-1}. Args: xt: Current state x_t of shape (batch, ...). pred: Model prediction at timestep t. t: Discrete timesteps of shape (batch,). pred_var: Optional predicted variance for learned variance models. Returns: x_prev: Sampled previous state x_{t-1}. pred_x0: Predicted clean data x0. """ # predict x_0 from model output pred_x0 = self.predict_x0(xt, t, pred) # get posterior mean coefficients coef1 = self.vs.posterior_mean_coef1[t] coef2 = self.vs.posterior_mean_coef2[t] coef1 = self.vs.get_index(coef1, xt.shape) coef2 = self.vs.get_index(coef2, xt.shape) # posterior mean: μ_θ(x_t, t) = coef1 * x_0 + coef2 * x_t posterior_mean = coef1 * pred_x0 + coef2 * xt # variance if self.var_type == "fixed_small": # use precomputed sqrt for fixed_small (most common case) sqrt_var = self.vs.sqrt_posterior_variance[t] sqrt_var = self.vs.get_index(sqrt_var, xt.shape) elif self.var_type == "fixed_large": sqrt_var = self.vs.sqrt_betas[t] sqrt_var = self.vs.get_index(sqrt_var, xt.shape) else: variance = self.get_variance(t, pred_var) variance = self.vs.get_index(variance, xt.shape) sqrt_var = torch.sqrt(variance) # sample noise (no noise for t=0) noise = torch.randn_like(xt) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(xt.shape) - 1))) # sample x_{t-1} ~ p_θ(x_{t-1} | x_t) x_prev = posterior_mean + nonzero_mask * sqrt_var * noise return x_prev, pred_x0
###==================================================================================================================###
[docs] class SchedulerDDPM(nn.Module): """ Noise scheduler for DDPM-style diffusion models. This class defines the discrete diffusion timeline and precomputes all noise schedule coefficients required for forward diffusion and reverse sampling, including betas, alphas, cumulative products, and posterior coefficients. Supported schedules include linear, cosine, quadratic, and sigmoid. The scheduler acts as the single source of truth for the diffusion horizon T and all time-dependent constants. """ def __init__( self, schedule_type: str = "linear", time_steps: int = 1000, beta_min: float = 0.0001, beta_max: float = 0.02, cosine_s: float = 0.008, clip_min: float = 0.0001, clip_max: float = 0.9999 ): """ Initialize the DDPM noise scheduler. Args: schedule_type: Type of beta schedule to use. One of {"linear", "cosine", "quadratic", "sigmoid"}. time_steps: Number of discrete diffusion steps (T). beta_min: Minimum beta value for applicable schedules. beta_max: Maximum beta value for applicable schedules. cosine_s: Small offset used in the cosine schedule. clip_min: Minimum value for clipping betas (cosine schedule). clip_max: Maximum value for clipping betas (cosine schedule). """ super().__init__() valid_schedules = ["linear", "cosine", "quadratic", "sigmoid"] if schedule_type not in valid_schedules: raise ValueError(f"schedule_type must be one of {valid_schedules}, got {schedule_type}") self.schedule_type = schedule_type self.time_steps = time_steps self.beta_min = beta_min self.beta_max = beta_max self.cosine_s = cosine_s self.clip_min = clip_min self.clip_max = clip_max self._setup_schedule() def _setup_schedule(self): """ Precompute the noise schedule and all derived diffusion coefficients. This method computes: - betas and alphas - cumulative products of alphas - coefficients for q(x_t | x_0) - coefficients for the reverse posterior q(x_{t-1} | x_t, x_0) All tensors are registered as buffers for correct device placement and checkpointing. """ if self.schedule_type == "linear": betas = torch.linspace(self.beta_min, self.beta_max, self.time_steps) elif self.schedule_type == "cosine": steps = self.time_steps + 1 t = torch.linspace(0, self.time_steps, steps) alphas_cumprod = torch.cos(((t / self.time_steps) + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = torch.clip(betas, self.clip_min, self.clip_max) elif self.schedule_type == "quadratic": betas = torch.linspace(self.beta_min ** 0.5, self.beta_max ** 0.5, self.time_steps) ** 2 elif self.schedule_type == "sigmoid": betas = torch.linspace(-6, 6, self.time_steps) betas = torch.sigmoid(betas) * (self.beta_max - self.beta_min) + self.beta_min # compute alphas alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) # compute coefficients for q(x_t | x_0) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) # compute coefficients for q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_log_variance = torch.log(torch.clamp(posterior_variance, min=1e-20)) posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_mean_coef2 = (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) # precompute square roots for reverse step efficiency sqrt_posterior_variance = torch.sqrt(torch.clamp(posterior_variance, min=1e-20)) sqrt_betas = torch.sqrt(betas) # register as buffers self.register_buffer('betas', betas) self.register_buffer('alphas', alphas) self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod) self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod) self.register_buffer('posterior_variance', posterior_variance) self.register_buffer('posterior_log_variance', posterior_log_variance) self.register_buffer('posterior_mean_coef1', posterior_mean_coef1) self.register_buffer('posterior_mean_coef2', posterior_mean_coef2) self.register_buffer('sqrt_posterior_variance', sqrt_posterior_variance) self.register_buffer('sqrt_betas', sqrt_betas)
[docs] def get_index(self, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: """ Reshape a timestep-dependent tensor for broadcasting over data tensors. Args: t: Tensor of shape (batch,) containing timestep-indexed values. x_shape: Shape of the target tensor to broadcast over. Returns: Tensor reshaped to (batch, 1, ..., 1) for broadcasting. """ batch_size = t.shape[0] return t.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
###==================================================================================================================###
[docs] class TrainDDPM(nn.Module): """Trainer for Denoising Diffusion Probabilistic Models (DDPM) with Multi-GPU Support. Manages the training process for DDPM, optimizing a noise predictor model to learn the noise added by the forward diffusion process. Supports conditional training with text prompts, mixed precision training, learning rate scheduling, early stopping, checkpointing, and distributed data parallel (DDP) training across multiple GPUs. Parameters ---------- diff_net : nn.Module Model to predict noise/v added during the forward diffusion process. fwd_ddpm : nn.Module Forward DDPM diffusion module for adding noise. rwd_ddpm: nn.Module Reverse DDPM diffusion module for denoising. train_loader : torch.utils.data.DataLoader DataLoader for training data. Should be wrapped with DistributedSampler for DDP. optim : torch.optim.Optimizer Optimizer for training the noise predictor and conditional model (if applicable). loss_fn : callable Loss function to compute the difference between predicted and actual noise. val_loader : torch.utils.data.DataLoader, optional DataLoader for validation data, default None. max_epochs : int, optional Maximum number of training epochs (default: 100). device : str Device for computation (default: CUDA). cond_net : nn.Module, optional Model for conditional generation (e.g., text embeddings), default None. metrics_ : object, optional Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None). tokenizer : BertTokenizer, optional Tokenizer for processing text prompts, default None (loads "bert-base-uncased"). max_token_length : int, optional Maximum length for tokenized prompts (default: 77). store_path : str, optional Path to save model checkpoints (default: "ddpm_train"). patience : int, optional Number of epochs to wait for improvement before early stopping (default: 20). warmup_steps : int, optional Number of epochs for learning rate warmup (default: 1000). val_freq : int, optional Frequency (in epochs) for validation (default: 10). norm_range : tuple, optional Range for clamping generated images (default: (-1, 1)). norm_output : bool, optional Whether to normalize generated images to [0, 1] for metrics (default: True). use_ddp : bool, optional Whether to use Distributed Data Parallel training (default: False). grad_acc : int, optional Number of gradient accumulation steps before optimizer update (default: 1). log_freq : int, optional Number of epochs before printing loss. use_comp : bool, optional whether the model is internally compiled using torch.compile (default: false) use_amp : bool, optional Whether to use automatic mixed precision (AMP) for training (default: False). Enable only on GPUs with good fp16 support (e.g., Ampere or newer). """ def __init__( self, diff_net: torch.nn.Module, fwd_ddpm: torch.nn.Module, rwd_ddpm: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optim: torch.optim.Optimizer, loss_fn: Callable, val_loader: Optional[torch.utils.data.DataLoader] = None, max_epochs: int = 100, device: str = 'cuda', cond_net: Optional[torch.nn.Module] = None, metrics_: Optional[Any] = None, tokenizer: Optional[BertTokenizer] = None, max_token_length: int = 77, store_path: Optional[str] = None, patience: int = 20, warmup_steps: int = 1000, val_freq: int = 10, norm_range: Tuple[float, float] = (-1.0, 1.0), norm_output: bool = True, use_ddp: bool = False, grad_acc: int = 1, log_freq: int = 1, use_comp: bool = False, use_amp: bool = False, *args ) -> None: super().__init__() self.use_ddp = use_ddp self.grad_acc = grad_acc self.use_amp = use_amp if isinstance(device, str): self.device = torch.device(device) else: self.device = device if self.use_ddp: self._setup_ddp() else: self._setup_single_gpu() self.diff_net = diff_net.to(self.device) self.fwd_ddpm = fwd_ddpm.to(self.device) self.rwd_ddpm = rwd_ddpm.to(self.device) self.cond_net = cond_net.to(self.device) if cond_net else None self.metrics_ = metrics_ self.optim = optim self.loss_fn = LossAdapter(loss_fn) self.store_path = store_path or "ddpm_train" self.train_loader = train_loader self.val_loader = val_loader self.max_epochs = max_epochs self.max_token_length = max_token_length self.patience = patience self.val_freq = val_freq self.norm_range = norm_range self.norm_output = norm_output self.log_freq = log_freq self.use_comp = use_comp self.global_step = 0 self.warmup_steps = warmup_steps self.best_loss = float('inf') self.losses = {'train_losses': [], 'val_losses': []} self.scheduler = ReduceLROnPlateau( self.optim, patience=self.patience, factor=0.5 ) self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps) self._device_type = self.device.type if hasattr(self.device, 'type') else ('cuda' if 'cuda' in str(self.device) else 'cpu') if tokenizer is None: try: self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") except Exception as e: raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.") else: self.tokenizer = tokenizer def _setup_ddp(self) -> None: """Setup Distributed Data Parallel training configuration. Initializes process group, determines rank information, and sets up CUDA device for the current process. """ if "RANK" not in os.environ: raise ValueError("DDP enabled but RANK environment variable not set") if "LOCAL_RANK" not in os.environ: raise ValueError("DDP enabled but LOCAL_RANK environment variable not set") if "WORLD_SIZE" not in os.environ: raise ValueError("DDP enabled but WORLD_SIZE environment variable not set") if not torch.cuda.is_available(): raise RuntimeError("DDP requires CUDA but CUDA is not available") if not torch.distributed.is_initialized(): init_process_group(backend="nccl") self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes self.device = torch.device(f"cuda:{self.ddp_local_rank}") torch.cuda.set_device(self.device) self.master_process = self.ddp_rank == 0 if self.master_process: print(f"DDP initialized with world_size={self.ddp_world_size}") def _setup_single_gpu(self) -> None: """Setup single GPU or CPU training configuration.""" self.ddp_rank = 0 self.ddp_local_rank = 0 self.ddp_world_size = 1 self.master_process = True
[docs] def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]: """Loads a training checkpoint to resume training. Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint. Handles DDP model state dict loading. Parameters ---------- checkpoint_path : str Path to the checkpoint file. Returns ------- epoch : int The epoch at which the checkpoint was saved. loss : float The loss at the checkpoint. """ try: # load checkpoint with proper device mapping checkpoint = torch.load(checkpoint_path, map_location=self.device) except FileNotFoundError: raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") if 'model_state_dict_diff_net' not in checkpoint: raise KeyError("Checkpoint missing 'model_state_dict_diff_net' key") state_dict = checkpoint['model_state_dict_diff_net'] if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()): state_dict = {f'module.{k}': v for k, v in state_dict.items()} elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()): state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} self.diff_net.load_state_dict(state_dict) if self.cond_net is not None: if 'model_state_dict_cond' in checkpoint and checkpoint['model_state_dict_cond'] is not None: cond_state_dict = checkpoint['model_state_dict_cond'] if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()): cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()} elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()): cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()} self.cond_net.load_state_dict(cond_state_dict) else: warnings.warn( "Checkpoint contains no 'model_state_dict_cond' or it is None, " "skipping conditional model loading" ) if 'scheduler_model' not in checkpoint: raise KeyError("Checkpoint missing 'scheduler_model' key") try: if isinstance(self.fwd_ddpm.vs, nn.Module): self.fwd_ddpm.vs.load_state_dict( checkpoint['scheduler_model']) if isinstance(self.rwd_ddpm.vs, nn.Module): self.rwd_ddpm.vs.load_state_dict( checkpoint['scheduler_model']) else: self.fwd_ddpm.vs = checkpoint['scheduler_model'] self.rwd_ddpm.vs = checkpoint['scheduler_model'] except Exception as e: warnings.warn( f"Scheduler loading failed: {e}. Continuing with current scheduler.") if 'optim_state_dict' not in checkpoint: raise KeyError("Checkpoint missing 'optim_state_dict' key") try: self.optim.load_state_dict(checkpoint['optim_state_dict']) except ValueError as e: warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.") epoch = checkpoint.get('epoch', -1) loss = checkpoint.get('loss', float('inf')) if self.master_process: print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}") return epoch, loss
[docs] @staticmethod def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) -> torch.optim.lr_scheduler.LambdaLR: """Creates a learning rate scheduler for warmup. Generates a scheduler that linearly increases the learning rate from 0 to the optimizer's initial value over the specified warmup epochs, then maintains it. Parameters ---------- optimizer : torch.optim.Optimizer Optimizer to apply the scheduler to. warmup_steps : int Number of steps for the warmup phase. Returns ------- torch.optim.lr_scheduler.LambdaLR Learning rate scheduler for warmup. """ def lr_lambda(step): if step < warmup_steps: return 0.1 + (0.9 * step / warmup_steps) return 1.0 return LambdaLR(optimizer, lr_lambda)
def _wrap_models_for_ddp(self) -> None: """Wrap models with DistributedDataParallel for multi-GPU training.""" if self.use_ddp: ddp_kwargs = dict(find_unused_parameters=False) if self._device_type == 'cuda': ddp_kwargs['device_ids'] = [self.ddp_local_rank] self.diff_net = DDP(self.diff_net, **ddp_kwargs) if self.cond_net is not None: self.cond_net = DDP(self.cond_net, **ddp_kwargs)
[docs] def forward(self) -> Dict: """Trains the DDPM model to predict noise added by the forward diffusion process. Executes the training loop with support for distributed training, gradient accumulation, mixed precision, gradient clipping, and learning rate scheduling. Includes validation, early stopping, and checkpointing functionality. Returns ------- losses : a dictionary contains train and validation losses """ self.diff_net.train() if self.cond_net is not None: self.cond_net.train() if self.use_comp: try: self.diff_net = torch.compile(self.diff_net) if self.cond_net is not None: self.cond_net = torch.compile(self.cond_net) except Exception as e: if self.master_process: print(f"Model compilation failed: {e}. Continuing without compilation.") self._wrap_models_for_ddp() use_amp = self.use_amp and self._device_type == 'cuda' scaler = torch.amp.GradScaler(self._device_type, enabled=use_amp) wait = 0 for epoch in range(self.max_epochs): pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not self.master_process) if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'): self.train_loader.sampler.set_epoch(epoch) train_losses_epoch = [] for step, (x, y) in enumerate(pbar): x = x.to(self.device) if self.cond_net is not None: y_encoded = self._process_conditional_input(y) else: y_encoded = None with torch.autocast(device_type=self._device_type, enabled=use_amp): noise = torch.randn_like(x) t = torch.randint(0, self.fwd_ddpm.vs.time_steps, (x.shape[0],), device=x.device) xt, target = self.fwd_ddpm(x, t, noise) pred = self.diff_net(xt, t, y_encoded, clip_embeddings=None) loss = self.loss_fn(pred, target) / self.grad_acc scaler.scale(loss).backward() if (step + 1) % self.grad_acc == 0: scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_(self.diff_net.parameters(), max_norm=1.0) if self.cond_net is not None: torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0) scaler.step(self.optim) scaler.update() self.optim.zero_grad(set_to_none=True) if self.global_step < self.warmup_steps: self.warmup_lr_scheduler.step() self.global_step += 1 pbar.set_postfix({'Loss': f'{loss.item() * self.grad_acc:.4f}'}) train_losses_epoch.append(loss.item() * self.grad_acc) mean_train_loss = torch.tensor(train_losses_epoch).mean().item() self.losses['train_losses'].append(mean_train_loss) if self.use_ddp: loss_tensor = torch.tensor(mean_train_loss, device=self.device) dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) mean_train_loss = loss_tensor.item() if self.master_process and (epoch + 1) % self.log_freq == 0: current_lr = self.optim.param_groups[0]['lr'] print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}") if self.val_loader is not None and (epoch + 1) % self.val_freq == 0: val_metrics = self.validate() val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics if self.master_process: print(f" | Val Loss: {val_loss:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid: print(f" | FID: {fid:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics: print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips: print(f" | LPIPS: {lpips_score:.4f}", end="") print() self.scheduler.step(val_loss) self.losses['val_losses'].append((val_loss, fid, mse, psnr, ssim, lpips_score)) else: if self.master_process: print() self.scheduler.step(mean_train_loss) if self.master_process: if mean_train_loss < self.best_loss: self.best_loss = mean_train_loss wait = 0 self._save_checkpoint(epoch + 1, self.best_loss, "best_") else: wait += 1 if wait >= self.patience: print("Early stopping triggered") self._save_checkpoint(epoch + 1, mean_train_loss, "early_stop_") break if (epoch + 1) % self.val_freq == 0: self._save_checkpoint(epoch + 1, mean_train_loss, "") if self.use_ddp: destroy_process_group() return self.losses
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor: """Process conditional input for text-to-image generation. Parameters ---------- y : torch.Tensor or list Conditional input (text prompts). Returns ------- torch.Tensor Encoded conditional input. """ y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y y_list = [str(item) for item in y_list] y_encoded = self.tokenizer( y_list, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" ).to(self.device) input_ids = y_encoded["input_ids"] attention_mask = y_encoded["attention_mask"] y_encoded = self.cond_net(input_ids, attention_mask) return y_encoded def _save_checkpoint(self, epoch: int, loss: float, pref: str = "") -> None: """Save model checkpoint (only called by master process). Parameters ---------- epoch : int Current epoch number. loss : float Current loss value. pref : str, optional prefix to add to checkpoint filename. """ try: diff_net_state = ( self.diff_net.module.state_dict() if self.use_ddp else self.diff_net.state_dict() ) cond_state = None if self.cond_net is not None: cond_state = ( self.cond_net.module.state_dict() if self.use_ddp else self.cond_net.state_dict() ) checkpoint = { 'epoch': epoch, 'model_state_dict_diff_net': diff_net_state, 'model_state_dict_cond': cond_state, 'optim_state_dict': self.optim.state_dict(), 'loss': loss, 'losses': self.losses, 'scheduler_model': self.fwd_ddpm.vs.state_dict(), 'max_epochs': self.max_epochs, } filename = f"{pref}model_epoch_{epoch}.pth" filepath = os.path.join(self.store_path, filename) os.makedirs(self.store_path, exist_ok=True) torch.save(checkpoint, filepath) print(f"Model saved at epoch {epoch} with loss: {loss:.4f}") except Exception as e: print(f"Failed to save model: {e}")
[docs] def validate(self) -> Tuple[float, float, float, float, float, float]: """Validates the noise predictor and computes evaluation metrics. Computes validation loss (MSE between predicted and ground truth noise) and generates samples using the reverse diffusion model. Evaluates image quality metrics if available. Returns ------- tuple (val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed. """ self.diff_net.eval() if self.cond_net is not None: self.cond_net.eval() val_losses = [] fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], [] with torch.no_grad(): for x, y in self.val_loader: x = x.to(self.device) x_orig = x.clone() if self.cond_net is not None: y_encoded = self._process_conditional_input(y) else: y_encoded = None noise = torch.randn_like(x) t = torch.randint(0, self.fwd_ddpm.vs.time_steps, (x.shape[0],), device=x.device) xt, target = self.fwd_ddpm(x, t, noise) pred = self.diff_net(xt, t, y_encoded, clip_embeddings=None) loss = self.loss_fn(pred, target) val_losses.append(loss.item()) if self.metrics_ is not None and self.rwd_ddpm is not None: xt = torch.randn_like(x) for t in reversed(range(self.fwd_ddpm.vs.time_steps)): time_steps = torch.full((xt.shape[0],), t, device=self.device, dtype=torch.long) pred = self.diff_net(xt, time_steps, y_encoded, clip_embeddings=None) xt, _ = self.rwd_ddpm(xt, pred, time_steps) x_hat = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1]) if self.norm_output: x_hat = (x_hat - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) x_orig = (x_orig - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) metrics_result = self.metrics_.forward(x_orig, x_hat) fid, mse, psnr, ssim, lpips_score = metrics_result if hasattr(self.metrics_, 'fid') and self.metrics_.fid: fid_scores.append(fid) if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics: mse_scores.append(mse) psnr_scores.append(psnr) ssim_scores.append(ssim) if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips: lpips_scores.append(lpips_score) val_loss = torch.tensor(val_losses).mean().item() if self.use_ddp: val_loss_tensor = torch.tensor(val_loss, device=self.device) dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG) val_loss = val_loss_tensor.item() fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf') mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None self.diff_net.train() if self.cond_net is not None: self.cond_net.train() return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
###==================================================================================================================###
[docs] class SampleDDPM(nn.Module): """mage generation using a trained Denoising Diffusion Probabilistic Model (DDPM). Implements the sampling process for DDPM, generating images by iteratively denoising random noise using a trained noise predictor and reverse diffusion process. Supports conditional generation with text prompts via a conditional model, as inspired by Ho et al. (2020). Parameters ---------- rwd_ddpm : nn.Module Reverse diffusion module (e.g., ReverseDDPM) for the reverse process. diff_net : nn.Module Trained model to predict noise at each time step. img_size : tuple Tuple of (height, width) specifying the generated image dimensions. cond_model : nn.Module, optional Model for conditional generation (e.g., text embeddings), default None. tokenizer : str, optional Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased"). max_token_length : int, optional Maximum length for tokenized prompts (default: 77). batch_size : int, optional Number of images to generate per batch (default: 1). in_channels : int, optional Number of input channels for generated images (default: 3). device : str, device type Device for computation (default: CUDA). norm_range : tuple, optional Tuple of (min, max) for clamping generated images (default: (-1, 1)). """ def __init__( self, rwd_ddpm: torch.nn.Module, diff_net: torch.nn.Module, img_size: Tuple[int, int], cond_model: Optional[torch.nn.Module] = None, tokenizer: str = "bert-base-uncased", max_token_length: int = 77, batch_size: int = 1, in_channels: int = 3, device: str = 'cuda', norm_range: Tuple[float, float] = (-1.0, 1.0) ) -> None: super().__init__() if isinstance(device, str): self.device = torch.device(device) else: self.device = device self.rwd_ddpm = rwd_ddpm.to(self.device) self.diff_net = diff_net.to(self.device) self.cond_model = cond_model.to(self.device) if cond_model else None self.tokenizer = BertTokenizer.from_pretrained(tokenizer) self.max_token_length = max_token_length self.in_channels = in_channels self.img_size = img_size self.batch_size = batch_size self.norm_range = norm_range if not isinstance(img_size, (tuple, list)) or len(img_size) != 2 or not all( isinstance(s, int) and s > 0 for s in img_size): raise ValueError("img_size must be a tuple of two positive integers (height, width)") if batch_size <= 0: raise ValueError("batch_size must be positive") if not isinstance(norm_range, (tuple, list)) or len(norm_range) != 2 or norm_range[0] >= norm_range[1]: raise ValueError("norm_range must be a tuple (min, max) with min < max")
[docs] def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]: """Tokenizes text prompts for conditional generation. Converts input prompts into tokenized input IDs and attention masks using the specified tokenizer, suitable for use with the conditional model. Parameters ---------- prompts : str or list A single text prompt or a list of text prompts. Returns ------- input_ids : torch.Tensor Tokenized input IDs, shape (batch_size, max_length). attention_mask : torch.Tensor Attention mask, shape (batch_size, max_length). """ if isinstance(prompts, str): prompts = [prompts] elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts): raise TypeError("prompts must be a string or list of strings") encoded = self.tokenizer( prompts, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" ) return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
[docs] def forward( self, conds: Optional[Union[str, List]] = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = "ddpm_samples" ) -> torch.Tensor: """Generates images using the DDPM sampling process. Iteratively denoises random noise to generate images using the reverse diffusion process and noise predictor. Supports conditional generation with text prompts. Optionally saves generated images to a specified directory. Parameters ---------- conds : str or list, optional Text prompt(s) for conditional generation, default None. norm_output : bool, optional If True, normalizes output images to [0, 1] (default: True). save_imgs : bool, optional If True, saves generated images to `save_path` (default: True). save_path : str, optional Directory to save generated images (default: "ddpm_samples"). Returns ------- samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). If `norm_output` is True, images are normalized to [0, 1]; otherwise, they are clamped to `norm_range`. """ if conds is not None and self.cond_model is None: raise ValueError("Conditions provided but no conditional model specified") if conds is None and self.cond_model is not None: raise ValueError("Conditions must be provided for conditional model") init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1], device=self.device) self.diff_net.eval() if self.cond_model: self.cond_model.eval() iterator = tqdm( reversed(range(self.rwd_ddpm.vs.time_steps)), total=self.rwd_ddpm.vs.time_steps, desc="Sampling", dynamic_ncols=True, leave=True, ) with torch.no_grad(): if self.cond_model is not None and conds is not None: input_ids, attention_masks = self.tokenize(conds) key_padding_mask = (attention_masks == 0) y = self.cond_model(input_ids, key_padding_mask) else: y = None xt = init_samps for step in iterator: time_steps = torch.full((self.batch_size,), step, device=self.device, dtype=torch.long) pred = self.diff_net(xt, time_steps, y, clip_embeddings=None) xt, _ = self.rwd_ddpm(xt, pred, time_steps) samps = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1]) if norm_output: samps = (samps - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) if save_imgs: os.makedirs(save_path, exist_ok=True) for i in range(samps.size(0)): img_path = os.path.join(save_path, f"img_{i+1}.png") save_image(samps[i], img_path) return samps
[docs] def to(self, device: torch.device) -> Self: """Moves the module and its components to the specified device. Updates the device attribute and moves the reverse diffusion, noise predictor, and conditional model (if present) to the specified device. Parameters ---------- device : torch.device Target device for the module and its components. Returns ------- sample_ddpm (SampleDDPM) - moved to the specified device. """ self.device = device self.diff_net.to(device) self.rwd_ddpm.to(device) if self.cond_model: self.cond_model.to(device) return super().to(device)