LDM#
Latent Diffusion Models (LDM)
This module provides a framework for training and sampling Latent Diffusion Models, as described in Rombach et al. (2022, “High-Resolution Image Synthesis with Latent Diffusion Models”). It supports diffusion in the latent space using a variational autoencoder (compressor model), includes utilities for training the autoencoder, noise predictor, and conditional model, and provides metrics for evaluating generated images. The framework is compatible with DDPM, DDIM, and SDE diffusion models, supporting both unconditional and conditional generation with text prompts.
Components
AutoencoderLDM: Variational autoencoder for compressing images to latent space and decoding back to image space.
TrainAE: Trainer for AutoencoderLDM, optimizing reconstruction and regularization losses with evaluation metrics.
TrainLDM: Training loop with mixed precision, warmup, and scheduling for the noise predictor and conditional model (e.g., TextEncoder with projection layers) in latent space, with image-domain evaluation metrics using a reverse diffusion model.
SampleLDM: Image generation from trained models, decoding from latent to image space.
Notes
The scheduler parameter expects an external hyperparameter module (e.g., SchedulerDDPM, SchedulerSDE) as an nn.Module for noise schedule management.
AutoencoderLDM serves as the comp_net in TrainLDM and SampleLDM, providing encode and decode methods for latent space conversion. It supports KL-divergence or vector quantization (VQ) regularization, using internal components (DownBlock, UpBlock, Conv3, DownSampling, UpSampling, Attention, VectorQuantizer).
TrainAE trains AutoencoderLDM, optimizing reconstruction (MSE), regularization (KL or VQ), and optional perceptual (LPIPS) losses, with metrics (MSE, PSNR, SSIM, FID, LPIPS) computed via the Metrics class, KL warmup, early stopping, and learning rate scheduling.
TrainLDM trains the noise predictor and conditional model, optimizing MSE between predicted and ground truth noise, with optional validation metrics (MSE, PSNR, SSIM, FID, LPIPS) on generated images decoded from latents sampled using a reverse diffusion model (e.g., ReverseDDPM).
SampleLDM supports multiple diffusion models (“ddpm”, “ddim”, “sde”) via the model parameter, requiring compatible reverse_diffusion modules (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
References
Rombach, Robin, et al. “High-resolution image synthesis with latent diffusion models.”
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
Esser, Patrick, Robin Rombach, and Bjorn Ommer. “Taming transformers for high-resolution image synthesis.”
Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2021.
- class torchdiff.ldm.TrainLDM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleTrainer for the noise/score/v predictor in Latent Diffusion Models.
Optimizes the noise predictor and conditional model (e.g., TextEncoder) to predict noise in the latent space of AutoencoderLDM, using a diffusion model (e.g., DDPM, DDIM, SDE). Supports mixed precision, conditional generation with text prompts, and evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS) for generated images during validation, using a specified reverse diffusion model.
- Parameters:
diff_type (str) – Diffusion model type (“ddpm”, “ddim”, “sde”).
fwd_diff (ForwardDDPM, ForwardDDIM, or ForwardSDE) – Forward diffusion model defining the noise schedule.
rwd_diff (ReverseDDPM, ReverseDDIM, or ReverseSDE) – Reverse diffusion model for sampling during validation (default: None).
diff_net (torch.nn.Module) – Model to predict noise/score/v in the latent space (e.g., DiffusionNetwork).
comp_net (torch.nn.Module) – Variational autoencoder for encoding/decoding latents.
optim (torch.optim.Optimizer) – Optimizer for the noise predictor and conditional model (e.g., Adam).
loss_fn (Callable) – Loss function for noise prediction (e.g., MSELoss).
train_loader (torch.utils.data.DataLoader) – DataLoader for training data.
val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data (default: None).
cond_net (TextEncoder, optional) – Text encoder with projection layers for conditional generation (default: None).
metrics (object, optional) – Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
max_epochs (int, optional) – Maximum number of training epochs (default: 100).
device (str, optional) – Device for computation (e.g., ‘cuda’, ‘cpu’) (default: ‘cuda’).
store_path (str, optional) – Path to save model checkpoints (default: None, uses ‘ldm_train’).
patience (int, optional) – Number of epochs to wait for early stopping if validation loss doesn’t improve (default: 20).
warmup_steps (int, optional) – Number of steps for learning rate warmup (default: 1000).
tokenizer (BertTokenizer, optional) – Tokenizer for processing text prompts, default None (loads “bert-base-uncased”).
max_token_length (int, optional) – Maximum sequence length for tokenized text (default: 77).
val_freq (int, optional) – Frequency (in epochs) for validation and metric computation (default: 10).
norm_range (tuple, optional) – Range for clamping generated images (default: (-1, 1)).
norm_output (bool, optional) – Whether to normalize generated images to [0, 1] for metrics (default: True).
use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).
grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).
log_freq (int, optional) – Number of epochs before printing loss.
use_comp (bool, optional) – whether the model is internally compiled using torch.compile (default: false)
time_eps (float, optional) – lower bound for diffusion time sampling (time_eps, 1.0) (default: 1e-5)
num_steps (int, optional) – number of time staps for sampling during validation (default: 400)
- load_checkpoint(checkpoint_path: str) Tuple[int, float][source]#
Loads a training checkpoint to resume training.
Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint. Handles DDP model state dict loading.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file.
- Returns:
epoch (int) – The epoch at which the checkpoint was saved.
loss (float) – The loss at the checkpoint.
- static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#
Creates a learning rate scheduler for warmup.
Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs, then maintains it.
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.
warmup_steps (int) – Number of steps for the warmup phase.
- Returns:
Learning rate scheduler for warmup.
- Return type:
torch.optim.lr_scheduler.LambdaLR
- forward() Dict[source]#
Trains the noise/score/v/x0 predictor and conditional model with mixed precision and evaluation metrics.
Optimizes the noise predictor and conditional model (e.g., TextEncoder with projection layers) using the forward diffusion model’s noise schedule, with text conditioning. Performs validation with image-domain metrics (MSE, PSNR, SSIM, FID, LPIPS) using the reverse diffusion model, saves checkpoints for the best validation loss, and supports early stopping.
- Returns:
losses
- Return type:
dictionary of train and validation losses
- validate() Tuple[float, float, float, float, float, float][source]#
Validates the noise predictor and computes evaluation metrics.
Computes validation loss (MSE between predicted and ground truth noise) and generates samples using the reverse diffusion model. Evaluates image quality metrics if available.
- Returns:
(val_loss, fid, mse, psnr, ssim, lpips_score) where metrics may be None if not computed.
- Return type:
tuple
- class torchdiff.ldm.SampleLDM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSampler for generating images using Latent Diffusion Models (LDM).
Generates images by iteratively denoising random noise in the latent space using a reverse diffusion process, decoding the result back to the image space with a pre-trained compressor, as described in Rombach et al. (2022). Supports DDPM, DDIM, and SDE diffusion models, as well as conditional generation with text prompts.
- Parameters:
diff_type (str) – Diffusion model type. Supported: “ddpm”, “ddim”, “sde”.
rwd_diff (nn.Module) – Reverse diffusion module (e.g., ReverseDDPM, ReverseDDIM, ReverseSDE).
diff_net (nn.Module) – Model to predict noise added during the forward diffusion process.
comp_net (nn.Module) – Pre-trained model to encode/decode between image and latent spaces (e.g., AutoencoderLDM).
img_size (tuple) – Shape of generated images as (height, width).
cond_net (nn.Module, optional) – Model for conditional generation (e.g., TextEncoder), default None.
tokenizer (str or BertTokenizer, optional) – Tokenizer for processing text prompts, default “bert-base-uncased”.
batch_size (int, optional) – Number of images to generate per batch (default: 1).
in_channels (int, optional) – Number of input channels for latent representations (default: 3).
device (str) – Device for computation (default: CUDA).
max_token_length (int, optional) – Maximum length for tokenized prompts (default: 77).
norm_range (tuple, optional) – Range for clamping generated images (min, max), default (-1, 1).
- tokenize(prompts: List | str)[source]#
Tokenizes text prompts for conditional generation.
Converts input prompts into tokenized tensors using the specified tokenizer.
- Parameters:
prompts (str or list) – Text prompt(s) for conditional generation. Can be a single string or a list of strings.
- Returns:
input_ids (torch.Tensor) – Tokenized input IDs, shape (batch_size, max_length).
attention_mask (torch.Tensor) – Attention mask, shape (batch_size, max_length).
- forward(conds: str | List | None = None, norm_output: bool = True, save_imgs: bool = True, save_path: str = 'ldm_samples') torch.Tensor[source]#
Generates images using the reverse diffusion process in the latent space.
Iteratively denoises random noise in the latent space using the specified reverse diffusion model (DDPM, DDIM, SDE), then decodes the result to the image space with the compressor model. Supports conditional generation with text prompts.
- Parameters:
conds (str or list, optional) – Text prompt(s) for conditional generation, default None.
norm_output (bool, optional) – If True, normalizes output images to [0, 1] (default: True).
save_imgs (bool, optional) – If True, saves generated images to save_path (default: True).
save_path (str, optional) – Directory to save generated images (default: “ldm_samples”).
- Returns:
samps (torch.Tensor) - Generated images, shape (batch_size, channels, height, width).
If norm_output is True, images are normalized to [0, 1]; otherwise, they are clamped to norm_range.
- to(device: torch.device) Self[source]#
Moves the module and its components to the specified device.
- Parameters:
device (torch.device) – Target device for computation.
- Return type:
sample (SampleDDIM, SampleDDIM or SampleSDE) - The module moved to the specified device.
- class torchdiff.ldm.AutoencoderLDM(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleVariational autoencoder for latent space compression in Latent Diffusion Models.
Encodes images into a latent space and decodes them back to the image space, used as the compressor_model in LDM’s TrainLDM and SampleLDM. Supports KL-divergence or vector quantization (VQ) regularization for the latent representation.
- Parameters:
in_channels (int) – Number of input channels (e.g., 3 for RGB images).
down_channels (list) – List of channel sizes for encoder downsampling blocks (e.g., [32, 64, 128, 256]).
up_channels (list) – List of channel sizes for decoder upsampling blocks (e.g., [256, 128, 64, 16]).
out_channels (int) – Number of output channels, typically equal to in_channels.
dropout_rate (float) – Dropout rate for regularization in convolutional and attention layers.
num_heads (int) – Number of attention heads in self-attention layers.
num_groups (int) – Number of groups for group normalization in attention layers.
num_layers_per_block (int) – Number of convolutional layers in each downsampling and upsampling block.
total_down_sampling_factor (int) – Total downsampling factor across the encoder (e.g., 8 for 8x reduction).
latent_channels (int) – Number of channels in the latent representation for diffusion models.
num_embeddings (int) – Number of discrete embeddings in the VQ codebook (if use_vq=True).
use_vq (bool, optional) – If True, uses vector quantization (VQ) regularization; otherwise, uses KL-divergence (default: False).
beta (float, optional) – Weight for KL-divergence loss (if use_vq=False) (default: 1.0).
use_flash (bool, optional) – if true and available flash attention is used to improve training efficiency (default: True)
use_grad_check (bool, optional) – if true, gradient checkpoint is used (default: False)
- reparameterize(mu: torch.Tensor, logvar: torch.Tensor) torch.Tensor[source]#
Applies reparameterization trick for variational autoencoding.
Samples from a Gaussian distribution using the mean and log-variance to enable differentiable training.
- Parameters:
mu (torch.Tensor) – Mean of the latent distribution, shape (batch_size, channels, height, width).
logvar (torch.Tensor) – Log-variance of the latent distribution, same shape as mu.
- Return type:
reparam (torch.Tensor) - Sampled latent representation, same shape as mu.
- encode(x: torch.Tensor) Tuple[torch.Tensor, float][source]#
Encodes images into a latent representation.
Processes input images through the encoder, applying convolutions, downsampling, self-attention, and latent projection (VQ or KL-based).
- Parameters:
x (torch.Tensor) – Input images, shape (batch_size, in_channels, height, width).
- Returns:
z ((torch.Tensor)) – Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
reg_loss (float) – Regularization loss (VQ loss if use_vq=True, KL-divergence loss if use_vq=False).
**Notes**
The VQ loss is computed by VectorQuantizer if use_vq=True.
- The KL-divergence loss is normalized by batch size and latent size, weighted – by current_beta.
- decode(z: torch.Tensor) torch.Tensor[source]#
Decodes latent representations back to images.
Processes latent representations through the decoder, applying convolutions, self-attention, upsampling, and final reconstruction.
- Parameters:
z (torch.Tensor) – Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
- Return type:
x (torch.Tensor) - Reconstructed images, shape (batch_size, out_channels, height, width).
- forward(x: torch.Tensor) Tuple[torch.Tensor, float, float, torch.Tensor][source]#
Encodes images to latent space and decodes them, computing reconstruction and regularization losses.
Performs a full autoencoding pass, encoding images to the latent space, decoding them back, and calculating MSE reconstruction loss and regularization loss (VQ or KL-based).
- Parameters:
x (torch.Tensor) – Input images, shape (batch_size, in_channels, height, width).
- Returns:
x_hat (torch.Tensor) – Reconstructed images, shape (batch_size, out_channels, height, width).
total_loss (float) – Sum of reconstruction (MSE) and regularization losses.
reg_loss (float) – Regularization loss (VQ or KL-divergence).
z (torch.Tensor) – Latent representation, shape (batch_size, latent_channels, height/down_sampling_factor, width/down_sampling_factor).
**Notes**
The reconstruction loss is computed as the mean squared error between x_hat and x.
The regularization loss depends on use_vq (VQ loss or KL-divergence).
- class torchdiff.ldm.VectorQuantizer(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleVector quantization layer for discretizing latent representations.
Quantizes input latent vectors to the nearest embedding in a learned codebook, used in AutoencoderLDM when use_vq=True to enable discrete latent spaces for Latent Diffusion Models. Computes commitment and codebook losses to train the codebook embeddings.
- Parameters:
num_embed (int) – Number of discrete embeddings in the codebook.
embed_dim (int) – Dimensionality of each embedding vector (matches input channel dimension).
commit_cost (float, optional) – Weight for the commitment loss, encouraging inputs to be close to quantized values (default: 0.25).
**Notes**
[-1/num_embeddings (- The codebook embeddings are initialized uniformly in the range)
1/num_embeddings].
latents (- The forward pass flattens input)
embeddings (computes Euclidean distances to codebook)
quantization. (and selects the nearest embedding for)
versions (- The commitment loss encourages input latents to be close to their quantized)
inputs. (while the codebook loss updates embeddings to match)
input. (- A straight-through estimator is used to pass gradients from the quantized output to the)
- forward(z: torch.Tensor) Tuple[torch.Tensor, torch.Tensor][source]#
Quantizes latent representations to the nearest codebook embedding.
Computes the closest embedding for each input vector, applies quantization, and calculates commitment and codebook losses for training.
- Parameters:
z (torch.Tensor) – Input latent representation, shape (batch_size, embedding_dim, height, width).
- Returns:
quantized (torch.Tensor) – Quantized latent representation, same shape as z.
vq_loss (torch.Tensor) – Sum of commitment and codebook losses.
**Notes**
- The input is flattened to (batch_size * height * width, embedding_dim) for distance computation.
- Euclidean distances are computed efficiently using vectorized operations.
The commitment loss is scaled by commitment_cost, and the total VQ loss combines commitment and codebook losses.
- class torchdiff.ldm.DownBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDownsampling block for the encoder in AutoencoderLDM.
Applies multiple convolutional layers with residual connections followed by downsampling to reduce spatial dimensions in the encoder of the variational autoencoder used in Latent Diffusion Models.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels for convolutional layers.
num_layers (int) – Number of convolutional layer pairs (Conv3) per block.
down_sampling_factor (int) – Factor by which to downsample spatial dimensions.
dropout_rate (float) – Dropout rate for Conv3 layers.
use_grad_check (bool, optional) – if true, gradient checkpoint is used (default: False)
**Notes**
dimensions. (- Each layer pair consists of two Conv3 modules with a residual connection using a 1x1 convolution to match)
layers (- The downsampling is applied after all convolutional)
down_sampling_factor. (reducing spatial dimensions by)
- forward(x: torch.Tensor) torch.Tensor[source]#
Processes input through convolutional layers and downsampling.
- Parameters:
x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).
- Return type:
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
- class torchdiff.ldm.Conv3(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleConvolutional layer with group normalization, SiLU activation, and dropout.
Used in DownBlock and UpBlock of AutoencoderLDM for feature extraction and transformation in the encoder and decoder.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
dropout_rate (float) – Dropout rate for regularization.
**Notes**
normalization (- The layer applies group)
activation (SiLU)
dropout
sequence. (and a 3x3 convolution in)
convolution. (- Spatial dimensions are preserved due to padding=1 in the)
- forward(x: torch.Tensor) torch.Tensor[source]#
Processes input through group normalization, activation, dropout, and convolution.
- Parameters:
x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).
- Return type:
x (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height, width).
- class torchdiff.ldm.DownSampling(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleDownsampling module for reducing spatial dimensions in AutoencoderLDM’s encoder.
Combines convolutional downsampling and max pooling, concatenating their outputs to preserve feature information during downsampling in DownBlock.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels (sum of conv and pool paths).
down_sampling_factor (- The convolutional path uses a stride equal to) – Factor by which to downsample spatial dimensions.
**Notes**
paths (- The module splits the output channels evenly between convolutional and pooling)
dimension. (concatenating them along the channel)
down_sampling_factor
factor. (while the pooling path uses max pooling with the same)
- forward(x: torch.Tensor) torch.Tensor[source]#
Downsamples input by combining convolutional and pooling paths.
- Parameters:
batch (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).
- Return type:
x (torch.Tensor) - Downsampled tensor, shape (batch_size, out_channels, height/down_sampling_factor, width/down_sampling_factor).
- class torchdiff.ldm.Attention(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleSelf-attention module for feature enhancement in AutoencoderLDM.
Applies multi-head self-attention to enhance features in the encoder and decoder, used after downsampling (in DownBlock) and before upsampling (in UpBlock).
- Parameters:
num_channels (int) – Number of input and output channels (embedding dimension for attention).
num_heads (int) – Number of attention heads.
num_groups (int) – Number of groups for group normalization.
dropout_rate (float) – Dropout rate for attention outputs.
use_flash (bool, optional) – if true and available flash attention is used to improve training efficiency (default: True)
**Notes**
(batch_size (then restored to)
width (height *)
processing (num_channels) for attention)
(batch_size
num_channels
height
width).
training. (- Group normalization is applied before attention to stabilize)
- class torchdiff.ldm.UpBlock(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleUpsampling block for the decoder in AutoencoderLDM.
Applies upsampling followed by multiple convolutional layers with residual connections to increase spatial dimensions in the decoder of the variational autoencoder used in Latent Diffusion Models.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels for convolutional layers.
num_layers (int) – Number of convolutional layer pairs (Conv3) per block.
up_sampling_factor (int) – Factor by which to upsample spatial dimensions.
dropout_rate (float) – Dropout rate for Conv3 layers.
use_grad_check (bool, optional) – if true, gradient checkpoint is used (default: False)
**Notes**
first (- Upsampling is applied)
convolutions. (followed by convolutional layer pairs with residual connections using 1x1)
modules. (- Each layer pair consists of two Conv3)
- forward(x: torch.Tensor) torch.Tensor[source]#
Processes input through upsampling and convolutional layers.
- Parameters:
x (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).
- Return type:
output (torch.Tensor) - Output tensor, shape (batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
- class torchdiff.ldm.UpSampling(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleUpsampling module for increasing spatial dimensions in AutoencoderLDM’s decoder.
Combines transposed convolution and nearest-neighbor upsampling, concatenating their outputs to preserve feature information during upsampling in UpBlock.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels (sum of conv and upsample paths).
up_sampling_factor (int) – Factor by which to upsample spatial dimensions.
**Notes**
paths (- The module splits the output channels evenly between transposed convolution and upsampling)
dimension. (concatenating them along the channel)
differ (- If the spatial dimensions of the two paths)
size. (the upsampling path is interpolated to match the convolutional path’s)
- forward(x: torch.Tensor) torch.Tensor[source]#
Upsamples input by combining transposed convolution and upsampling paths.
- Parameters:
batch (torch.Tensor) – Input tensor, shape (batch_size, in_channels, height, width).
- Returns:
x (torch.Tensor) - Upsampled tensor, shape
(batch_size, out_channels, height * up_sampling_factor, width * up_sampling_factor).
**Notes**
- Interpolation is applied if the spatial dimensions of the – convolutional and upsampling paths differ, using nearest-neighbor mode.
- class torchdiff.ldm.TrainAE(*args: Any, **kwargs: Any)[source]#
Bases:
ModuleTrainer for the AutoencoderLDM variational autoencoder in Latent Diffusion Models.
Optimizes the AutoencoderLDM model to compress images into latent space and reconstruct them, using reconstruction loss (MSE), regularization (KL or VQ), and optional perceptual loss (LPIPS). Supports mixed precision, KL warmup, early stopping, and learning rate scheduling, with evaluation metrics (MSE, PSNR, SSIM, FID, LPIPS).
- Parameters:
model (nn.Module) – The variational autoencoder model (AutoencoderLDM) to train.
optim (torch.optim.Optimizer) – Optimizer for training (e.g., Adam).
train_loader (torch.utils.data.DataLoader) – DataLoader for training data.
val_loader (torch.utils.data.DataLoader, optional) – DataLoader for validation data (default: None).
max_epochs (int, optional) – Maximum number of training epochs (default: 100).
metrics (object, optional) – Metrics object for computing MSE, PSNR, SSIM, FID, and LPIPS (default: None).
device (str) – Device for computation (e.g., ‘cuda’, ‘cpu’).
store_path (str, optional) – Path to save model checkpoints (default: ‘vlc_model.pth’).
checkpoint (int, optional) – Frequency (in epochs) to save model checkpoints (default: 10).
kl_warmup_epochs (int, optional) – Number of epochs for KL loss warmup (default: 10).
patience (int, optional) – Number of epochs to wait for early stopping if validation loss doesn’t improve (default: 10).
val_freq (int, optional) – Frequency (in epochs) for validation and metric computation (default: 5).
warmup_steps (int, optional) – learinig rate warmup steps (default: 1000)
use_ddp (bool, optional) – Whether to use Distributed Data Parallel training (default: False).
grad_acc (int, optional) – Number of gradient accumulation steps before optimizer update (default: 1).
log_freq (int, optional) – Number of epochs before printing loss.
use_comp (bool, optional) – if true, model is compiled (default: False)
- load_checkpoint(checkpoint_path: str) Tuple[float, float][source]#
Loads a training checkpoint to resume training.
Restores the state of the noise predictor, conditional model (if applicable), and optimizer from a saved checkpoint.
- Parameters:
checkpoint_path (str) – Path to the checkpoint file.
- Returns:
epoch (float) – The epoch at which the checkpoint was saved (int).
loss (float) – The loss at the checkpoint (float).
- static warmup_scheduler(optimizer: torch.optim.Optimizer, warmup_steps: int) torch.optim.lr_scheduler.LambdaLR[source]#
Creates a learning rate scheduler for warmup.
Generates a scheduler that linearly increases the learning rate from 0 to the optimizer’s initial value over the specified warmup epochs, then maintains it.
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer to apply the scheduler to.
warmup_steps (int) – Number of steps for the warm phase.
- Returns:
Learning rate scheduler for warmup.
- Return type:
torch.optim.lr_scheduler.LambdaLR
- forward() Dict[source]#
Trains the AutoencoderLDM model with mixed precision and evaluation metrics.
Performs training with reconstruction and regularization losses, KL warmup, gradient clipping, and learning rate scheduling. Saves checkpoints for the best validation loss and supports early stopping.
- Returns:
losses
- Return type:
dictionlary contains train and validation losses
- validate() Tuple[float, float, float, float, float, float][source]#
Validates the AutoencoderLDM model and computes evaluation Metrics.
Computes validation loss and optional Metrics (MSE, PSNR, SSIM, FID, LPIPS) using the provided Metrics object.
- Returns:
val_loss (float) – Mean validation loss.
fid (float, or float(‘inf’) if not computed) – Mean FID score.
mse (float, or None if not computed) – Mean MSE
psnr (float, or None if not computed) – Mean PSNR
ssim (float, or None if not computed) – Mean SSIM
lpips_score (float, or None if not computed) – Mean LPIPS score