"""Trains a variational auto-encoder (VAE) on binarized MNIST using PyTorch.
The VAE defines a generative model in which a latent code `Z` is sampled from a
prior `p(Z)`, then used to generate an observation `X` by way of a decoder
`p(X|Z)`. The full reconstruction follows
```none
X ~ p(X) # A random image from some dataset.
Z ~ q(Z | X) # A random encoding of the original image ("encoder").
Xhat ~ p(Xhat | Z) # A random reconstruction of the original image
# ("decoder").
```
To fit the VAE, we assume an approximate representation of the posterior in the
form of an encoder `q(Z|X)`. We minimize the KL divergence between `q(Z|X)` and
the true posterior `p(Z|X)`: this is equivalent to maximizing the evidence lower
bound (ELBO),
```none
-log p(x)
= -log int dz p(x|z) p(z)
= -log int dz q(z|x) p(x|z) p(z) / q(z|x)
<= int dz q(z|x) (-log[ p(x|z) p(z) / q(z|x) ]) # Jensen's Inequality
=: KL[q(Z|x) || p(x|Z)p(Z)]
= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
```
-or-
```none
-log p(x)
= KL[q(Z|x) || p(x|Z)p(Z)] - KL[q(Z|x) || p(Z|x)]
<= KL[q(Z|x) || p(x|Z)p(Z) # Positivity of KL
= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
```
The `-E_{Z~q(Z|x)}[log p(x|Z)]` term is an expected reconstruction loss and
`KL[q(Z|x) || p(Z)]` is a kind of distributional regularizer.
This implementation supports both standard normal prior and mixture of Gaussians prior.
Using a single Gaussian component is equivalent to the fixed standard normal prior.
This implementation also supports using the analytic KL (KL[q(Z|x) || p(Z)]) with the
`analytic_kl` flag. Using the analytic KL is only supported when
`mixture_components` is set to 1 since otherwise no analytic form is known.
We also compute tighter bounds, the IWAE [Burda et. al. (2015)][2].
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import time
from typing import Tuple, List, Callable, Optional, Dict, Any
import numpy as np
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
Variational Autoencoder from Scratch - Torch
This example implements a Variational Autoencoder (VAE) on binarized MNIST using PyTorch. It is based on the TensorFlow Probability example but adapted to use PyTorch 2.x.
!nvidia-smiWed Apr 30 20:37:14 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A4500 Laptop GPU Off | 00000000:01:00.0 On | Off |
| N/A 53C P3 25W / 115W | 313MiB / 16384MiB | 23% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
# Check if GPU is available
try:
# First check if CUDA is available through PyTorch
cuda_available = torch.cuda.is_available()
if cuda_available:
device = torch.device("cuda:0")
print(f"Using GPU device: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("CUDA is not available. Using CPU instead.")
except Exception as e:
device = torch.device("cpu")
print(f"Error initializing CUDA: {e}")
print("Falling back to CPU.")
print(f"Using device: {device}")CUDA is not available. Using CPU instead.
Using device: cpu
/workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
return torch._C._cuda_getDeviceCount() > 0
# Define constants and hyperparameters
IMAGE_SHAPE = [28, 28, 1]
# Parse arguments (simulating flags in TF)
parser = argparse.ArgumentParser(description="VAE on binarized MNIST")
parser.add_argument(
"--learning_rate", type=float, default=0.001, help="Initial learning rate"
)
parser.add_argument(
"--max_steps", type=int, default=5001, help="Number of training steps to run"
)
parser.add_argument(
"--latent_size",
type=int,
default=16,
help="Number of dimensions in the latent code (z)",
)
parser.add_argument("--base_depth", type=int, default=32, help="Base depth for layers")
parser.add_argument(
"--activation", type=str, default="leaky_relu", help="Activation function"
)
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument(
"--n_samples", type=int, default=16, help="Number of samples to use in encoding"
)
parser.add_argument(
"--mixture_components",
type=int,
default=100,
help="Number of mixture components in the prior",
)
parser.add_argument(
"--analytic_kl",
action="store_true",
default=False,
help="Whether to use analytic KL",
)
parser.add_argument(
"--data_dir",
type=str,
default=os.path.join("/tmp", "vae/data"),
help="Directory for data",
)
parser.add_argument(
"--model_dir",
type=str,
default=os.path.join("/tmp", "vae"),
help="Directory to put the model's fit",
)
parser.add_argument(
"--viz_steps",
type=int,
default=500,
help="Frequency at which to save visualizations",
)
parser.add_argument(
"--fake_data",
action="store_true",
default=False,
help="If true, uses fake data instead of MNIST",
)
parser.add_argument(
"--delete_existing",
action="store_true",
default=False,
help="If true, deletes existing model_dir directory",
)
# In notebook environment, use default values instead of parsing arguments
args = argparse.Namespace(
learning_rate=0.001,
max_steps=5001,
latent_size=16,
base_depth=32,
activation="leaky_relu",
batch_size=32,
n_samples=16,
mixture_components=100,
analytic_kl=False,
data_dir=os.path.join("/tmp", "vae/data"),
model_dir=os.path.join("/tmp", "vae"),
viz_steps=500,
fake_data=False,
delete_existing=False,
)# Utility functions
def _softplus_inverse(x):
"""Helper which computes the function inverse of softplus."""
return torch.log(torch.expm1(x))
# Helper to get activation function by name
def get_activation(name):
if name == "leaky_relu":
return nn.LeakyReLU(0.2)
elif name == "relu":
return nn.ReLU()
elif name == "elu":
return nn.ELU()
else:
raise ValueError(f"Unsupported activation function: {name}")
def pack_images(images, rows, cols):
"""Helper utility to make a field of images."""
shape = images.shape
width, height, depth = shape[-3], shape[-2], shape[-1]
images = images.reshape(-1, width, height, depth)
batch = images.shape[0]
rows = min(rows, batch)
cols = min(batch // rows, cols)
images = images[: rows * cols]
images = images.reshape(rows, cols, width, height, depth)
images = images.permute(0, 2, 1, 3, 4)
images = images.reshape(1, rows * width, cols * height, depth)
return images.squeeze(3) # Remove depth dimension for plotting if depth=1
def show_image_grid(images, rows=8, cols=8, figsize=(10, 10)):
"""Display a grid of images."""
plt.figure(figsize=figsize)
grid = pack_images(images, rows, cols)[0].cpu().numpy()
# Remove batch dimension if it exists
if grid.ndim == 3 and grid.shape[0] == 1:
grid = grid.squeeze(0)
# Ensure we have a 2D tensor for imshow
elif grid.ndim == 3:
# If we have a 3D tensor that's not (C,H,W) format with C=3 or C=4,
# reshape it to a 2D image by flattening the first dimension
if grid.shape[0] not in [3, 4]:
grid = grid.reshape(-1, grid.shape[-1])
plt.imshow(grid, cmap="gray")
plt.axis("off")
plt.tight_layout()
plt.show()Encoder
# VAE Model Components
class Encoder(nn.Module):
"""Encoder network for VAE."""
def __init__(self, latent_size: int, base_depth: int, activation_name: str):
super(Encoder, self).__init__()
activation = get_activation(activation_name)
self.latent_size = latent_size
# Define the encoder network
self.encoder_net = nn.Sequential(
nn.Conv2d(1, base_depth, kernel_size=5, stride=1, padding=2),
activation,
nn.Conv2d(base_depth, base_depth, kernel_size=5, stride=2, padding=2),
activation,
nn.Conv2d(base_depth, 2 * base_depth, kernel_size=5, stride=1, padding=2),
activation,
nn.Conv2d(
2 * base_depth, 2 * base_depth, kernel_size=5, stride=2, padding=2
),
activation,
nn.Conv2d(
2 * base_depth, 4 * latent_size, kernel_size=7, stride=1, padding=0
),
activation,
nn.Flatten(),
nn.Linear(4 * latent_size, 2 * latent_size),
)
def forward(self, x):
# Scale images from [0, 1] to [-1, 1]
x = 2 * x - 1
net_output = self.encoder_net(x)
# Split the output into mean and log_variance
mean = net_output[..., : self.latent_size]
log_scale = net_output[..., self.latent_size :] + _softplus_inverse(
torch.tensor(1.0)
)
scale = F.softplus(log_scale)
# Return the distribution
return td.Independent(td.Normal(loc=mean, scale=scale), 1)
Decoder
class Decoder(nn.Module):
"""Decoder network for VAE."""
def __init__(
self,
latent_size: int,
output_shape: List[int],
base_depth: int,
activation_name: str,
):
super(Decoder, self).__init__()
activation = get_activation(activation_name)
self.latent_size = latent_size
self.output_shape = output_shape
# Define the decoder network
self.decoder_net = nn.Sequential(
nn.ConvTranspose2d(
latent_size, 2 * base_depth, kernel_size=7, stride=1, padding=0
),
activation,
nn.ConvTranspose2d(
2 * base_depth, 2 * base_depth, kernel_size=5, stride=1, padding=2
),
activation,
nn.ConvTranspose2d(
2 * base_depth,
2 * base_depth,
kernel_size=5,
stride=2,
padding=2,
output_padding=1,
),
activation,
nn.ConvTranspose2d(
2 * base_depth, base_depth, kernel_size=5, stride=1, padding=2
),
activation,
nn.ConvTranspose2d(
base_depth,
base_depth,
kernel_size=5,
stride=2,
padding=2,
output_padding=1,
),
activation,
nn.ConvTranspose2d(
base_depth, base_depth, kernel_size=5, stride=1, padding=2
),
activation,
nn.Conv2d(base_depth, output_shape[-1], kernel_size=5, stride=1, padding=2),
)
def forward(self, z):
# Remember original shape
original_shape = z.shape
# Reshape for convolutional decoder
z = z.reshape(-1, self.latent_size, 1, 1)
# Apply decoder network
logits = self.decoder_net(z)
# Reshape logits to match original dimensions + output shape
new_shape = original_shape[:-1] + tuple(self.output_shape)
logits = logits.reshape(new_shape)
# Return the distribution
return td.Independent(td.Bernoulli(logits=logits), len(self.output_shape))Prior
class MixturePrior:
"""Mixture of Gaussians prior distribution."""
def __init__(self, latent_size: int, mixture_components: int, device: torch.device):
self.latent_size = latent_size
self.mixture_components = mixture_components
self.device = device
if mixture_components == 1:
# Standard normal prior
self.distribution = td.MultivariateNormal(
loc=torch.zeros(latent_size, device=device),
covariance_matrix=torch.eye(latent_size, device=device),
)
else:
# Mixture of Gaussians prior
self.locs = nn.Parameter(
torch.randn(mixture_components, latent_size, device=device) * 0.1
)
self.raw_scales = nn.Parameter(
torch.randn(mixture_components, latent_size, device=device) * 0.1
)
self.mixture_logits = nn.Parameter(
torch.zeros(mixture_components, device=device)
)
# The distribution is created in the forward pass
self.distribution = None
self._create_distribution()
def _create_distribution(self):
if self.mixture_components == 1:
return
# Create component distributions (diagonal normal)
scales = F.softplus(self.raw_scales)
component_distribution = td.Independent(
td.Normal(loc=self.locs, scale=scales), 1
)
# Create mixture distribution
mixture_distribution = td.Categorical(logits=self.mixture_logits)
# Create the mixture
self.distribution = td.MixtureSameFamily(
mixture_distribution=mixture_distribution,
component_distribution=component_distribution,
)
def log_prob(self, z):
if self.mixture_components > 1:
# Ensure distribution is up-to-date with current parameters
self._create_distribution()
return self.distribution.log_prob(z)
def sample(self, sample_shape=torch.Size()):
if self.mixture_components > 1:
# Ensure distribution is up-to-date with current parameters
self._create_distribution()
return self.distribution.sample(sample_shape)VAE
class VAE(nn.Module):
"""Variational Autoencoder model."""
def __init__(
self,
latent_size: int,
base_depth: int,
activation_name: str,
mixture_components: int,
analytic_kl: bool,
n_samples: int,
):
super(VAE, self).__init__()
if analytic_kl and mixture_components != 1:
raise NotImplementedError(
"Using analytic_kl is only supported when mixture_components = 1 "
"since there's no closed form otherwise."
)
self.latent_size = latent_size
self.analytic_kl = analytic_kl
self.n_samples = n_samples
# Create encoder, decoder, and prior
self.encoder = Encoder(latent_size, base_depth, activation_name)
self.decoder = Decoder(latent_size, IMAGE_SHAPE, base_depth, activation_name)
self.prior = MixturePrior(latent_size, mixture_components, device=device)
def forward(self, x):
"""Forward pass through the VAE model."""
# Encode input to get approximate posterior q(z|x)
approx_posterior = self.encoder(x)
# Sample from the posterior
approx_posterior_sample = approx_posterior.rsample([self.n_samples])
# Decode the samples to get p(x|z)
decoder_likelihood = self.decoder(approx_posterior_sample)
# Calculate distortion (negative log likelihood)
distortion = -decoder_likelihood.log_prob(x)
# Calculate rate (KL divergence)
if self.analytic_kl:
# For standard normal prior we can use a closed form
mean = approx_posterior.mean
var = approx_posterior.variance
rate = 0.5 * torch.sum(mean.pow(2) + var - 1 - torch.log(var), dim=-1)
else:
# Monte Carlo estimate of KL
rate = approx_posterior.log_prob(
approx_posterior_sample
) - self.prior.log_prob(approx_posterior_sample)
# ELBO = -rate - distortion
elbo_local = -(rate + distortion)
# Compute IWAE bound
log_weight_mean = torch.logsumexp(elbo_local, dim=0) - torch.log(
torch.tensor(self.n_samples, dtype=torch.float, device=device)
)
# Return various quantities for training and evaluation
return {
"elbo": elbo_local.mean(),
"elbo_iwae": log_weight_mean.mean(),
"rate": rate.mean(),
"distortion": distortion.mean(),
"approx_posterior": approx_posterior,
"decoder_likelihood": decoder_likelihood,
}Binarized MNIST
# Dataset loading functions
ROOT_PATH = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
FILE_TEMPLATE = "binarized_mnist_{split}.amat"
def download(directory, filename):
"""Downloads a file."""
filepath = os.path.join(directory, filename)
if os.path.exists(filepath):
return filepath
if not os.path.exists(directory):
os.makedirs(directory)
url = os.path.join(ROOT_PATH, filename)
print(f"Downloading {url} to {filepath}")
urllib.request.urlretrieve(url, filepath)
return filepath
class BinarizedMNIST(Dataset):
"""Binarized MNIST dataset."""
def __init__(self, directory, split_name):
super(BinarizedMNIST, self).__init__()
# Download and load the data
amat_file = download(directory, FILE_TEMPLATE.format(split=split_name))
with open(amat_file, "r") as f:
lines = f.readlines()
# Process each line to get binary data
data = []
for line in lines:
row = [int(bit) for bit in line.strip().split()]
data.append(row)
# Convert to tensor and reshape
self.data = (
torch.tensor(data, dtype=torch.float)
.reshape(-1, 28, 28, 1)
.permute(0, 3, 1, 2)
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], torch.tensor(0, dtype=torch.long) # dummy label
def build_fake_data(batch_size):
"""Builds fake MNIST-style data for unit testing."""
random_sample = torch.rand(batch_size, 1, 28, 28)
train_dataset = TensorDataset(
random_sample, torch.zeros(batch_size, dtype=torch.long)
)
eval_dataset = TensorDataset(
random_sample, torch.zeros(batch_size, dtype=torch.long)
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
return train_loader, eval_loader
def build_dataloaders(data_dir, batch_size):
"""Builds data loaders for training and evaluation."""
train_dataset = BinarizedMNIST(data_dir, "train")
valid_dataset = BinarizedMNIST(data_dir, "valid")
train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
return train_loader, valid_loaderTrainer
# Training and evaluation
def train(model, train_loader, optimizer, epoch, writer=None):
"""Train for one epoch."""
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
# Forward pass
outputs = model(data)
loss = -outputs["elbo"] # Negative ELBO is the loss
# Backward pass and optimize
loss.backward()
optimizer.step()
# Track loss
train_loss += loss.item()
avg_loss = train_loss / len(train_loader)
if writer is not None:
writer.add_scalar("Loss/train", avg_loss, epoch)
return avg_loss
def evaluate(model, eval_loader, epoch, writer=None):
"""Evaluate the model."""
model.eval()
metrics = {"elbo": 0, "elbo_iwae": 0, "rate": 0, "distortion": 0}
with torch.no_grad():
for data, _ in eval_loader:
data = data.to(device)
outputs = model(data)
for k in metrics:
metrics[k] += outputs[k].item()
# Only need the first batch for visualization
if "eval_data" not in locals():
eval_data = data
eval_outputs = outputs
# Average metrics
for k in metrics:
metrics[k] /= len(eval_loader)
if writer is not None:
for k, v in metrics.items():
writer.add_scalar(f"Metrics/{k}", v, epoch)
return metrics, eval_data, eval_outputs
def visualize(model, data, outputs, epoch, writer=None):
"""Generate visualizations of reconstructions and samples."""
with torch.no_grad():
# Get reconstructions
recon_samples = outputs["decoder_likelihood"].mean[:3, :16]
# Generate random samples from prior
random_z = model.prior.sample(torch.Size([16]))
random_samples = model.decoder(random_z).mean
if writer is not None:
# Input images - squeeze first dim to make it (H,W) format
input_grid = pack_images(data[:16].unsqueeze(1), 1, 16)[0].squeeze(0)
writer.add_image("input", input_grid, epoch, dataformats='HW')
# Reconstruction - squeeze first dim to make it (H,W) format
recon_grid = pack_images(recon_samples, 3, 16)[0].squeeze(0)
writer.add_image("recon/mean", recon_grid, epoch, dataformats='HW')
# Random samples - squeeze first dim to make it (H,W) format
random_grid = pack_images(random_samples, 4, 4)[0].squeeze(0)
writer.add_image("random/mean", random_grid, epoch, dataformats='HW')
return {"input": data[:16], "recon": recon_samples, "random": random_samples}Application
def get_scheduler(optimizer, max_steps):
"""Create a cosine annealing scheduler."""
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)
def main():
# Set parameters
params = vars(args)
activation = params["activation"]
# Create directories
if params["delete_existing"] and os.path.exists(params["model_dir"]):
print(f"Deleting old log directory at {params['model_dir']}")
import shutil
shutil.rmtree(params["model_dir"])
os.makedirs(params["model_dir"], exist_ok=True)
# Setup data loaders
if params["fake_data"]:
train_loader, eval_loader = build_fake_data(params["batch_size"])
else:
train_loader, eval_loader = build_dataloaders(
params["data_dir"], params["batch_size"]
)
# Create model
model = VAE(
latent_size=params["latent_size"],
base_depth=params["base_depth"],
activation_name=params["activation"],
mixture_components=params["mixture_components"],
analytic_kl=params["analytic_kl"],
n_samples=params["n_samples"],
).to(device)
# Setup optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=params["learning_rate"])
scheduler = get_scheduler(optimizer, params["max_steps"])
# Setup TensorBoard writer
writer = SummaryWriter(log_dir=params["model_dir"])
# Training loop
total_steps = 0
epochs = params["max_steps"] // len(train_loader) + 1
viz_epochs = params["viz_steps"] // len(train_loader) + 1
for epoch in range(epochs):
# Train for one epoch
train_loss = train(model, train_loader, optimizer, epoch, writer)
scheduler.step()
# Evaluate and visualize at regular intervals
if epoch % viz_epochs == 0 or epoch == epochs - 1:
metrics, eval_data, eval_outputs = evaluate(
model, eval_loader, epoch, writer
)
vis = visualize(model, eval_data, eval_outputs, epoch, writer)
# Display metrics
print(
f"Epoch {epoch}: Train loss = {train_loss:.4f}, ELBO = {metrics['elbo']:.4f}, "
f"IWAE ELBO = {metrics['elbo_iwae']:.4f}, Rate = {metrics['rate']:.4f}, "
f"Distortion = {metrics['distortion']:.4f}"
)
# Display visualizations
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.title("Original")
show_image_grid(vis["input"], rows=2, cols=8, figsize=(6, 2))
plt.subplot(132)
plt.title("Reconstruction")
show_image_grid(vis["recon"], rows=3, cols=5, figsize=(6, 3))
plt.subplot(133)
plt.title("Generated")
show_image_grid(vis["random"], rows=4, cols=4, figsize=(4, 4))
writer.close()
print("Training completed!")if __name__ == "__main__":
main()Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat to /tmp/vae/data/binarized_mnist_train.amat
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat to /tmp/vae/data/binarized_mnist_valid.amat
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[11], line 2 1 if __name__ == "__main__": ----> 2 main() Cell In[10], line 51, in main() 47 viz_epochs = params["viz_steps"] // len(train_loader) + 1 49 for epoch in range(epochs): 50 # Train for one epoch ---> 51 train_loss = train(model, train_loader, optimizer, epoch, writer) 52 scheduler.step() 54 # Evaluate and visualize at regular intervals Cell In[9], line 12, in train(model, train_loader, optimizer, epoch, writer) 9 optimizer.zero_grad() 11 # Forward pass ---> 12 outputs = model(data) 13 loss = -outputs["elbo"] # Negative ELBO is the loss 15 # Backward pass and optimize File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() Cell In[7], line 39, in VAE.forward(self, x) 36 approx_posterior_sample = approx_posterior.rsample([self.n_samples]) 38 # Decode the samples to get p(x|z) ---> 39 decoder_likelihood = self.decoder(approx_posterior_sample) 41 # Calculate distortion (negative log likelihood) 42 distortion = -decoder_likelihood.log_prob(x) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() Cell In[5], line 64, in Decoder.forward(self, z) 61 z = z.reshape(-1, self.latent_size, 1, 1) 63 # Apply decoder network ---> 64 logits = self.decoder_net(z) 66 # Reshape logits to match original dimensions + output shape 67 new_shape = original_shape[:-1] + tuple(self.output_shape) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input) 248 def forward(self, input): 249 for module in self: --> 250 input = module(input) 251 return input File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:1162, in ConvTranspose2d.forward(self, input, output_size) 1151 num_spatial_dims = 2 1152 output_padding = self._output_padding( 1153 input, 1154 output_size, (...) 1159 self.dilation, # type: ignore[arg-type] 1160 ) -> 1162 return F.conv_transpose2d( 1163 input, 1164 self.weight, 1165 self.bias, 1166 self.stride, 1167 self.padding, 1168 output_padding, 1169 self.groups, 1170 self.dilation, 1171 ) KeyboardInterrupt: