Source code for torchdiff.ddim

"""
**Denoising Diffusion Implicit Models (DDIM)**

This module provides a complete implementation of DDIM, as described in Song et al.
(2021, "Denoising Diffusion Implicit Models"). It includes components for forward and
reverse diffusion processes, hyperparameter management, training, and image sampling.
Supports both unconditional and conditional generation with text prompts, using a
subsampled time step schedule for faster sampling compared to DDPM.

**Components**

- **ForwardDDIM**: Forward diffusion process to add noise.
- **ReverseDDIM**: Reverse diffusion process to denoise with subsampled steps.
- **SchedulerDDIM**: Noise schedule management with subsampled (tau) schedule.
- **TrainDDIM**: Training loop with mixed precision and scheduling.
- **SampleDDIM**: Image generation from trained models with subsampled steps.

**Notes**

- The subsampled time step schedule (tau) enables faster sampling, controlled by the
  `tau_num_steps` parameter in VarianceSchedulerDDIM.

**References**:

- Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020).

-------------------------------------------------------------------------------
"""


###==================================================================================================================###


import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
from tqdm import tqdm
from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
from transformers import BertTokenizer
import warnings
from torchvision.utils import save_image
from typing import Optional, Tuple, Callable, List, Any, Union, Dict
from typing_extensions import Self
from .utils import LossAdapter
import os

__all__ = [
    "ForwardDDIM",
    "ReverseDDIM",
    "SchedulerDDIM",
    "TrainDDIM",
    "SampleDDIM",
]


###==================================================================================================================###


[docs] class ForwardDDIM(nn.Module): """ Implements the forward (noising) process of DDIM. This module samples x_t from the forward diffusion distribution: q(x_t | x_0) = N(x_t; sqrt(alphā_t) * x_0, (1 - alphā_t) * I) It also computes the appropriate training target depending on the prediction parameterization (noise, x0, or v-prediction). Args: scheduler: Noise scheduler containing precomputed diffusion coefficients. pred_type: Type of model prediction. One of ["noise", "x0", "v"]. """ def __init__( self, scheduler: nn.Module, pred_type: str = "noise" ): super().__init__() valid_types = ["noise", "x0", "v"] if pred_type not in valid_types: raise ValueError(f"prediction_type must be one of {valid_types}, got {pred_type}") self.vs = scheduler self.pred_type = pred_type
[docs] def forward(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Perform the forward diffusion step and compute the training target. Samples x_t by adding noise to the clean input x_0 at timestep t, and returns the corresponding supervision target for training. Args: x0: Clean input data of shape (batch, ...). t: Discrete diffusion timesteps of shape (batch,). noise: Gaussian noise of same shape as x0. Returns: xt: Noised data x_t of shape (batch, ...). target: Training target corresponding to pred_type: - "noise": the added noise ε - "x0": the original clean input x0 - "v": the velocity parameterization """ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t] sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t] sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, x0.shape) sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, x0.shape) xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise if self.pred_type == "noise": target = noise elif self.pred_type == "x0": target = x0 elif self.pred_type == "v": target = sqrt_alpha_cumprod_t * noise - sqrt_one_minus_alpha_cumprod_t * x0 return xt, target
###==================================================================================================================###
[docs] class ReverseDDIM(nn.Module): """ Implements the reverse (denoising) process of DDIM. Computes x_{t_prev} from x_t using the DDIM update rule: x_{t_prev} = sqrt(alphā_{t_prev}) * x̂_0 + sqrt(1 - alphā_{t_prev} - σ_t²) * ε̂_t + σ_t * z where σ_t controls stochasticity via eta. Setting eta=0 results in deterministic DDIM sampling. Args: scheduler: Noise scheduler containing diffusion coefficients. pred_type: Model prediction type ["noise", "x0", "v"]. eta: Controls stochasticity of sampling (0 = deterministic). clip_: Whether to clip predicted x0 to [-1, 1]. """ def __init__(self, scheduler: nn.Module, pred_type: str = "noise", eta: float = 0.0, clip_: bool = True): super().__init__() valid_pred_types = ["noise", "x0", "v"] if pred_type not in valid_pred_types: raise ValueError(f"prediction_type must be one of {valid_pred_types}") self.vs = scheduler self.pred_type = pred_type self.eta = eta self.clip_ = clip_
[docs] def predict_x0(self, xt: torch.Tensor, t: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: """ Convert model output into a prediction of the clean sample x0. The conversion depends on the chosen prediction parameterization (noise, x0, or v-prediction). Args: xt: Noisy input at timestep t of shape (batch, ...). t: Current timesteps of shape (batch,). pred: Model output of shape (batch, ...). Returns: x0_pred: Predicted clean sample x0. """ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t] sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t] sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, xt.shape) sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, xt.shape) if self.pred_type == "noise": x0_pred = (xt - sqrt_one_minus_alpha_cumprod_t * pred) / sqrt_alpha_cumprod_t elif self.pred_type == "x0": x0_pred = pred elif self.pred_type == "v": x0_pred = sqrt_alpha_cumprod_t * xt - sqrt_one_minus_alpha_cumprod_t * pred if self.clip_: x0_pred = torch.clamp(x0_pred, -1.0, 1.0) return x0_pred
[docs] def predict_noise(self, xt: torch.Tensor, t: torch.Tensor, x0_pred: torch.Tensor) -> torch.Tensor: """ Compute the predicted noise ε̂_t from x_t and predicted x0. Uses the identity: ε̂_t = (x_t - sqrt(alphā_t) * x̂_0) / sqrt(1 - alphā_t) Args: xt: Noisy input at timestep t. t: Current timesteps. x0_pred: Predicted clean sample x0. Returns: pred_noise: Predicted noise ε̂_t. """ sqrt_alpha_cumprod_t = self.vs.sqrt_alphas_cumprod[t] sqrt_one_minus_alpha_cumprod_t = self.vs.sqrt_one_minus_alphas_cumprod[t] sqrt_alpha_cumprod_t = self.vs.get_index(sqrt_alpha_cumprod_t, xt.shape) sqrt_one_minus_alpha_cumprod_t = self.vs.get_index(sqrt_one_minus_alpha_cumprod_t, xt.shape) pred_noise = (xt - sqrt_alpha_cumprod_t * x0_pred) / sqrt_one_minus_alpha_cumprod_t return pred_noise
[docs] def forward( self, xt: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor, pred: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ Perform one DDIM reverse diffusion step. Computes x_{t_prev} from x_t using the DDIM update equation. Allows non-adjacent timesteps, enabling accelerated sampling. Args: xt: Current noisy sample x_t of shape (batch, ...). t: Current timestep indices of shape (batch,). t_prev: Previous timestep indices of shape (batch,). pred: Model prediction at timestep t. Returns: x_prev: Sample at timestep t_prev. pred_x0: Predicted clean sample x0. """ pred_x0 = self.predict_x0(xt, t, pred) # predict noise from x_0 pred_noise = self.predict_noise(xt, t, pred_x0) alpha_cumprod_t = self.vs.alphas_cumprod[t] alpha_cumprod_t_prev = self.vs.alphas_cumprod[t_prev] alpha_cumprod_t = self.vs.get_index(alpha_cumprod_t, xt.shape) alpha_cumprod_t_prev = self.vs.get_index(alpha_cumprod_t_prev, xt.shape) # compute variance σ_t # eta=0: fully deterministic (σ_t=0) # eta=1: maximum stochasticity (similar to ddpm) sigma_t = self.eta * torch.sqrt( (1 - alpha_cumprod_t_prev) / (1 - alpha_cumprod_t) * (1 - alpha_cumprod_t / alpha_cumprod_t_prev) ) # dir_xt = √(1 - ᾱ_{t_prev} - σ_t²) * ε̂_t sqrt_one_minus_alpha_cumprod_t_prev_minus_sigma = torch.sqrt( 1.0 - alpha_cumprod_t_prev - sigma_t ** 2 ) dir_xt = sqrt_one_minus_alpha_cumprod_t_prev_minus_sigma * pred_noise noise = torch.randn_like(xt) mask = (t_prev != 0).float().view(-1, *([1] * (len(xt.shape) - 1))) # x_{t_prev} = √ᾱ_{t_prev} * x̂_0 + dir_xt + σ_t * z x_prev = (torch.sqrt(alpha_cumprod_t_prev) * pred_x0 + dir_xt + sigma_t * mask * noise) return x_prev, pred_x0
###==================================================================================================================###
[docs] class SchedulerDDIM(nn.Module): """ Noise scheduler for DDIM. Responsible for constructing the diffusion noise schedule and precomputing all coefficients required for both training and sampling. Supports multiple beta schedules and allows using fewer inference steps than training steps. Args: schedule_type: Type of beta schedule ("linear", "cosine", etc.). train_steps: Number of diffusion steps used during training. sample_steps: Number of steps used during inference. beta_min: Minimum beta value. beta_max: Maximum beta value. cosine_s: Offset parameter for cosine schedule. clip_min: Minimum clipping value for betas. clip_max: Maximum clipping value for betas. learn_var: Whether posterior variance is learnable. """ def __init__( self, schedule_type: str = "linear", train_steps: int = 1000, # total timesteps for training sample_steps: Optional[int] = None, # can use fewer steps for sampling beta_min: float = 0.0001, beta_max: float = 0.02, cosine_s: float = 0.008, clip_min: float = 0.0001, clip_max: float = 0.9999, learn_var: bool = False ): super().__init__() valid_schedules = ["linear", "cosine", "quadratic", "sigmoid"] if schedule_type not in valid_schedules: raise ValueError(f"schedule_type must be one of {valid_schedules}, got {schedule_type}") self.schedule_type = schedule_type self.train_steps = train_steps self.sample_steps = sample_steps or train_steps self.beta_min = beta_min self.beta_max = beta_max self.cosine_s = cosine_s self.clip_min = clip_min self.clip_max = clip_max self.learn_var = learn_var self._setup_schedule() self._setup_inference_timesteps() def _setup_schedule(self): """ Construct the diffusion noise schedule and precompute coefficients. Computes betas, alphas, cumulative products, square roots, and posterior variances required for forward and reverse diffusion. """ if self.schedule_type == "linear": betas = torch.linspace(self.beta_min, self.beta_max, self.train_steps) elif self.schedule_type == "cosine": steps = self.train_steps + 1 t = torch.linspace(0, self.train_steps, steps) alphas_cumprod = torch.cos(((t / self.train_steps) + self.cosine_s) / (1 + self.cosine_s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = torch.clip(betas, self.clip_min, self.clip_max) elif self.schedule_type == "quadratic": betas = torch.linspace(self.beta_min ** 0.5, self.beta_max ** 0.5, self.train_steps) ** 2 elif self.schedule_type == "sigmoid": betas = torch.linspace(-6, 6, self.train_steps) betas = torch.sigmoid(betas) * (self.beta_max - self.beta_min) + self.beta_min alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) posterior_log_variance = torch.log(torch.clamp(posterior_variance, min=1e-20)) if self.learn_var: self.register_parameter('log_variance', nn.Parameter(posterior_log_variance.clone())) else: self.register_buffer('log_variance', posterior_log_variance) self.register_buffer('betas', betas) self.register_buffer('alphas', alphas) self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod) self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod) self.register_buffer('posterior_variance', posterior_variance) self.register_buffer('posterior_log_variance', posterior_log_variance) def _setup_inference_timesteps(self): """ Create the set of timesteps used during inference. DDIM allows skipping timesteps for faster sampling by selecting a subset of training timesteps. """ step_ratio = self.train_steps // self.sample_steps inference_timesteps = torch.arange(0, self.train_steps, step_ratio) self.register_buffer('inference_timesteps', inference_timesteps)
[docs] def set_inference_timesteps(self, num_inference_timesteps: int): """ Update the number of inference timesteps dynamically. Allows changing sampling speed and quality trade-offs at inference time without retraining the model. Args: num_inference_timesteps: Number of timesteps to use for sampling. """ self.sample_steps = num_inference_timesteps self._setup_inference_timesteps()
[docs] def get_index(self, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: """ Reshape timestep-dependent coefficients for broadcasting. Extracts values indexed by t and reshapes them to match the dimensionality of a given tensor shape. Args: t: Timesteps tensor of shape (batch,). x_shape: Shape of the target tensor. Returns: Reshaped tensor suitable for broadcasting. """ batch_size = t.shape[0] return t.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
###==================================================================================================================###
[docs] class TrainDDIM(nn.Module): """Trainer for Denoising Diffusion Implicit Models (DDIM). Manages the training process for DDIM, optimizing a noise predictor model to learn the noise added by the forward diffusion process. Supports conditional training with text prompts, mixed precision training, learning rate scheduling, early stopping, and checkpointing, as inspired by Song et al. (2021). Parameters ---------- `diff_net` : nn.Module Main model to predict noise/v/x0 fwd_ddim : nn.Module Forward DDIM diffusion module for adding noise. rwd_ddim: nn.Module Reverse DDIM diffusion module for denoising. `data_loader` : torch.utils.data.DataLoader DataLoader for training data. `optim` : torch.optim.Optimizer Optimizer for training the noise predictor and conditional model (if applicable). `loss_fn` : callable Loss function to compute the difference between predicted and actual noise. `val_loader` : torch.utils.data.DataLoader, optional DataLoader for validation data, default None. `max_epochs` : int, optional Maximum number of training epochs (default: 100). `device` : str Device for computation (default: CUDA). `cond_net` : nn.Module, optional Model for conditional generation (e.g., text embeddings), default None. `metrics_` : object, optional Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None). `tokenizer` : BertTokenizer, optional Tokenizer for processing text prompts, default None (loads "bert-base-uncased"). `max_token_length` : int, optional Maximum length for tokenized prompts (default: 77). `store_path` : str, optional Path to save model checkpoints (default: "ddim_train"). `patience` : int, optional Number of epochs to wait for improvement before early stopping (default: 20). `warmup_steps` : int, optional Number of epochs for learning rate warmup (default: 1000). `val_freq` : int, optional Frequency (in epochs) for validation (default: 10). `norm_range` : tuple, optional Range for clamping generated images (default: (-1, 1)). `norm_output` : bool, optional Whether to normalize generated images to [0, 1] for metrics (default: True). `use_ddp` : bool, optional Whether to use Distributed Data Parallel training (default: False). `grad_acc` : int, optional Number of gradient accumulation steps before optimizer update (default: 1). `log_freq` : int, optional Number of epochs before printing loss. use_comp : bool, optional whether the model is internally compiled using torch.compile (default: false) """ def __init__( self, diff_net: torch.nn.Module, fwd_ddim: torch.nn.Module, rwd_ddim: torch.nn.Module, train_loader: torch.utils.data.DataLoader, optim: torch.optim.Optimizer, loss_fn: Callable, val_loader: Optional[torch.utils.data.DataLoader] = None, max_epochs: int = 100, device: str = 'cuda', cond_net: torch.nn.Module = None, metrics_: Optional[Any] = None, tokenizer: Optional[BertTokenizer] = None, max_token_length: int = 77, store_path: Optional[str] = None, patience: int = 20, warmup_steps: int = 1000, val_freq: int = 10, norm_range: Tuple[float, float] = (-1, 1), norm_output: bool = True, use_ddp: bool = False, grad_acc: int = 1, log_freq: int = 1, use_comp: bool = False, use_amp: bool = False, *args ) -> None: super().__init__() self.use_ddp = use_ddp self.grad_acc = grad_acc self.use_amp = use_amp if isinstance(device, str): self.device = torch.device(device) else: self.device = device if self.use_ddp: self._setup_ddp() else: self._setup_single_gpu() self.diff_net = diff_net.to(self.device) self.fwd_ddim = fwd_ddim.to(self.device) self.rwd_ddim = rwd_ddim.to(self.device) self.cond_net = cond_net.to(self.device) if cond_net else None self.metrics_ = metrics_ self.optim = optim self.loss_fn = LossAdapter(loss_fn) self.store_path = store_path or "ddim_train" self.train_loader = train_loader self.val_loader = val_loader self.max_epochs = max_epochs self.max_token_length = max_token_length self.patience = patience self.val_freq = val_freq self.norm_range = norm_range self.norm_output = norm_output self.log_freq = log_freq self.use_comp = use_comp self.global_step = 0 self.warmup_steps = warmup_steps self.best_loss = float('inf') self.losses = {'train_losses': [], 'val_losses': []} self.scheduler = ReduceLROnPlateau( self.optim, patience=self.patience, factor=0.5 ) self.warmup_lr_scheduler = self.warmup_scheduler(self.optim, warmup_steps) self._device_type = self.device.type if hasattr(self.device, 'type') else ('cuda' if 'cuda' in str(self.device) else 'cpu') if tokenizer is None: try: self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") except Exception as e: raise ValueError(f"Failed to load default tokenizer: {e}. Please provide a tokenizer.") def _setup_ddp(self) -> None: """Setup Distributed Data Parallel training configuration. Initializes process group, determines rank information, and sets up CUDA device for the current process. """ if "RANK" not in os.environ: raise ValueError("DDP enabled but RANK environment variable not set") if "LOCAL_RANK" not in os.environ: raise ValueError("DDP enabled but LOCAL_RANK environment variable not set") if "WORLD_SIZE" not in os.environ: raise ValueError("DDP enabled but WORLD_SIZE environment variable not set") if not torch.distributed.is_initialized(): backend = "nccl" if torch.cuda.is_available() else "gloo" init_process_group(backend=backend) self.ddp_rank = int(os.environ["RANK"]) # global rank across all nodes self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) # local rank on current node self.ddp_world_size = int(os.environ["WORLD_SIZE"]) # total number of processes if torch.cuda.is_available(): self.device = torch.device(f"cuda:{self.ddp_local_rank}") torch.cuda.set_device(self.device) else: self.device = torch.device("cpu") self.master_process = self.ddp_rank == 0 if self.master_process: print(f"DDP initialized with world_size={self.ddp_world_size}") def _setup_single_gpu(self) -> None: """Setup single GPU or CPU training configuration.""" self.ddp_rank = 0 self.ddp_local_rank = 0 self.ddp_world_size = 1 self.master_process = True
[docs] def load_checkpoint(self, checkpoint_path: str) -> Tuple[int, float]: """Loads a training checkpoint to resume training. Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint. Handles DDP model state dict loading. Parameters ---------- checkpoint_path : str Path to the checkpoint file. Returns ------- epoch : int The epoch at which the checkpoint was saved. loss : float The loss at the checkpoint. """ try: checkpoint = torch.load(checkpoint_path, map_location=self.device) except FileNotFoundError: raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") if 'model_state_dict_diff_net' not in checkpoint: raise KeyError("Checkpoint missing 'model_state_dict_diff_net' key") state_dict = checkpoint['model_state_dict_diff_net'] if self.use_ddp and not any(key.startswith('module.') for key in state_dict.keys()): state_dict = {f'module.{k}': v for k, v in state_dict.items()} elif not self.use_ddp and any(key.startswith('module.') for key in state_dict.keys()): state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} self.diff_net.load_state_dict(state_dict) if self.cond_net is not None: if 'model_state_dict_cond' in checkpoint and checkpoint['model_state_dict_cond'] is not None: cond_state_dict = checkpoint['model_state_dict_cond'] if self.use_ddp and not any(key.startswith('module.') for key in cond_state_dict.keys()): cond_state_dict = {f'module.{k}': v for k, v in cond_state_dict.items()} elif not self.use_ddp and any(key.startswith('module.') for key in cond_state_dict.keys()): cond_state_dict = {k.replace('module.', ''): v for k, v in cond_state_dict.items()} self.cond_net.load_state_dict(cond_state_dict) else: warnings.warn( "Checkpoint contains no 'model_state_dict_cond' or it is None, " "skipping conditional model loading" ) if 'scheduler_model' not in checkpoint: raise KeyError("Checkpoint missing 'scheduler_model' key") try: if isinstance(self.fwd_ddim.vs, nn.Module): self.fwd_ddim.vs.load_state_dict( checkpoint['scheduler_model']) if isinstance(self.rwd_ddim.vs, nn.Module): self.rwd_ddim.vs.load_state_dict( checkpoint['scheduler_model']) else: self.fwd_ddim.vs = checkpoint['scheduler_model'] self.rwd_ddim.vs = checkpoint['scheduler_model'] except Exception as e: warnings.warn(f"Scheduler loading failed: {e}. Continuing with current scheduler.") if 'optim_state_dict' not in checkpoint: raise KeyError("Checkpoint missing 'optim_state_dict' key") try: self.optim.load_state_dict(checkpoint['optim_state_dict']) except ValueError as e: warnings.warn(f"Optimizer state loading failed: {e}. Continuing without optimizer state.") epoch = checkpoint.get('epoch', -1) loss = checkpoint.get('loss', float('inf')) if self.master_process: print(f"Loaded checkpoint from {checkpoint_path} at epoch {epoch} with loss {loss:.4f}") return epoch, loss
[docs] @staticmethod def warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) -> torch.optim.lr_scheduler.LambdaLR: """Creates a learning rate scheduler for warmup. Generates a scheduler that linearly increases the learning rate from 0 to the optimizer's initial value over the specified warmup epochs, then maintains it. Parameters ---------- `optimizer` : torch.optim.Optimizer Optimizer to apply the scheduler to. `warmup_steps` : int Number of steps for the warmup phase. Returns ------- lr_scheduler (torch.optim.lr_scheduler.LambdaLR) - Learning rate scheduler for warmup. """ def lr_lambda(step): if step < warmup_steps: return 0.1 + (0.9 * step / warmup_steps) return 1.0 return LambdaLR(optimizer, lr_lambda)
def _wrap_models_for_ddp(self) -> None: """Wrap models with DistributedDataParallel for multi-GPU training.""" if self.use_ddp: ddp_kwargs = dict(find_unused_parameters=False) if self._device_type == 'cuda': ddp_kwargs['device_ids'] = [self.ddp_local_rank] self.diff_net = DDP(self.diff_net, **ddp_kwargs) if self.cond_net is not None: self.cond_net = DDP(self.cond_net, **ddp_kwargs)
[docs] def forward(self) -> Dict: """Trains the DDIM model to predict noise added by the forward diffusion process. Executes the training loop, optimizing the noise predictor and conditional model (if applicable) using mixed precision, gradient clipping, and learning rate scheduling. Supports validation, early stopping, and checkpointing. Returns ------- losses: dictionlary contains train and validation losses """ self.diff_net.train() if self.cond_net is not None: self.cond_net.train() if self.use_comp: try: self.diff_net = torch.compile(self.diff_net) if self.cond_net is not None: self.cond_net = torch.compile(self.cond_net) except Exception as e: if self.master_process: print(f"Model compilation failed: {e}. Continuing without compilation.") self._wrap_models_for_ddp() use_amp = self.use_amp and self._device_type == 'cuda' scaler = torch.amp.GradScaler(self._device_type, enabled=use_amp) wait = 0 for epoch in range(self.max_epochs): pbar = tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not self.master_process) if self.use_ddp and hasattr(self.train_loader.sampler, 'set_epoch'): self.train_loader.sampler.set_epoch(epoch) train_losses_epoch = [] for step, (x, y) in enumerate(pbar): x = x.to(self.device) if self.cond_net is not None: y_encoded = self._process_conditional_input(y) else: y_encoded = None with torch.autocast(device_type=self._device_type, enabled=use_amp): noise = torch.randn_like(x) t = torch.randint(0, self.fwd_ddim.vs.train_steps, (x.shape[0],), device=x.device) xt, target = self.fwd_ddim(x, t, noise) pred = self.diff_net(xt, t, y_encoded, clip_embeddings=None) loss = self.loss_fn(pred, target) / self.grad_acc scaler.scale(loss).backward() if (step + 1) % self.grad_acc == 0: scaler.unscale_(self.optim) torch.nn.utils.clip_grad_norm_(self.diff_net.parameters(), max_norm=1.0) if self.cond_net is not None: torch.nn.utils.clip_grad_norm_(self.cond_net.parameters(), max_norm=1.0) scaler.step(self.optim) scaler.update() self.optim.zero_grad(set_to_none=True) if self.global_step < self.warmup_steps: self.warmup_lr_scheduler.step() self.global_step += 1 pbar.set_postfix({'Loss': f'{loss.item() * self.grad_acc:.4f}'}) train_losses_epoch.append(loss.item() * self.grad_acc) mean_train_loss = torch.tensor(train_losses_epoch).mean().item() self.losses['train_losses'].append(mean_train_loss) if self.use_ddp: loss_tensor = torch.tensor(mean_train_loss, device=self.device) dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) mean_train_loss = loss_tensor.item() if self.master_process and (epoch + 1) % self.log_freq == 0: current_lr = self.optim.param_groups[0]['lr'] print(f"\nEpoch: {epoch + 1}/{self.max_epochs} | LR: {current_lr:.2e} | Train Loss: {mean_train_loss:.4f}") if self.val_loader is not None and (epoch + 1) % self.val_freq == 0: val_metrics = self.validate() val_loss, fid, mse, psnr, ssim, lpips_score = val_metrics if self.master_process: print(f" | Val Loss: {val_loss:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'fid') and self.metrics_.fid: print(f" | FID: {fid:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'metrics') and self.metrics_.metrics: print(f" | MSE: {mse:.4f} | PSNR: {psnr:.4f} | SSIM: {ssim:.4f}", end="") if self.metrics_ and hasattr(self.metrics_, 'lpips') and self.metrics_.lpips: print(f" | LPIPS: {lpips_score:.4f}", end="") print() self.scheduler.step(val_loss) self.losses['val_losses'].append((val_loss, fid, mse, psnr, ssim, lpips_score)) else: if self.master_process: print() self.scheduler.step(mean_train_loss) if self.master_process: if mean_train_loss < self.best_loss: self.best_loss = mean_train_loss wait = 0 self._save_checkpoint(epoch + 1, self.best_loss, "best_") else: wait += 1 if wait >= self.patience: print("Early stopping triggered") self._save_checkpoint(epoch + 1, mean_train_loss, "early_stop_") break if (epoch + 1) % self.val_freq == 0: self._save_checkpoint(epoch + 1, mean_train_loss, "") if self.use_ddp: destroy_process_group() return self.losses
def _process_conditional_input(self, y: Union[torch.Tensor, List]) -> torch.Tensor: """Process conditional input for text-to-image generation. Parameters ---------- y : torch.Tensor or list Conditional input (text prompts). Returns ------- torch.Tensor Encoded conditional input. """ # convert to string list y_list = y.cpu().numpy().tolist() if isinstance(y, torch.Tensor) else y y_list = [str(item) for item in y_list] y_encoded = self.tokenizer( y_list, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" ).to(self.device) input_ids = y_encoded["input_ids"] attention_mask = y_encoded["attention_mask"] y_encoded = self.cond_net(input_ids, attention_mask) return y_encoded def _save_checkpoint(self, epoch: int, loss: float, pref: str = "") -> None: """Save model checkpoint (only called by master process). Parameters ---------- epoch : int Current epoch number. loss : float Current loss value. pref : str, optional prefix to add to checkpoint filename. """ try: diff_net_state = ( self.diff_net.module.state_dict() if self.use_ddp else self.diff_net.state_dict() ) cond_state = None if self.cond_net is not None: cond_state = ( self.cond_net.module.state_dict() if self.use_ddp else self.cond_net.state_dict() ) checkpoint = { 'epoch': epoch, 'model_state_dict_diff_net': diff_net_state, 'model_state_dict_cond': cond_state, 'optim_state_dict': self.optim.state_dict(), 'loss': loss, 'losses': self.losses, 'scheduler_model': self.fwd_ddim.vs.state_dict(), 'max_epochs': self.max_epochs, } filename = f"{pref}model_epoch_{epoch}.pth" filepath = os.path.join(self.store_path, filename) os.makedirs(self.store_path, exist_ok=True) torch.save(checkpoint, filepath) print(f"Model saved at epoch {epoch} with loss: {loss:.4f}") except Exception as e: print(f"Failed to save model: {e}")
[docs] def validate(self) -> Tuple[float, float, float, float, float, float]: """Validates the noise predictor and computes evaluation Metrics. Computes validation loss (MSE between predicted and ground truth noise) and generates samples using the reverse diffusion model by manually iterating over timesteps. Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS) if metrics_ is provided. Returns ------- val_loss : float Mean validation loss. fid : float, or `float('inf')` if not computed Mean FID score. mse : float, or None if not computed Mean MSE psnr : float, or None if not computed Mean PSNR ssim : float, or None if not computed Mean SSIM lpips_score : float, or None if not computed Mean LPIPS score """ self.diff_net.eval() if self.cond_net is not None: self.cond_net.eval() val_losses = [] fid_scores, mse_scores, psnr_scores, ssim_scores, lpips_scores = [], [], [], [], [] with torch.no_grad(): for x, y in self.val_loader: x = x.to(self.device) x_orig = x.clone() if self.cond_net is not None: y_encoded = self._process_conditional_input(y) else: y_encoded = None noise = torch.randn_like(x) t = torch.randint(0, self.fwd_ddim.vs.train_steps, (x.shape[0],), device=x.device) xt, target = self.fwd_ddim(x, t, noise) pred = self.diff_net(xt, t, y_encoded, clip_embeddings=None) loss = self.loss_fn(pred, target) val_losses.append(loss.item()) if self.metrics_ is not None and self.rwd_ddim is not None: xt = torch.randn_like(x) timesteps = self.fwd_ddim.vs.inference_timesteps.flip(0) for i in range(len(timesteps) - 1): t_ = timesteps[i].item() t_pre = timesteps[i + 1].item() time = torch.full((xt.shape[0],), t_, device=self.device, dtype=torch.long) prev_time = torch.full((xt.shape[0],), t_pre, device=self.device, dtype=torch.long) pred = self.diff_net(xt, time, y_encoded, clip_embeddings=None) xt, _ = self.rwd_ddim(xt, time, prev_time, pred) x_hat = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1]) if self.norm_output: x_hat = (x_hat - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) x_orig = (x_orig - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) metrics_result = self.metrics_.forward(x_orig, x_hat) fid, mse, psnr, ssim, lpips_score = metrics_result if hasattr(self.metrics_, 'fid') and self.metrics_.fid: fid_scores.append(fid) if hasattr(self.metrics_, 'metrics') and self.metrics_.metrics: mse_scores.append(mse) psnr_scores.append(psnr) ssim_scores.append(ssim) if hasattr(self.metrics_, 'lpips') and self.metrics_.lpips: lpips_scores.append(lpips_score) val_loss = torch.tensor(val_losses).mean().item() if self.use_ddp: val_loss_tensor = torch.tensor(val_loss, device=self.device) dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.AVG) val_loss = val_loss_tensor.item() fid_avg = torch.tensor(fid_scores).mean().item() if fid_scores else float('inf') mse_avg = torch.tensor(mse_scores).mean().item() if mse_scores else None psnr_avg = torch.tensor(psnr_scores).mean().item() if psnr_scores else None ssim_avg = torch.tensor(ssim_scores).mean().item() if ssim_scores else None lpips_avg = torch.tensor(lpips_scores).mean().item() if lpips_scores else None self.diff_net.train() if self.cond_net is not None: self.cond_net.train() return val_loss, fid_avg, mse_avg, psnr_avg, ssim_avg, lpips_avg
###==================================================================================================================###
[docs] class SampleDDIM(nn.Module): """Image generation using a trained DDIM model. Implements the sampling process for DDIM, generating images by iteratively denoising random noise using a trained noise predictor and reverse diffusion process with a subsampled time step schedule. Supports conditional generation with text prompts, as inspired by Song et al. (2021). Parameters ---------- `rwd_ddim` : nn.Module Reverse diffusion module (e.g., ReverseDDIM) for the reverse process. `diff_net` : nn.Module Trained model to predict noise/v/x0 at each time step. `img_size` : tuple Tuple of (height, width) specifying the generated image dimensions. `cond_net` : nn.Module, optional Model for conditional generation (e.g., text embeddings), default None. `tokenizer` : str, optional Pretrained tokenizer name from Hugging Face (default: "bert-base-uncased"). `max_length` : int, optional Maximum length for tokenized prompts (default: 77). `batch_size` : int, optional Number of images to generate per batch (default: 1). `in_channels` : int, optional Number of input channels for generated images (default: 3). `device` : str Device for computation (default: CUDA). `norm_range` : tuple, optional Tuple of (min, max) for clamping generated images (default: (-1, 1)). """ def __init__( self, rwd_ddim: torch.nn.Module, diff_net: torch.nn.Module, img_size: Tuple[int, int], cond_net: Optional[torch.nn.Module] = None, tokenizer: str = "bert-base-uncased", max_token_length: int = 77, batch_size: int = 1, in_channels: int = 3, device: Optional[str] = None, norm_range: Tuple[float, float] = (-1.0, 1.0) ) -> None: super().__init__() if isinstance(device, str): self.device = torch.device(device) else: self.device = device self.rwd_ddim = rwd_ddim.to(self.device) self.diff_net = diff_net.to(self.device) self.cond_net = cond_net.to(self.device) if cond_net else None self.tokenizer = BertTokenizer.from_pretrained(tokenizer) self.max_token_length = max_token_length self.in_channels = in_channels self.img_size = img_size self.batch_size = batch_size self.norm_range = norm_range if not isinstance(img_size, (tuple, list)) or len(img_size) != 2 or not all( isinstance(s, int) and s > 0 for s in img_size): raise ValueError("img_size must be a tuple of two positive integers (height, width)") if batch_size <= 0: raise ValueError("batch_size must be positive") if not isinstance(norm_range, (tuple, list)) or len(norm_range) != 2 or norm_range[0] >= norm_range[1]: raise ValueError("norm_range must be a tuple (min, max) with min < max")
[docs] def tokenize(self, prompts: Union[List, str]) -> Tuple[torch.Tensor, torch.Tensor]: """Tokenizes text prompts for conditional generation. Converts input prompts into tokenized input IDs and attention masks using the specified tokenizer, suitable for use with the conditional model. Parameters ---------- `prompts` : str or list A single text prompt or a list of text prompts. Returns ------- input_ids : torch.Tensor Tokenized input IDs, shape (batch_size, max_length). attention_mask : torch.Tensor Attention mask, shape (batch_size, max_length). """ if isinstance(prompts, str): prompts = [prompts] elif not isinstance(prompts, list) or not all(isinstance(p, str) for p in prompts): raise TypeError("prompts must be a string or list of strings") encoded = self.tokenizer( prompts, padding="max_length", truncation=True, max_length=self.max_token_length, return_tensors="pt" ) return encoded["input_ids"].to(self.device), encoded["attention_mask"].to(self.device)
[docs] def forward(self, conds: Optional[Union[str, List]] = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = "ddim_samples") -> torch.Tensor: """Generates images using the DDIM sampling process. Iteratively denoises random noise to generate images using the reverse diffusion process with a subsampled time step schedule and noise predictor. Supports conditional generation with text prompts. Parameters ---------- `conds` : str or list, optional Text prompt(s) for conditional generation, default None. `norm_output` : bool, optional If True, normalizes output images to [0, 1] (default: True). `save_imgs` : bool, optional If True, saves generated images to `save_path` (default: True). `save_path` : str, optional Directory to save generated images (default: "ddim_samples"). Returns ------- samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width). """ if conds is not None and self.cond_net is None: raise ValueError("Conditions provided but no conditional model specified") if conds is None and self.cond_net is not None: raise ValueError("Conditions must be provided for conditional model") init_samps = torch.randn(self.batch_size, self.in_channels, self.img_size[0], self.img_size[1], device=self.device) self.diff_net.eval() if self.cond_net: self.cond_net.eval() timesteps = self.rwd_ddim.vs.inference_timesteps timesteps = timesteps.flip(0) iterator = tqdm( range(len(timesteps) - 1), total=len(timesteps) - 1, desc="Sampling", dynamic_ncols=True, leave=True, ) with torch.no_grad(): if self.cond_net is not None and conds is not None: input_ids, attention_masks = self.tokenize(conds) key_padding_mask = (attention_masks == 0) y = self.cond_net(input_ids, key_padding_mask) else: y = None xt = init_samps for i in iterator: t_current = timesteps[i].item() t_prev = timesteps[i + 1].item() #assert t_current > t_prev or t_prev == 0 time = torch.full((self.batch_size,), t_current, device=self.device, dtype=torch.long) prev_time = torch.full((self.batch_size,), t_prev, device=self.device, dtype=torch.long) pred = self.diff_net(xt, time, y, clip_embeddings=None) xt, _ = self.rwd_ddim(xt, time, prev_time, pred) samps = torch.clamp(xt, min=self.norm_range[0], max=self.norm_range[1]) if norm_output: samps = (samps - self.norm_range[0]) / (self.norm_range[1] - self.norm_range[0]) if save_imgs: os.makedirs(save_path, exist_ok=True) for i in range(samps.size(0)): img_path = os.path.join(save_path, f"img_{i+1}.png") save_image(samps[i], img_path) return samps
[docs] def to(self, device: torch.device) -> Self: """Moves the module and its components to the specified device. Updates the device attribute and moves the reverse diffusion, noise predictor, and conditional model (if present) to the specified device. Parameters ---------- `device` : torch.device Target device for the module and its components. Returns ------- sample_ddim (SampleDDIM) - moved to the specified device. """ self.device = device self.diff_net.to(device) self.rwd_ddim.to(device) if self.cond_net: self.cond_net.to(device) return super().to(device)