DDIM#
Denoising Diffusion Implicit Models (DDIM)
This module provides a complete implementation of DDIM, as described in Song et al. (2021, “Denoising Diffusion Implicit Models”). It includes components for forward and reverse diffusion processes, hyperparameter management, training, and image sampling. Supports both unconditional and conditional generation with text prompts, using a subsampled time step schedule for faster sampling compared to DDPM.
Components
ForwardDDIM: Forward diffusion process to add noise.
ReverseDDIM: Reverse diffusion process to denoise with subsampled steps.
SchedulerDDIM: Noise schedule management with subsampled (tau) schedule.
TrainDDIM: Training loop with mixed precision and scheduling.
SampleDDIM: Image generation from trained models with subsampled steps.
Notes
The subsampled time step schedule (tau) enables faster sampling, controlled by the tau_num_steps parameter in VarianceSchedulerDDIM.
References:
Song, Jiaming, Chenlin Meng, and Stefano Ermon. “Denoising diffusion implicit models.” arXiv preprint arXiv:2010.02502 (2020).
- class torchdiff.ddim.ForwardDDIM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleImplements the forward (noising) process of DDIM.
This module samples x_t from the forward diffusion distribution:
q(x_t | x_0) = N(x_t; sqrt(alphā_t) * x_0, (1 - alphā_t) * I)
It also computes the appropriate training target depending on the prediction parameterization (noise, x0, or v-prediction).
- Parameters:
scheduler – Noise scheduler containing precomputed diffusion coefficients.
pred_type – Type of model prediction. One of [“noise”, “x0”, “v”].
- forward(x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#
Perform the forward diffusion step and compute the training target.
Samples x_t by adding noise to the clean input x_0 at timestep t, and returns the corresponding supervision target for training.
- Parameters:
x0 – Clean input data of shape (batch, …).
t – Discrete diffusion timesteps of shape (batch,).
noise – Gaussian noise of same shape as x0.
- Returns:
Noised data x_t of shape (batch, …). target: Training target corresponding to pred_type:
”noise”: the added noise ε
”x0”: the original clean input x0
”v”: the velocity parameterization
- Return type:
xt
- class torchdiff.ddim.ReverseDDIM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleImplements the reverse (denoising) process of DDIM.
Computes x_{t_prev} from x_t using the DDIM update rule:
- x_{t_prev} = sqrt(alphā_{t_prev}) * x̂_0
sqrt(1 - alphā_{t_prev} - σ_t²) * ε̂_t
σ_t * z
where σ_t controls stochasticity via eta. Setting eta=0 results in deterministic DDIM sampling.
- Parameters:
scheduler – Noise scheduler containing diffusion coefficients.
pred_type – Model prediction type [“noise”, “x0”, “v”].
eta – Controls stochasticity of sampling (0 = deterministic).
clip – Whether to clip predicted x0 to [-1, 1].
- predict_x0(xt: torch.Tensor, t: torch.Tensor, pred: torch.Tensor) torch.Tensor[source]#
Convert model output into a prediction of the clean sample x0.
The conversion depends on the chosen prediction parameterization (noise, x0, or v-prediction).
- Parameters:
xt – Noisy input at timestep t of shape (batch, …).
t – Current timesteps of shape (batch,).
pred – Model output of shape (batch, …).
- Returns:
Predicted clean sample x0.
- Return type:
x0_pred
- predict_noise(xt: torch.Tensor, t: torch.Tensor, x0_pred: torch.Tensor) torch.Tensor[source]#
Compute the predicted noise ε̂_t from x_t and predicted x0.
- Uses the identity:
ε̂_t = (x_t - sqrt(alphā_t) * x̂_0) / sqrt(1 - alphā_t)
- Parameters:
xt – Noisy input at timestep t.
t – Current timesteps.
x0_pred – Predicted clean sample x0.
- Returns:
Predicted noise ε̂_t.
- Return type:
pred_noise
- forward(xt: torch.Tensor, t: torch.Tensor, t_prev: torch.Tensor, pred: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None][source]#
Perform one DDIM reverse diffusion step.
Computes x_{t_prev} from x_t using the DDIM update equation. Allows non-adjacent timesteps, enabling accelerated sampling.
- Parameters:
xt – Current noisy sample x_t of shape (batch, …).
t – Current timestep indices of shape (batch,).
t_prev – Previous timestep indices of shape (batch,).
pred – Model prediction at timestep t.
- Returns:
Sample at timestep t_prev. pred_x0: Predicted clean sample x0.
- Return type:
x_prev
- class torchdiff.ddim.SchedulerDDIM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleNoise scheduler for DDIM.
Responsible for constructing the diffusion noise schedule and precomputing all coefficients required for both training and sampling.
Supports multiple beta schedules and allows using fewer inference steps than training steps.
- Parameters:
schedule_type – Type of beta schedule (“linear”, “cosine”, etc.).
train_steps – Number of diffusion steps used during training.
sample_steps – Number of steps used during inference.
beta_min – Minimum beta value.
beta_max – Maximum beta value.
cosine_s – Offset parameter for cosine schedule.
clip_min – Minimum clipping value for betas.
clip_max – Maximum clipping value for betas.
learn_var – Whether posterior variance is learnable.
- set_inference_timesteps(num_inference_timesteps: int)[source]#
Update the number of inference timesteps dynamically.
Allows changing sampling speed and quality trade-offs at inference time without retraining the model.
- Parameters:
num_inference_timesteps – Number of timesteps to use for sampling.
- get_index(t: torch.Tensor, x_shape: torch.Size) torch.Tensor[source]#
Reshape timestep-dependent coefficients for broadcasting.
Extracts values indexed by t and reshapes them to match the dimensionality of a given tensor shape.
- Parameters:
t – Timesteps tensor of shape (batch,).
x_shape – Shape of the target tensor.
- Returns:
Reshaped tensor suitable for broadcasting.
- class torchdiff.ddim.TrainDDIM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleTrainer for Denoising Diffusion Implicit Models (DDIM).
Manages the training process for DDIM, optimizing a noise predictor model to learn the noise added by the forward diffusion process. Supports conditional training with text prompts, mixed precision training, learning rate scheduling, early stopping, and checkpointing, as inspired by Song et al. (2021).
- Parameters:
diff_net (nn.Module) – Main model to predict noise/v/x0
fwd_ddim (nn.Module) – Forward DDIM diffusion module for adding noise.
rwd_ddim (nn.Module) – Reverse DDIM diffusion module for denoising.
data_loader (torch.utils.data.DataLoader) – DataLoader for training data.
optim (torch.optim.Optimizer) – Optimizer for training the noise predictor and conditional model (if applicable).
loss_fn (callable) – Loss function to compute the difference between predicted and actual noise.
val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data, default None.
max_epochs (int, optional) – Maximum number of training epochs (default: 100).
device (str) – Device for computation (default: CUDA).
cond_net (nn.Module, optional) – Model for conditional generation (e.g., text embeddings), default None.
metrics_ (object, optional) – Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
tokenizer (BertTokenizer, optional) – Tokenizer for processing text prompts, default None (loads “bert-base-uncased”).
max_token_length (int, optional) – Maximum length for tokenized prompts (default: 77).
store_path (str, optional) – Path to save model checkpoints (default: “ddim_train”).
patience (int, optional) – Number of epochs to wait for improvement before early stopping (default: 20).
warmup_steps (int, optional) – Number of epochs for learning rate warmup (default: 1000).
val_freq (int, optional) – Frequency (in epochs) for validation (default: 10).
norm_range (tuple, optional) – Range for clamping generated images (default: (-1, 1)).
norm_output (bool, optional) – Whether to normalize generated images to [0, 1] for metrics (default: True).
use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).
grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).
log_freq (int, optional) – Number of epochs before printing loss.
use_comp (bool, optional) – whether the model is internally compiled using torch.compile (default: false)
- load_checkpoint(checkpoint_path: str) Tuple[int, float][source]#
Loads a training checkpoint to resume training.
Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint. Handles DDP model state dict loading.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file.
- Returns:
epoch (int) – The epoch at which the checkpoint was saved.
loss (float) – The loss at the checkpoint.
- static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#
Creates a learning rate scheduler for warmup.
Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs, then maintains it.
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.
warmup_steps (int) – Number of steps for the warmup phase.
- Return type:
lr_scheduler (torch.optim.lr_scheduler.LambdaLR) - Learning rate scheduler for warmup.
- forward() Dict[source]#
Trains the DDIM model to predict noise added by the forward diffusion process.
Executes the training loop, optimizing the noise predictor and conditional model (if applicable) using mixed precision, gradient clipping, and learning rate scheduling. Supports validation, early stopping, and checkpointing.
- Returns:
losses
- Return type:
dictionlary contains train and validation losses
- validate() Tuple[float, float, float, float, float, float][source]#
Validates the noise predictor and computes evaluation Metrics.
Computes validation loss (MSE between predicted and ground truth noise) and generates samples using the reverse diffusion model by manually iterating over timesteps. Decodes samples to images and computes image-domain Metrics (MSE, PSNR, SSIM, FID, LPIPS) if metrics_ is provided.
- Returns:
val_loss (float) – Mean validation loss.
fid (float, or float(‘inf’) if not computed) – Mean FID score.
mse (float, or None if not computed) – Mean MSE
psnr (float, or None if not computed) – Mean PSNR
ssim (float, or None if not computed) – Mean SSIM
lpips_score (float, or None if not computed) – Mean LPIPS score
- class torchdiff.ddim.SampleDDIM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleImage generation using a trained DDIM model.
Implements the sampling process for DDIM, generating images by iteratively denoising random noise using a trained noise predictor and reverse diffusion process with a subsampled time step schedule. Supports conditional generation with text prompts, as inspired by Song et al. (2021).
- Parameters:
rwd_ddim (nn.Module) – Reverse diffusion module (e.g., ReverseDDIM) for the reverse process.
diff_net (nn.Module) – Trained model to predict noise/v/x0 at each time step.
img_size (tuple) – Tuple of (height, width) specifying the generated image dimensions.
cond_net (nn.Module, optional) – Model for conditional generation (e.g., text embeddings), default None.
tokenizer (str, optional) – Pretrained tokenizer name from Hugging Face (default: “bert-base-uncased”).
max_length (int, optional) – Maximum length for tokenized prompts (default: 77).
batch_size (int, optional) – Number of images to generate per batch (default: 1).
in_channels (int, optional) – Number of input channels for generated images (default: 3).
device (str) – Device for computation (default: CUDA).
norm_range (tuple, optional) – Tuple of (min, max) for clamping generated images (default: (-1, 1)).
- tokenize(prompts: List | str) Tuple[torch.Tensor, torch.Tensor][source]#
Tokenizes text prompts for conditional generation.
Converts input prompts into tokenized input IDs and attention masks using the specified tokenizer, suitable for use with the conditional model.
- Parameters:
prompts (str or list) – A single text prompt or a list of text prompts.
- Returns:
input_ids (torch.Tensor) – Tokenized input IDs, shape (batch_size, max_length).
attention_mask (torch.Tensor) – Attention mask, shape (batch_size, max_length).
- forward(conds: str | List | None = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = 'ddim_samples') torch.Tensor[source]#
Generates images using the DDIM sampling process.
Iteratively denoises random noise to generate images using the reverse diffusion process with a subsampled time step schedule and noise predictor. Supports conditional generation with text prompts.
- Parameters:
conds (str or list, optional) – Text prompt(s) for conditional generation, default None.
norm_output (bool, optional) – If True, normalizes output images to [0, 1] (default: True).
save_imgs (bool, optional) – If True, saves generated images to save_path (default: True).
save_path (str, optional) – Directory to save generated images (default: “ddim_samples”).
- Return type:
samps (torch.Tensor) - Generated images, shape (batch_size, in_channels, height, width).
- to(device: torch.device) Self[source]#
Moves the module and its components to the specified device.
Updates the device attribute and moves the reverse diffusion, noise predictor, and conditional model (if present) to the specified device.
- Parameters:
device (torch.device) – Target device for the module and its components.
- Return type:
sample_ddim (SampleDDIM) - moved to the specified device.