UnCLIP#

UnCLIP Diffusion Model

This module provides a comprehensive implementation of the UnCLIP diffusion model, as described in Ramesh et al. (2022, “Hierarchical Text-Conditional Image Generation with CLIP Latents”). It integrates CLIP embeddings with diffusion processes for high-quality image generation conditioned on text prompts or image embeddings. The module supports training, sampling, and upsampling processes, leveraging components from CLIP, GLIDE, and DDIM, with classifier-free guidance and text dropout for robust generation.

Components

  • SchedulerUnCLIP: Manages noise schedules with support for linear, sigmoid, quadratic, constant, inverse_time,

    and cosine beta schedules, including subsampled (tau) schedules for efficient sampling.

  • ForwardUnCLIP: Forward diffusion process to add noise to image or latent embeddings.

  • ReverseUnCLIP: Reverse diffusion process for denoising, supporting noise or clean image predictions with subsampled steps.

  • CLIPEncoder: Encodes images or text into embeddings using a pre-trained CLIP model.

  • UnClipDecoder: Generates low-resolution images (64x64) from CLIP embeddings, incorporating GLIDE text encoding and classifier-free guidance.

  • UnCLIPTransformerPrior: Transformer-based prior to predict clean image embeddings from noisy embeddings and text conditions.

  • CLIPContextProjection: Projects CLIP image embeddings into context tokens for the decoder.

  • CLIPEmbeddingProjection: Reduces and reconstructs embedding dimensionality for efficient processing.

  • TrainUnClipDecoder: Orchestrates training of the decoder with mixed precision, gradient accumulation, and DDP support.

  • SampleUnCLIP: Generates images from text prompts or noise, scaling from 64x64 to 256x256 or 1024x1024 with upsamplers.

  • UpsamplerUnCLIP: U-Net-based upsampler for scaling images (64x64 to 256x256 or 256x256 to 1024x1024), conditioned on low-resolution inputs.

  • TrainUpsamplerUnCLIP: Trains the upsampler with noise prediction, low-resolution conditioning, and optional image corruption (Gaussian blur or BSR degradation).

Notes

  • The model uses a subsampled time step schedule (tau) for faster sampling, controlled by the tau_num_steps parameter in VarianceSchedulerUnCLIP.

  • Classifier-free guidance and text dropout enhance generation quality, with tunable parameters classifier_free_prop and drop_caption.

  • Upsampling stages use corrupted low-resolution inputs (Gaussian blur for 64x64→256x256, BSR degradation for 256x256→1024x1024) to improve robustness.

  • Supports distributed training with DDP, mixed precision via autocast, and learning rate scheduling with warmup and plateau reduction.

References

  • Ramesh, Aditya, et al. “Hierarchical Text-Conditional Image Generation with CLIP Latents.” arXiv preprint arXiv:2204.06125 (2022).

  • Radford, Alec, et al. “Learning Transferable Visual Models From Natural Language Supervision.” arXiv preprint arXiv:2103.00020 (2021).

  • Nichol, Alexander, et al. “GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models.” arXiv preprint arXiv:2112.10741 (2021).

  • Song, Jiaming, et al. “Denoising Diffusion Implicit Models.” arXiv preprint arXiv:2010.02502 (2020).


class torchdiff.unclip.SchedulerUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Variance scheduler for UnCLIP supporting multiple schedule types

Manages noise schedule parameters with support for both full training schedule and subsampled inference schedule for faster sampling.

set_inf_timesteps(num_inf_timesteps: int)[source]#

Dynamically change the number of inference steps

Allows using different numbers of steps at inference time.

get_index(t: torch.Tensor, x_shape: torch.Size) torch.Tensor[source]#

Extract coefficients at timestep t and reshape for broadcasting

class torchdiff.unclip.ForwardUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Forward diffusion process for UnCLIP

Applies Gaussian noise to input data according to the forward diffusion process. Supports both 2D (latent embeddings) and 4D (images) inputs.

q(x_t | x_0) = N(x_t; √ᾱ_t x_0, (1 - ᾱ_t)I)

forward(x0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#

Sample from q(x_t | x_0) and compute prediction target

Parameters:
  • x0 – (batch, …) clean data (2D or 4D)

  • t – (batch,) discrete timesteps in [0, train_steps-1]

  • noise – (batch, …) gaussian noise

Returns:

(batch, …) noised data target: (batch, …) prediction target (noise or x0)

Return type:

xt

class torchdiff.unclip.ReverseUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Reverse diffusion process for UnCLIP

Denoises input using DDIM-style sampling with the tau (subsampled) schedule. Supports both noise prediction and x0 prediction modes. Works with both 2D (latent embeddings) and 4D (images) inputs.

predict_x0(xt: torch.Tensor, t: torch.Tensor, pred: torch.Tensor) torch.Tensor[source]#

Convert model output to x0 prediction based on prediction type

predict_noise(xt: torch.Tensor, t: torch.Tensor, x0_pred: torch.Tensor) torch.Tensor[source]#

Predict noise from x0

ε̂ = (x_t - √ᾱ_t * x̂_0) / √(1 - ᾱ_t)

forward(xt: torch.Tensor, t: torch.Tensor, t_pre: torch.Tensor, pred: torch.Tensor) Tuple[torch.Tensor, torch.Tensor | None][source]#

UnCLIP reverse step from x_t to x_{t_prev}

Uses tau schedule (subsampled timesteps) for faster sampling.

Parameters:
  • xt – (batch, …) current state (2D or 4D)

  • t – (batch,) current tau timestep indices [0, sample_steps-1]

  • t_pre – (batch,) previous tau timestep indices

  • pred – (batch, …) model prediction

Returns:

(batch, …) previous state x_{t_prev} pred_x0: (batch, …) predicted x0 (if return_pred_x0=True)

Return type:

x_prev

set_pred_type(pred_type: str)[source]#

Change the prediction type after initialization

class torchdiff.unclip.CLIPEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Encodes images or text using a pre-trained CLIP model.

Loads a CLIP model and processor from the transformers library, providing methods to encode images or text into embeddings and compute similarity scores between them.

Parameters:
  • model_name (str, optional) – Name of the CLIP model to load (default: ‘openai/clip-vit-base-patch32’).

  • device (str, optional) – Device to run the model on (default: ‘cuda’ if available, else ‘cpu’).

  • use_fast (bool, optional) – Whether to use the fast image processor (torchvision-based) (default: False).

forward(data: torch.Tensor | List[str] | str | Image | List[Image], data_type: str, normalize: bool = True) torch.Tensor[source]#

Encodes input data (image or text) using the CLIP model.

Processes input data (images or text) to produce embeddings, with optional L2 normalization.

Parameters:
  • data (Union[torch.Tensor, List[str], str, Image.Image, List[Image.Image]]) –

    Input data to encode:
    • torch.Tensor: Preprocessed image tensor (batch_size, channels, height, width).

    • List[str] or str: Text or list of texts.

    • PIL.Image.Image or List[PIL.Image.Image]: Single or list of PIL images.

  • data_type (str) – Type of input data (‘img’ or ‘text’).

  • normalize (bool, optional) – Whether to L2-normalize the output embeddings (default: True).

Returns:

outputs – Encoded embeddings, shape (batch_size, embedding_dim).

Return type:

torch.Tensor

compute_similarity(image_features: torch.Tensor, text_features: torch.Tensor) torch.Tensor[source]#

Computes cosine similarity between image and text embeddings.

Calculates the cosine similarity matrix between batches of image and text embeddings.

Parameters:
  • image_features (torch.Tensor) – Image embeddings, shape (batch_size, embedding_dim).

  • text_features (torch.Tensor) – Text embeddings, shape (batch_size, embedding_dim).

Returns:

similarity – Cosine similarity scores, shape (batch_size, batch_size).

Return type:

torch.Tensor

class torchdiff.unclip.UnClipDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Decoder for UnCLIP diffusion models.

Combines CLIP image embeddings and text embeddings to guide the denoising process, using a noise predictor and diffusion processes. Incorporates classifier-free guidance, text caption dropout, and projection of CLIP embeddings into context tokens.

Parameters:
  • clip_embed_dim (int) – Dimensionality of the input embeddings.

  • diff_net (nn.Module) – Model to predict noise/x0 during the denoising process.

  • fwd_unclip (nn.Module) – Forward diffusion module (e.g., ForwardUnCLIP) for adding noise.

  • rwd_unclip (nn.Module) – Reverse diffusion module (e.g., ReverseUnCLIP) for denoising.

  • glide_text_encoder (nn.Module, optional) – GLIDE text encoder for processing text prompts, default None.

  • tokenizer (BertTokenizer, optional) – Tokenizer for processing text prompts, default None (loads “bert-base-uncased”).

  • device (str, optional) – Device for computation (default: CUDA).

  • norm_range (Tuple[float, float], optional) – Range for clamping output images (default: (-1.0, 1.0)).

  • norm_clip_embed (bool, optional) – Whether to normalize outputs (default: True).

  • classifier_free_prop (float, optional) – Probability for classifier-free guidance (default: 0.1, per paper).

  • drop_caption (float, optional) – Probability for text caption dropout (default: 0.5, per paper).

  • max_token_length (int, optional) – Maximum length for tokenized prompts (default: 77).

forward(img_embed: torch.Tensor, text_embed: torch.Tensor, imgs: torch.Tensor, texts: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#

Processes embeddings and images to predict noise for training.

Applies classifier-free guidance and text dropout, projects CLIP image embeddings into context tokens, encodes text with GLIDE, and predicts noise for the diffusion process.

Parameters:
  • img_embed (torch.Tensor) – CLIP image embeddings, shape (batch_size, embed_dim).

  • text_embed (torch.Tensor) – CLIP text embeddings, shape (batch_size, embed_dim).

  • imgs (torch.Tensor) – Input images, shape (batch_size, channels, height, width).

  • texts (torch.Tensor) – Text prompts for conditional generation.

Returns:

  • pred (torch.Tensor) – Predicted noise/x0 tensor, shape (batch_size, channels, height, width).

  • target (torch.Tensor) – Ground truth noise/x0 tensor, shape (batch_size, channels, height, width).

inference_forward(img_embed, prompt_embed)[source]#
class torchdiff.unclip.UnCLIPTransformerPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer-based prior model for UnCLIP diffusion.

Predicts clean image embeddings from noisy image embeddings and text embeddings using a Transformer architecture, incorporating time embeddings and optional projection layers for text and image inputs.

Parameters:
  • fwd_unclip (nn.Module) – Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.

  • rwd_unclip (nn.Module) – Reverse diffusion module (e.g., ReverseUnCLIP) for denoising during training.

  • clip_text_proj (nn.Module, optional) – Projection module for text embeddings, default None.

  • clip_img_proj (nn.Module, optional) – Projection module for image embeddings, default None.

  • trans_embed_dim (int, optional) – Dimensionality of embeddings (default: 320).

  • num_layers (int, optional) – Number of Transformer layers (default: 12).

  • num_att_heads (int, optional) – Number of attention heads in each Transformer layer (default: 8).

  • ff_dim (int, optional) – Dimensionality of the feedforward network in Transformer layers (default: 768).

  • max_sequence_length (int, optional) – Maximum sequence length for input embeddings (default: 2).

  • dropoute (float, optional) – Dropout probability for regularization (default: 0.2).

  • use_flash (bool, optional) – Enable flash attention if available (default: True).

  • grad_check (bool, optional) – Apply gradinet checkpointing (default: False).

  • check_every_n_layers (int, optional) – Frequency of applying gradient checkpoint (default: 2 layers)

forward(text_embed: torch.Tensor, noisy_img_embed: torch.Tensor, timesteps: torch.Tensor) torch.Tensor[source]#

Predicts clean image embeddings from noisy inputs and text embeddings.

Processes text and noisy image embeddings through a Transformer architecture, conditioned on time embeddings, to predict the clean image embeddings.

Parameters:
  • text_embed (torch.Tensor) – Text embeddings, shape (batch_size, embed_dim).

  • noisy_img_embed (torch.Tensor) – Noisy image embeddings, shape (batch_size, embed_dim).

  • timesteps (torch.Tensor) – Tensor of time step indices (long), shape (batch_size,).

Returns:

pred_clean_embed – Predicted clean image embeddings, shape (batch_size, embed_dim).

Return type:

torch.Tensor

enable_grad_check()[source]#

Enable gradient checkpointing for memory savings

disable_grad_check()[source]#

Disable gradient checkpointing

class torchdiff.unclip.TransformerBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single Transformer block with multi-head attention and feedforward layers.

Implements a Transformer block with multi-head self-attention, layer normalization, and a feedforward network with residual connections for processing sequences in the UnCLIPTransformerPrior model.

Parameters:
  • embed_dim (int) – Dimensionality of input and output embeddings.

  • num_heads (int) – Number of attention heads in the multi-head attention layer.

  • ff_dim (int) – Dimensionality of the feedforward network.

  • dropout (float) – Dropout probability for regularization.

  • use_falsh (bool) – Whethere use flash attention (default: True)

forward(x: torch.Tensor) torch.Tensor[source]#

Processes input sequence through the Transformer block.

Applies multi-head self-attention followed by a feedforward network, with residual connections and layer normalization.

Parameters:

x (torch.Tensor) – Input sequence tensor, shape (batch_size, sequence_length, embed_dim).

Returns:

`x` – Processed sequence tensor, shape (batch_size, sequence_length, embed_dim).

Return type:

torch.Tensor

class torchdiff.unclip.FusedGELU(*args: Any, **kwargs: Any)[source]#

Bases: Module

Fused GELU activation for better efficiency on some hardware

forward(x: torch.Tensor) torch.Tensor[source]#
class torchdiff.unclip.CLIPContextProjection(*args: Any, **kwargs: Any)[source]#

Bases: Module

Projects CLIP image embeddings into multiple context tokens.

Transforms a single CLIP image embedding into a specified number of context tokens using a linear projection followed by layer normalization.

Parameters:
  • clip_embed_dim (int) – Dimensionality of the input CLIP embedding (e.g., 320 or 512).

  • num_tokens (int, optional) – Number of context tokens to generate (default: 4).

  • output_dim (int, optional) – Dimensionality of each output context token. If None, defaults to clip_embed_dim. Use this when the input embedding has been reduced in dimensionality but the output tokens need to match a different dimension (e.g., GLIDE text encoder output).

forward(z_i)[source]#

Projects CLIP image embedding into context tokens.

Applies a linear projection to transform the input embedding into multiple tokens, reshapes the output, and applies layer normalization.

Parameters:

z_i (torch.Tensor) – Input CLIP image embedding, shape (batch_size, input_dim).

Returns:

c – Context tokens, shape (batch_size, num_tokens, output_dim).

Return type:

torch.Tensor

class torchdiff.unclip.CLIPEmbeddingProjection(*args: Any, **kwargs: Any)[source]#

Bases: Module

Projection module for dimensionality reduction and reconstruction.

Implements a neural network with forward and inverse projections to reduce and restore input dimensionality, supporting customizable hidden layers, dropout, and layer normalization.

Parameters:
  • clip_embed_dim (int, optional) – Input dimensionality (default: 1024).

  • trans_embed_dim (int, optional) – Output dimensionality for forward projection (default: 320).

  • hidden_dim (int, optional) – Inner dimension of projection (default: 512).

  • dropout (float) – Dropout rate (default: 0.2)

  • use_layer_norm (bool) – If normalize output (default: True)

forward(x: torch.Tensor) torch.Tensor[source]#

Projects input to a lower-dimensional space.

Applies the forward projection network to reduce the dimensionality of the input tensor.

Parameters:

x (torch.Tensor) – Input tensor to be projected, shape (batch_size, input_dim).

Returns:

x_reduced – Projected tensor, shape (batch_size, output_dim).

Return type:

torch.Tensor

inverse_transform(x: torch.Tensor) torch.Tensor[source]#

Reconstructs input from lower-dimensional space.

Applies the inverse projection network to restore the original dimensionality of the input tensor.

Parameters:

x (torch.Tensor) – Reduced-dimensionality tensor, shape (batch_size, output_dim).

Returns:

x – Reconstructed tensor, shape (batch_size, input_dim).

Return type:

torch.Tensor

rec_loss(x: torch.Tensor) torch.Tensor[source]#

Computes the reconstruction loss for the projection.

Calculates the mean squared error between the original input and its reconstruction after forward and inverse projections.

Parameters:

x (torch.Tensor) – Original input tensor, shape (batch_size, input_dim).

Returns:

loss – Mean squared error loss between the original and reconstructed tensors.

Return type:

torch.Tensor

class torchdiff.unclip.TrainUnClipDecoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Trainer for the UnCLIP decoder model.

Orchestrates the training of the UnCLIP decoder model, integrating CLIP embeddings, forward and reverse diffusion processes, and optional dimensionality reduction. Supports mixed precision, gradient accumulation, DDP, and comprehensive evaluation metrics.

Parameters:
  • clip_embed_dim (int) – Dimensionality of the input embeddings.

  • decoder_net (nn.Module) – The UnCLIP decoder model (e.g., UnClipDecoder) to be trained.

  • clip_net (nn.Module) – CLIP model for generating text and image embeddings.

  • train_loader (torch.utils.data.DataLoader) – DataLoader for training data.

  • optim (torch.optim.Optimizer) – Optimizer for training the decoder model.

  • loss_fn (Callable) – Loss function to compute the difference between predicted and target noise.

  • clip_text_proj (nn.Module, optional) – Projection module for text embeddings, default None.

  • clip_img_proj (nn.Module, optional) – Projection module for image embeddings, default None.

  • val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data, default None.

  • metrics_ (Any, optional) – Object providing evaluation metrics (e.g., FID, MSE, PSNR, SSIM, LPIPS), default None.

  • max_epochs (int, optional) – Maximum number of training epochs (default: 100).

  • device (str, optional) – Device for computation (default: CUDA).

  • store_path (str, optional) – Directory to save model checkpoints (default: “unclip_decoder”).

  • patience (int, optional) – Number of epochs to wait for improvement before early stopping (default: 20).

  • warmup_steps (int, optional) – Number of epochs for learning rate warmup (default: 10000).

  • val_freq (int, optional) – Frequency (in epochs) for validation (default: 10).

  • use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).

  • grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).

  • log_freq (int, optional) – Frequency (in epochs) for printing progress (default: 1).

  • use_comp (bool, optional) – Whether to compile the model using torch.compile (default: False).

  • norm_range (Tuple[float, float], optional) – Range for clamping output images (default: (-1.0, 1.0)).

  • reduce_clip_embed_dim (bool, optional) – Whether to apply dimensionality reduction to embeddings (default: True).

  • trans_embed_dim (int, optional) – Output dimensionality for reduced embeddings (default: 312).

  • norm_clip_embed (bool, optional) – Whether to normalize CLIP embeddings (default: True).

  • finetune_clip_proj (bool, optional) – Whether to fine-tune projection layers (default: False).

  • use_autocast (bool) – Whether use mix percision for efficienty (default: True)

forward() Dict[source]#

Trains the UnCLIP decoder model to predict noise for denoising.

Executes the training loop, optimizing the decoder model using CLIP embeddings, mixed precision, gradient clipping, and learning rate scheduling. Supports validation, early stopping, and checkpointing.

Returns:

loses

Return type:

a ductionlaty of losses (train and validation losses)

static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#

Creates a learning rate scheduler for warmup.

Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.

  • warmup_steps (int) – Number of steps for the warmup phase.

Returns:

lr_scheduler – Learning rate scheduler for warmup.

Return type:

torch.optim.lr_scheduler.LambdaLR

load_checkpoint(check_path: str) Tuple[int, float][source]#

Loads model checkpoint.

Restores the state of the decoder model, its submodules, optimizer, and schedulers from a saved checkpoint, handling DDP compatibility.

Parameters:

check_path (str) – Path to the checkpoint file.

Returns:

  • epoch (int) – The epoch at which the checkpoint was saved.

  • loss (float) – The loss at the checkpoint.

validate() Tuple[float, float | None, float | None, float | None, float | None, float | None][source]#

Validates the UnCLIP decoder model.

Computes validation loss and optional metrics (FID, MSE, PSNR, SSIM, LPIPS) by encoding images and texts, applying forward diffusion, predicting noise, and reconstructing images through reverse diffusion.

Returns:

  • val_loss (float) – Mean validation loss.

  • fid_avg (float or None) – Average FID score, if computed.

  • mse_avg (float or None) – Average MSE score, if computed.

  • psnr_avg (float or None) – Average PSNR score, if computed.

  • ssim_avg (float or None) – Average SSIM score, if computed.

  • lpips_avg (float or None) – Average LPIPS score, if computed.

class torchdiff.unclip.TrainUnCLIPPrior(*args: Any, **kwargs: Any)[source]#

Bases: Module

Trainer for the UnCLIPTransformerPrior model.

Handles the training of the UnCLIP prior model to predict clean image embeddings from noisy image embeddings and text embeddings, with support for dimension reduction, mixed precision training, and distributed training.

Parameters:
  • prior_net (nn.Module) – The UnCLIP prior model to be trained (e.g., UnCLIPTransformerPrior).

  • clip_net (nn.Module) – CLIP model for encoding text and images.

  • train_loader (torch.utils.data.DataLoader) – DataLoader for training data.

  • optim (torch.optim.Optimizer) – Optimizer for training the prior model.

  • loss_fn (Callable) – Loss function to compute the difference between predicted and target embeddings.

  • val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data, default None.

  • max_epochs (int, optional) – Maximum number of training epochs (default: 100).

  • device (str, optional) – Device for computation (default: CUDA).

  • store_path (str, optional) – Directory path to save model checkpoints, default ‘unclip_prior_train’”.

  • patience (int, optional) – Number of epochs to wait for improvement before early stopping (default: 20).

  • warmup_steps (int, optional) – Number of epochs for learning rate warmup (default: 10000).

  • val_freq (int, optional) – Frequency (in epochs) for validation (default: 10).

  • use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).

  • grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).

  • log_freq (int, optional) – Frequency (in epochs) for printing training progress (default: 1).

  • use_comp (bool, optional) – Whether to compile models for optimization (default: False).

  • nor_range (Tuple[float, float], optional) – Range for clamping output embeddings (default: (-1.0, 1.0)).

  • reduce_clip_embed_dim (bool, optional) – Whether to apply dimension reduction to embeddings (default: True).

  • trans_embed_dim (int, optional) – Target dimensionality for reduced embeddings (default: 319).

  • norm_clip_embed (bool) – Whether clip embedding are normalized (default: True)

  • use_autocast (bool) – Whether mix percision is applied (default: True)

static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#

Creates a learning rate scheduler for warmup.

Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.

  • warmup_steps (int) – Number of steps for the warmup phase.

Returns:

lr_scheduler – Learning rate scheduler for warmup.

Return type:

torch.optim.lr_scheduler.LambdaLR

forward() Dict[source]#

Trains the UnCLIP prior model.

Executes the training loop, optimizing the prior model to predict clean image embeddings from noisy embeddings and text conditions, with support for validation, early stopping, and checkpointing.

Returns:

losses

Return type:

dictionlaty contains train and validation losses

validate() float[source]#

Validates the UnCLIP prior model.

Computes the validation loss by encoding images and text, applying forward diffusion, predicting clean embeddings, and comparing with target embeddings.

Returns:

val_loss – Mean validation loss, synchronized across processes if DDP is enabled.

Return type:

float

load_checkpoint(check_path: str) Tuple[int, float][source]#

Loads a model checkpoint to resume training.

Restores the prior model and optimizer states from a saved checkpoint, handling DDP compatibility for state dictionaries.

Parameters:

checkpoint_path (str) – Path to the checkpoint file.

Returns:

  • epoch (int) – The epoch at which the checkpoint was saved.

  • loss (float) – The loss value at the checkpoint.

class torchdiff.unclip.SampleUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Generates images using the UnCLIP model pipeline.

Combines a prior model, decoder model, CLIP model, and upsampler models to generate images from text prompts or noise. Performs diffusion-based sampling with classifier-free guidance in both prior and decoder stages, followed by upsampling to higher resolutions.

Parameters:
  • prior_net (nn.Module) – The UnCLIP prior model for generating image embeddings from text.

  • decoder_net (nn.Module) – The UnCLIP decoder model for generating low-resolution images from embeddings.

  • clip_net (nn.Module) – CLIP model for encoding text prompts into embeddings.

  • low_res_upsampler (nn.Module) – First upsampler model for scaling images from 64x64 to 256x256.

  • high_res_upsampler (nn.Module, optional) – Second upsampler model for scaling images from 256x256 to 1024x1024, default None.

  • device (str, optional) – Device for computation (default: CUDA).

  • offload_device (str) – Device for offloading (default: CPU)

  • clip_embed_dim (int, optional) – Dimensionality of CLIP embeddings (default: 512).

  • prior_guidance_scale (float, optional) – Classifier-free guidance scale for the prior model (default: 4.0).

  • decoder_guidance_scale (float, optional) – Classifier-free guidance scale for the decoder model (default: 8.0).

  • batch_size (int, optional) – Number of images to generate per batch (default: 1).

  • norm_clip_embed (bool, optional) – Whether to normalize CLIP embeddings (default: True).

  • prior_dim_reduction (bool, optional) – Whether to apply dimensionality reduction in the prior model (default: True).

  • init_img_size (Tuple[int, int, int], optional) – Size of the initial generated images (default: (3, 64, 64) for RGB 64x64).

  • use_high_res_upsampler (bool, optional) – Whether to use the second upsampler for 1024x1024 output (default: True).

  • norm_range (Tuple[float, float], optional) – Range for clamping output images (default: (-1.0, 1.0)).

  • use_model_offloading (bool) – Whether model offloading is used (default: True)

forward(prompts: str | List | None = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = 'unclip_samples')#

Generates images from text prompts or noise using the UnCLIP pipeline.

Executes the full UnCLIP generation process: prior model generates image embeddings, decoder model generates 64x64 images, first upsampler scales to 256x256, and optional second upsampler scales to 1024x1024. Supports classifier-free guidance and saves generated images if requested.

Parameters:
  • prompts (Union[str, List], optional) – Text prompt(s) for conditional generation, default None (unconditional).

  • norm_output (bool, optional) – Whether to normalize output images to [0, 1] range (default: True).

  • save_images (bool, optional) – Whether to save generated images to disk (default: True).

  • save_path (str, optional) – Directory to save generated images (default: “unclip_generated”).

Returns:

final_images – Generated images, shape (batch_size, channels, height, width), either 256x256 or 1024x1024 depending on use_second_upsampler.

Return type:

torch.Tensor

class torchdiff.unclip.UpsamplerUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Diffusion-based upsampler for UnCLIP models.

A U-Net-like model that upsamples low-resolution images to high-resolution images, conditioned on noisy high-resolution images and timesteps, using residual blocks, downsampling, and upsampling layers.

Parameters:
  • fwd_unclip (nn.Module) – Forward diffusion module (e.g., ForwardUnCLIP) for adding noise during training.

  • rwd_unclip (nn.Module) – Reverse diffusion module (e.g., ReverseUnCLIP) for removing noise during sampling.

  • in_channels (int, optional) – Number of input channels (default: 3, for RGB images).

  • out_channels (int, optional) – Number of output channels (default: 3, for RGB noise prediction).

  • model_channels (int, optional) – Base number of channels in the model (default: 192).

  • num_res_blocks (int, optional) – Number of residual blocks per resolution level (default: 2).

  • channel_mult (Tuple[int, ...], optional) – Channel multiplier for each resolution level (default: (1, 2, 4, 8)).

  • dropout (float, optional) – Dropout probability for regularization (default: 0.1).

  • time_embed_dim (int, optional) – Dimensionality of time embeddings (default: 768).

  • low_res_size (int, optional) – Spatial size of low-resolution input (default: 64).

  • high_res_size (int, optional) – Spatial size of high-resolution output (default: 256).

forward(x_high: torch.Tensor, t: torch.Tensor, x_low: torch.Tensor) torch.Tensor[source]#

Predicts noise for the upsampling process.

Processes a noisy high-resolution image and a low-resolution conditioning image, conditioned on timesteps, to predict the noise component for denoising.

Parameters:
  • x_high (torch.Tensor) – Noisy high-resolution image, shape (batch_size, in_channels, high_res_size, high_res_size).

  • t (torch.Tensor) – Timestep indices, shape (batch_size,).

  • x_low (torch.Tensor) – Low-resolution conditioning image, shape (batch_size, in_channels, low_res_size, low_res_size).

Returns:

out – Predicted noise, shape (batch_size, out_channels, high_res_size, high_res_size).

Return type:

torch.Tensor

class torchdiff.unclip.SinusoidalPositionalEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

Sinusoidal positional embedding for timesteps.

Generates sinusoidal embeddings for timesteps to condition the upsampler on the diffusion process stage.

Parameters:

dim (int) – Dimensionality of the embedding.

forward(t: torch.Tensor) torch.Tensor[source]#

Generates sinusoidal embeddings for timesteps.

Parameters:

timesteps (torch.Tensor) – Timestep indices, shape (batch_size,).

Returns:

embeddings – Sinusoidal embeddings, shape (batch_size, dim).

Return type:

torch.Tensor

class torchdiff.unclip.ResBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Residual block with time embedding and conditioning.

A convolutional residual block with group normalization, time embedding conditioning, and optional scale-shift normalization, used in the UnCLIP upsampler.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • time_embed_dim (int) – Dimensionality of time embeddings.

  • dropout (float, optional) – Dropout probability (default: 0.1).

  • use_scale_shift_norm (bool, optional) – Whether to use scale-shift normalization for time embeddings (default: True).

forward(x: torch.Tensor, time_emb: torch.Tensor) torch.Tensor[source]#

Processes input through the residual block with time conditioning.

Parameters:
  • x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).

  • time_emb (torch.Tensor) – Time embeddings, shape (batch_size, time_embed_dim).

Returns:

out – Output tensor, shape (batch_size, out_channels, height, width).

Return type:

torch.Tensor

class torchdiff.unclip.UpsampleBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Upsampling block using transposed convolution.

Increases the spatial resolution of the input tensor using a transposed convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

forward(x: torch.Tensor) torch.Tensor[source]#

Upsamples the input tensor.

Parameters:

x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).

Returns:

out – Upsampled tensor, shape (batch_size, out_channels, height*2, width*2).

Return type:

torch.Tensor

class torchdiff.unclip.DownsampleBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

Downsampling block using strided convolution.

Reduces the spatial resolution of the input tensor using a strided convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

forward(x: torch.Tensor) torch.Tensor[source]#

Downsamples the input tensor.

Parameters:

x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).

Returns:

out – Downsampled tensor, shape (batch_size, out_channels, height//2, width//2).

Return type:

torch.Tensor

class torchdiff.unclip.TrainUpsamplerUnCLIP(*args: Any, **kwargs: Any)[source]#

Bases: Module

Trainer for the UnCLIP upsampler model.

Orchestrates the training of the UnCLIP upsampler model, integrating forward diffusion, noise prediction, and low-resolution image conditioning with optional corruption (Gaussian blur or BSR degradation). Supports mixed precision, gradient accumulation, DDP, and comprehensive training utilities.

Parameters:
  • up_net (nn.Module) – The UnCLIP upsampler model (e.g., UpsamplerUnCLIP) to be trained.

  • train_loader (torch.utils.data.DataLoader) – DataLoader for training data, providing low- and high-resolution image pairs.

  • optim (torch.optim.Optimizer) – Optimizer for training the upsampler model.

  • loss_fn (Callable) – Loss function to compute the difference between predicted and target noise.

  • val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data, default None.

  • max_epochs (int, optional) – Maximum number of training epochs (default: 100).

  • device (str, optional) – Device for computation (default: CUDA).

  • store_path (str, optional) – Directory to save model checkpoints (default: “unclip_upsampler”).

  • patience (int, optional) – Number of epochs to wait for improvement before early stopping (default: 20).

  • warmup_steps (int, optional) – Number of epochs for learning rate warmup (default: 10000).

  • val_freq (int, optional) – Frequency (in epochs) for validation (default: 10).

  • use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).

  • grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).

  • log_freq (int, optional) – Frequency (in epochs) for printing progress (default: 1).

  • use_comp (bool, optional) – Whether to compile the model using torch.compile (default: False).

  • norm_range (Tuple[float, float], optional) – Range for clamping output images (default: (-1.0, 1.0)).

  • norm_out (bool, optional) – Whether to normalize inputs/outputs (default: True).

  • use_autocast (bool, optional) – Whether to use automatic mixed precision training (default: True).

forward() Dict[source]#

Trains the UnCLIP upsampler model to predict noise for denoising.

Executes the training loop, optimizing the upsampler model using low- and high-resolution image pairs, mixed precision, gradient clipping, and learning rate scheduling. Supports validation, early stopping, and checkpointing.

Returns:

losses

Return type:

dictionary contaions train and validation losses.

static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#

Creates a learning rate scheduler for warmup.

Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs.

Parameters:
  • optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.

  • warmup_steps (int) – Number of steps for the warmup phase.

Returns:

lr_scheduler – Learning rate scheduler for warmup.

Return type:

torch.optim.lr_scheduler.LambdaLR

corrupt_cond_img(x_low: torch.Tensor, corr_type: str = 'gaussian_blur') torch.Tensor[source]#

Corrupts the low-resolution conditioning image for robustness.

Applies Gaussian blur or BSR degradation to the low-resolution image to simulate real-world degradation, as specified in the UnCLIP paper.

Parameters:
  • x_low (torch.Tensor) – Low-resolution input image, shape (batch_size, channels, low_res_size, low_res_size).

  • corr_type (str, optional) – Type of corruption to apply: “gaussian_blur” or “bsr_degradation” (default: “gaussian_blur”).

Returns:

x_degraded – Corrupted low-resolution image, same shape as input.

Return type:

torch.Tensor

validate() float[source]#

Validates the UnCLIP upsampler model.

Computes the validation loss by applying forward diffusion to high-resolution images, predicting noise with the upsampler model conditioned on corrupted low-resolution images, and comparing predicted noise to ground truth.

Returns:

val_loss – Mean validation loss.

Return type:

float

load_checkpoint(check_path: str) Tuple[int, float][source]#

Loads model checkpoint.

Restores the state of the upsampler model, its variance scheduler, optimizer, and schedulers from a saved checkpoint, handling DDP compatibility.

Parameters:

checkpoint_path (str) – Path to the checkpoint file.

Returns:

  • epoch (int) – The epoch at which the checkpoint was saved.

  • loss (float) – The loss at the checkpoint.