Variational Autoencoder from Scratch - Torch

This example implements a Variational Autoencoder (VAE) on binarized MNIST using PyTorch. It is based on the TensorFlow Probability example but adapted to use PyTorch 2.x.
"""Trains a variational auto-encoder (VAE) on binarized MNIST using PyTorch.

    The VAE defines a generative model in which a latent code `Z` is sampled from a
    prior `p(Z)`, then used to generate an observation `X` by way of a decoder
    `p(X|Z)`. The full reconstruction follows

    ```none
    X ~ p(X)              # A random image from some dataset.
    Z ~ q(Z | X)          # A random encoding of the original image ("encoder").
    Xhat ~ p(Xhat | Z)       # A random reconstruction of the original image
                            #   ("decoder").
    ```

    To fit the VAE, we assume an approximate representation of the posterior in the
    form of an encoder `q(Z|X)`. We minimize the KL divergence between `q(Z|X)` and
    the true posterior `p(Z|X)`: this is equivalent to maximizing the evidence lower
    bound (ELBO),

    ```none
    -log p(x)
    = -log int dz p(x|z) p(z)
    = -log int dz q(z|x) p(x|z) p(z) / q(z|x)
    <= int dz q(z|x) (-log[ p(x|z) p(z) / q(z|x) ])   # Jensen's Inequality
    =: KL[q(Z|x) || p(x|Z)p(Z)]
    = -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
    ```

    -or-

    ```none
    -log p(x)
    = KL[q(Z|x) || p(x|Z)p(Z)] - KL[q(Z|x) || p(Z|x)]
    <= KL[q(Z|x) || p(x|Z)p(Z)                        # Positivity of KL
    = -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
    ```

    The `-E_{Z~q(Z|x)}[log p(x|Z)]` term is an expected reconstruction loss and
    `KL[q(Z|x) || p(Z)]` is a kind of distributional regularizer. 

    This implementation supports both standard normal prior and mixture of Gaussians prior.
    Using a single Gaussian component is equivalent to the fixed standard normal prior.

    This implementation also supports using the analytic KL (KL[q(Z|x) || p(Z)]) with the
    `analytic_kl` flag. Using the analytic KL is only supported when
    `mixture_components` is set to 1 since otherwise no analytic form is known.

    We also compute tighter bounds, the IWAE [Burda et. al. (2015)][2].
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import argparse
import time
from typing import Tuple, List, Callable, Optional, Dict, Any

import numpy as np
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
!nvidia-smi
Wed Apr 30 20:37:14 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX A4500 Laptop GPU    Off |   00000000:01:00.0  On |                  Off |
| N/A   53C    P3             25W /  115W |     313MiB /  16384MiB |     23%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
# Check if GPU is available
try:
    # First check if CUDA is available through PyTorch
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        device = torch.device("cuda:0")
        print(f"Using GPU device: {torch.cuda.get_device_name(0)}")
    else:
        device = torch.device("cpu")
        print("CUDA is not available. Using CPU instead.")
except Exception as e:
    device = torch.device("cpu")
    print(f"Error initializing CUDA: {e}")
    print("Falling back to CPU.")

print(f"Using device: {device}")
CUDA is not available. Using CPU instead.
Using device: cpu
/workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
  return torch._C._cuda_getDeviceCount() > 0
# Define constants and hyperparameters
IMAGE_SHAPE = [28, 28, 1]

# Parse arguments (simulating flags in TF)
parser = argparse.ArgumentParser(description="VAE on binarized MNIST")
parser.add_argument(
    "--learning_rate", type=float, default=0.001, help="Initial learning rate"
)
parser.add_argument(
    "--max_steps", type=int, default=5001, help="Number of training steps to run"
)
parser.add_argument(
    "--latent_size",
    type=int,
    default=16,
    help="Number of dimensions in the latent code (z)",
)
parser.add_argument("--base_depth", type=int, default=32, help="Base depth for layers")
parser.add_argument(
    "--activation", type=str, default="leaky_relu", help="Activation function"
)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument(
    "--n_samples", type=int, default=16, help="Number of samples to use in encoding"
)
parser.add_argument(
    "--mixture_components",
    type=int,
    default=100,
    help="Number of mixture components in the prior",
)
parser.add_argument(
    "--analytic_kl",
    action="store_true",
    default=False,
    help="Whether to use analytic KL",
)
parser.add_argument(
    "--data_dir",
    type=str,
    default=os.path.join("/tmp", "vae/data"),
    help="Directory for data",
)
parser.add_argument(
    "--model_dir",
    type=str,
    default=os.path.join("/tmp", "vae"),
    help="Directory to put the model's fit",
)
parser.add_argument(
    "--viz_steps",
    type=int,
    default=500,
    help="Frequency at which to save visualizations",
)
parser.add_argument(
    "--fake_data",
    action="store_true",
    default=False,
    help="If true, uses fake data instead of MNIST",
)
parser.add_argument(
    "--delete_existing",
    action="store_true",
    default=False,
    help="If true, deletes existing model_dir directory",
)

# In notebook environment, use default values instead of parsing arguments
args = argparse.Namespace(
    learning_rate=0.001,
    max_steps=5001,
    latent_size=16,
    base_depth=32,
    activation="leaky_relu",
    batch_size=32,
    n_samples=16,
    mixture_components=100,
    analytic_kl=False,
    data_dir=os.path.join("/tmp", "vae/data"),
    model_dir=os.path.join("/tmp", "vae"),
    viz_steps=500,
    fake_data=False,
    delete_existing=False,
)
# Utility functions
def _softplus_inverse(x):
    """Helper which computes the function inverse of softplus."""
    return torch.log(torch.expm1(x))


# Helper to get activation function by name
def get_activation(name):
    if name == "leaky_relu":
        return nn.LeakyReLU(0.2)
    elif name == "relu":
        return nn.ReLU()
    elif name == "elu":
        return nn.ELU()
    else:
        raise ValueError(f"Unsupported activation function: {name}")


def pack_images(images, rows, cols):
    """Helper utility to make a field of images."""
    shape = images.shape
    width, height, depth = shape[-3], shape[-2], shape[-1]
    images = images.reshape(-1, width, height, depth)
    batch = images.shape[0]
    rows = min(rows, batch)
    cols = min(batch // rows, cols)
    images = images[: rows * cols]
    images = images.reshape(rows, cols, width, height, depth)
    images = images.permute(0, 2, 1, 3, 4)
    images = images.reshape(1, rows * width, cols * height, depth)
    return images.squeeze(3)  # Remove depth dimension for plotting if depth=1


def show_image_grid(images, rows=8, cols=8, figsize=(10, 10)):
    """Display a grid of images."""
    plt.figure(figsize=figsize)
    grid = pack_images(images, rows, cols)[0].cpu().numpy()
    # Remove batch dimension if it exists
    if grid.ndim == 3 and grid.shape[0] == 1:
        grid = grid.squeeze(0)
    # Ensure we have a 2D tensor for imshow
    elif grid.ndim == 3:
        # If we have a 3D tensor that's not (C,H,W) format with C=3 or C=4,
        # reshape it to a 2D image by flattening the first dimension
        if grid.shape[0] not in [3, 4]:
            grid = grid.reshape(-1, grid.shape[-1])
    
    plt.imshow(grid, cmap="gray")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

Encoder

# VAE Model Components
class Encoder(nn.Module):
    """Encoder network for VAE."""

    def __init__(self, latent_size: int, base_depth: int, activation_name: str):
        super(Encoder, self).__init__()
        activation = get_activation(activation_name)

        self.latent_size = latent_size

        # Define the encoder network
        self.encoder_net = nn.Sequential(
            nn.Conv2d(1, base_depth, kernel_size=5, stride=1, padding=2),
            activation,
            nn.Conv2d(base_depth, base_depth, kernel_size=5, stride=2, padding=2),
            activation,
            nn.Conv2d(base_depth, 2 * base_depth, kernel_size=5, stride=1, padding=2),
            activation,
            nn.Conv2d(
                2 * base_depth, 2 * base_depth, kernel_size=5, stride=2, padding=2
            ),
            activation,
            nn.Conv2d(
                2 * base_depth, 4 * latent_size, kernel_size=7, stride=1, padding=0
            ),
            activation,
            nn.Flatten(),
            nn.Linear(4 * latent_size, 2 * latent_size),
        )

    def forward(self, x):
        # Scale images from [0, 1] to [-1, 1]
        x = 2 * x - 1
        net_output = self.encoder_net(x)

        # Split the output into mean and log_variance
        mean = net_output[..., : self.latent_size]
        log_scale = net_output[..., self.latent_size :] + _softplus_inverse(
            torch.tensor(1.0)
        )
        scale = F.softplus(log_scale)

        # Return the distribution
        return td.Independent(td.Normal(loc=mean, scale=scale), 1)

Decoder

class Decoder(nn.Module):
    """Decoder network for VAE."""

    def __init__(
        self,
        latent_size: int,
        output_shape: List[int],
        base_depth: int,
        activation_name: str,
    ):
        super(Decoder, self).__init__()
        activation = get_activation(activation_name)

        self.latent_size = latent_size
        self.output_shape = output_shape

        # Define the decoder network
        self.decoder_net = nn.Sequential(
            nn.ConvTranspose2d(
                latent_size, 2 * base_depth, kernel_size=7, stride=1, padding=0
            ),
            activation,
            nn.ConvTranspose2d(
                2 * base_depth, 2 * base_depth, kernel_size=5, stride=1, padding=2
            ),
            activation,
            nn.ConvTranspose2d(
                2 * base_depth,
                2 * base_depth,
                kernel_size=5,
                stride=2,
                padding=2,
                output_padding=1,
            ),
            activation,
            nn.ConvTranspose2d(
                2 * base_depth, base_depth, kernel_size=5, stride=1, padding=2
            ),
            activation,
            nn.ConvTranspose2d(
                base_depth,
                base_depth,
                kernel_size=5,
                stride=2,
                padding=2,
                output_padding=1,
            ),
            activation,
            nn.ConvTranspose2d(
                base_depth, base_depth, kernel_size=5, stride=1, padding=2
            ),
            activation,
            nn.Conv2d(base_depth, output_shape[-1], kernel_size=5, stride=1, padding=2),
        )

    def forward(self, z):
        # Remember original shape
        original_shape = z.shape

        # Reshape for convolutional decoder
        z = z.reshape(-1, self.latent_size, 1, 1)

        # Apply decoder network
        logits = self.decoder_net(z)

        # Reshape logits to match original dimensions + output shape
        new_shape = original_shape[:-1] + tuple(self.output_shape)
        logits = logits.reshape(new_shape)

        # Return the distribution
        return td.Independent(td.Bernoulli(logits=logits), len(self.output_shape))

Prior

class MixturePrior:
    """Mixture of Gaussians prior distribution."""

    def __init__(self, latent_size: int, mixture_components: int, device: torch.device):
        self.latent_size = latent_size
        self.mixture_components = mixture_components
        self.device = device

        if mixture_components == 1:
            # Standard normal prior
            self.distribution = td.MultivariateNormal(
                loc=torch.zeros(latent_size, device=device),
                covariance_matrix=torch.eye(latent_size, device=device),
            )
        else:
            # Mixture of Gaussians prior
            self.locs = nn.Parameter(
                torch.randn(mixture_components, latent_size, device=device) * 0.1
            )
            self.raw_scales = nn.Parameter(
                torch.randn(mixture_components, latent_size, device=device) * 0.1
            )
            self.mixture_logits = nn.Parameter(
                torch.zeros(mixture_components, device=device)
            )

            # The distribution is created in the forward pass
            self.distribution = None
            self._create_distribution()

    def _create_distribution(self):
        if self.mixture_components == 1:
            return

        # Create component distributions (diagonal normal)
        scales = F.softplus(self.raw_scales)
        component_distribution = td.Independent(
            td.Normal(loc=self.locs, scale=scales), 1
        )

        # Create mixture distribution
        mixture_distribution = td.Categorical(logits=self.mixture_logits)

        # Create the mixture
        self.distribution = td.MixtureSameFamily(
            mixture_distribution=mixture_distribution,
            component_distribution=component_distribution,
        )

    def log_prob(self, z):
        if self.mixture_components > 1:
            # Ensure distribution is up-to-date with current parameters
            self._create_distribution()
        return self.distribution.log_prob(z)

    def sample(self, sample_shape=torch.Size()):
        if self.mixture_components > 1:
            # Ensure distribution is up-to-date with current parameters
            self._create_distribution()
        return self.distribution.sample(sample_shape)

VAE

class VAE(nn.Module):
    """Variational Autoencoder model."""

    def __init__(
        self,
        latent_size: int,
        base_depth: int,
        activation_name: str,
        mixture_components: int,
        analytic_kl: bool,
        n_samples: int,
    ):
        super(VAE, self).__init__()

        if analytic_kl and mixture_components != 1:
            raise NotImplementedError(
                "Using analytic_kl is only supported when mixture_components = 1 "
                "since there's no closed form otherwise."
            )

        self.latent_size = latent_size
        self.analytic_kl = analytic_kl
        self.n_samples = n_samples

        # Create encoder, decoder, and prior
        self.encoder = Encoder(latent_size, base_depth, activation_name)
        self.decoder = Decoder(latent_size, IMAGE_SHAPE, base_depth, activation_name)
        self.prior = MixturePrior(latent_size, mixture_components, device=device)

    def forward(self, x):
        """Forward pass through the VAE model."""
        # Encode input to get approximate posterior q(z|x)
        approx_posterior = self.encoder(x)

        # Sample from the posterior
        approx_posterior_sample = approx_posterior.rsample([self.n_samples])

        # Decode the samples to get p(x|z)
        decoder_likelihood = self.decoder(approx_posterior_sample)

        # Calculate distortion (negative log likelihood)
        distortion = -decoder_likelihood.log_prob(x)

        # Calculate rate (KL divergence)
        if self.analytic_kl:
            # For standard normal prior we can use a closed form
            mean = approx_posterior.mean
            var = approx_posterior.variance
            rate = 0.5 * torch.sum(mean.pow(2) + var - 1 - torch.log(var), dim=-1)
        else:
            # Monte Carlo estimate of KL
            rate = approx_posterior.log_prob(
                approx_posterior_sample
            ) - self.prior.log_prob(approx_posterior_sample)

        # ELBO = -rate - distortion
        elbo_local = -(rate + distortion)

        # Compute IWAE bound
        log_weight_mean = torch.logsumexp(elbo_local, dim=0) - torch.log(
            torch.tensor(self.n_samples, dtype=torch.float, device=device)
        )

        # Return various quantities for training and evaluation
        return {
            "elbo": elbo_local.mean(),
            "elbo_iwae": log_weight_mean.mean(),
            "rate": rate.mean(),
            "distortion": distortion.mean(),
            "approx_posterior": approx_posterior,
            "decoder_likelihood": decoder_likelihood,
        }

Binarized MNIST

# Dataset loading functions
ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
FILE_TEMPLATE = "binarized_mnist_{split}.amat"


def download(directory, filename):
    """Downloads a file."""
    filepath = os.path.join(directory, filename)
    if os.path.exists(filepath):
        return filepath
    if not os.path.exists(directory):
        os.makedirs(directory)
    url = os.path.join(ROOT_PATH, filename)
    print(f"Downloading {url} to {filepath}")
    urllib.request.urlretrieve(url, filepath)
    return filepath


class BinarizedMNIST(Dataset):
    """Binarized MNIST dataset."""

    def __init__(self, directory, split_name):
        super(BinarizedMNIST, self).__init__()

        # Download and load the data
        amat_file = download(directory, FILE_TEMPLATE.format(split=split_name))
        with open(amat_file, "r") as f:
            lines = f.readlines()

        # Process each line to get binary data
        data = []
        for line in lines:
            row = [int(bit) for bit in line.strip().split()]
            data.append(row)

        # Convert to tensor and reshape
        self.data = (
            torch.tensor(data, dtype=torch.float)
            .reshape(-1, 28, 28, 1)
            .permute(0, 3, 1, 2)
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], torch.tensor(0, dtype=torch.long)  # dummy label


def build_fake_data(batch_size):
    """Builds fake MNIST-style data for unit testing."""
    random_sample = torch.rand(batch_size, 1, 28, 28)

    train_dataset = TensorDataset(
        random_sample, torch.zeros(batch_size, dtype=torch.long)
    )
    eval_dataset = TensorDataset(
        random_sample, torch.zeros(batch_size, dtype=torch.long)
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, eval_loader


def build_dataloaders(data_dir, batch_size):
    """Builds data loaders for training and evaluation."""
    train_dataset = BinarizedMNIST(data_dir, "train")
    valid_dataset = BinarizedMNIST(data_dir, "valid")

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
    )
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader

Trainer

# Training and evaluation
def train(model, train_loader, optimizer, epoch, writer=None):
    """Train for one epoch."""
    model.train()
    train_loss = 0

    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = model(data)
        loss = -outputs["elbo"]  # Negative ELBO is the loss

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Track loss
        train_loss += loss.item()

    avg_loss = train_loss / len(train_loader)

    if writer is not None:
        writer.add_scalar("Loss/train", avg_loss, epoch)

    return avg_loss


def evaluate(model, eval_loader, epoch, writer=None):
    """Evaluate the model."""
    model.eval()
    metrics = {"elbo": 0, "elbo_iwae": 0, "rate": 0, "distortion": 0}

    with torch.no_grad():
        for data, _ in eval_loader:
            data = data.to(device)
            outputs = model(data)

            for k in metrics:
                metrics[k] += outputs[k].item()

            # Only need the first batch for visualization
            if "eval_data" not in locals():
                eval_data = data
                eval_outputs = outputs

    # Average metrics
    for k in metrics:
        metrics[k] /= len(eval_loader)

    if writer is not None:
        for k, v in metrics.items():
            writer.add_scalar(f"Metrics/{k}", v, epoch)

    return metrics, eval_data, eval_outputs


def visualize(model, data, outputs, epoch, writer=None):
    """Generate visualizations of reconstructions and samples."""
    with torch.no_grad():
        # Get reconstructions
        recon_samples = outputs["decoder_likelihood"].mean[:3, :16]

        # Generate random samples from prior
        random_z = model.prior.sample(torch.Size([16]))
        random_samples = model.decoder(random_z).mean

        if writer is not None:
            # Input images - squeeze first dim to make it (H,W) format
            input_grid = pack_images(data[:16].unsqueeze(1), 1, 16)[0].squeeze(0)
            writer.add_image("input", input_grid, epoch, dataformats='HW')

            # Reconstruction - squeeze first dim to make it (H,W) format
            recon_grid = pack_images(recon_samples, 3, 16)[0].squeeze(0)
            writer.add_image("recon/mean", recon_grid, epoch, dataformats='HW')

            # Random samples - squeeze first dim to make it (H,W) format
            random_grid = pack_images(random_samples, 4, 4)[0].squeeze(0)
            writer.add_image("random/mean", random_grid, epoch, dataformats='HW')

        return {"input": data[:16], "recon": recon_samples, "random": random_samples}

Application

def get_scheduler(optimizer, max_steps):
    """Create a cosine annealing scheduler."""
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)


def main():
    # Set parameters
    params = vars(args)
    activation = params["activation"]

    # Create directories
    if params["delete_existing"] and os.path.exists(params["model_dir"]):
        print(f"Deleting old log directory at {params['model_dir']}")
        import shutil

        shutil.rmtree(params["model_dir"])
    os.makedirs(params["model_dir"], exist_ok=True)

    # Setup data loaders
    if params["fake_data"]:
        train_loader, eval_loader = build_fake_data(params["batch_size"])
    else:
        train_loader, eval_loader = build_dataloaders(
            params["data_dir"], params["batch_size"]
        )

    # Create model
    model = VAE(
        latent_size=params["latent_size"],
        base_depth=params["base_depth"],
        activation_name=params["activation"],
        mixture_components=params["mixture_components"],
        analytic_kl=params["analytic_kl"],
        n_samples=params["n_samples"],
    ).to(device)

    # Setup optimizer and scheduler
    optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"])
    scheduler = get_scheduler(optimizer, params["max_steps"])

    # Setup TensorBoard writer
    writer = SummaryWriter(log_dir=params["model_dir"])

    # Training loop
    total_steps = 0
    epochs = params["max_steps"] // len(train_loader) + 1
    viz_epochs = params["viz_steps"] // len(train_loader) + 1

    for epoch in range(epochs):
        # Train for one epoch
        train_loss = train(model, train_loader, optimizer, epoch, writer)
        scheduler.step()

        # Evaluate and visualize at regular intervals
        if epoch % viz_epochs == 0 or epoch == epochs - 1:
            metrics, eval_data, eval_outputs = evaluate(
                model, eval_loader, epoch, writer
            )
            vis = visualize(model, eval_data, eval_outputs, epoch, writer)

            # Display metrics
            print(
                f"Epoch {epoch}: Train loss = {train_loss:.4f}, ELBO = {metrics['elbo']:.4f}, "
                f"IWAE ELBO = {metrics['elbo_iwae']:.4f}, Rate = {metrics['rate']:.4f}, "
                f"Distortion = {metrics['distortion']:.4f}"
            )

            # Display visualizations
            plt.figure(figsize=(12, 4))
            plt.subplot(131)
            plt.title("Original")
            show_image_grid(vis["input"], rows=2, cols=8, figsize=(6, 2))

            plt.subplot(132)
            plt.title("Reconstruction")
            show_image_grid(vis["recon"], rows=3, cols=5, figsize=(6, 3))

            plt.subplot(133)
            plt.title("Generated")
            show_image_grid(vis["random"], rows=4, cols=4, figsize=(4, 4))

    writer.close()
    print("Training completed!")
if __name__ == "__main__":
    main()
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat to /tmp/vae/data/binarized_mnist_train.amat
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat to /tmp/vae/data/binarized_mnist_valid.amat
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[11], line 2
      1 if __name__ == "__main__":
----> 2     main()

Cell In[10], line 51, in main()
     47 viz_epochs = params["viz_steps"] // len(train_loader) + 1
     49 for epoch in range(epochs):
     50     # Train for one epoch
---> 51     train_loss = train(model, train_loader, optimizer, epoch, writer)
     52     scheduler.step()
     54     # Evaluate and visualize at regular intervals

Cell In[9], line 12, in train(model, train_loader, optimizer, epoch, writer)
      9 optimizer.zero_grad()
     11 # Forward pass
---> 12 outputs = model(data)
     13 loss = -outputs["elbo"]  # Negative ELBO is the loss
     15 # Backward pass and optimize

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

Cell In[7], line 39, in VAE.forward(self, x)
     36 approx_posterior_sample = approx_posterior.rsample([self.n_samples])
     38 # Decode the samples to get p(x|z)
---> 39 decoder_likelihood = self.decoder(approx_posterior_sample)
     41 # Calculate distortion (negative log likelihood)
     42 distortion = -decoder_likelihood.log_prob(x)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

Cell In[5], line 64, in Decoder.forward(self, z)
     61 z = z.reshape(-1, self.latent_size, 1, 1)
     63 # Apply decoder network
---> 64 logits = self.decoder_net(z)
     66 # Reshape logits to match original dimensions + output shape
     67 new_shape = original_shape[:-1] + tuple(self.output_shape)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input)
    248 def forward(self, input):
    249     for module in self:
--> 250         input = module(input)
    251     return input

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:1162, in ConvTranspose2d.forward(self, input, output_size)
   1151 num_spatial_dims = 2
   1152 output_padding = self._output_padding(
   1153     input,
   1154     output_size,
   (...)   1159     self.dilation,  # type: ignore[arg-type]
   1160 )
-> 1162 return F.conv_transpose2d(
   1163     input,
   1164     self.weight,
   1165     self.bias,
   1166     self.stride,
   1167     self.padding,
   1168     output_padding,
   1169     self.groups,
   1170     self.dilation,
   1171 )

KeyboardInterrupt: