"""Trains a variational auto-encoder (VAE) on binarized MNIST using PyTorch.
The VAE defines a generative model in which a latent code `Z` is sampled from a
prior `p(Z)`, then used to generate an observation `X` by way of a decoder
`p(X|Z)`. The full reconstruction follows
```none
X ~ p(X) # A random image from some dataset.
Z ~ q(Z | X) # A random encoding of the original image ("encoder").
Xhat ~ p(Xhat | Z) # A random reconstruction of the original image
# ("decoder").
```
To fit the VAE, we assume an approximate representation of the posterior in the
form of an encoder `q(Z|X)`. We minimize the KL divergence between `q(Z|X)` and
the true posterior `p(Z|X)`: this is equivalent to maximizing the evidence lower
bound (ELBO),
```none
-log p(x)
= -log int dz p(x|z) p(z)
= -log int dz q(z|x) p(x|z) p(z) / q(z|x)
<= int dz q(z|x) (-log[ p(x|z) p(z) / q(z|x) ]) # Jensen's Inequality
=: KL[q(Z|x) || p(x|Z)p(Z)]
= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
```
-or-
```none
-log p(x)
= KL[q(Z|x) || p(x|Z)p(Z)] - KL[q(Z|x) || p(Z|x)]
<= KL[q(Z|x) || p(x|Z)p(Z) # Positivity of KL
= -E_{Z~q(Z|x)}[log p(x|Z)] + KL[q(Z|x) || p(Z)]
```
The `-E_{Z~q(Z|x)}[log p(x|Z)]` term is an expected reconstruction loss and
`KL[q(Z|x) || p(Z)]` is a kind of distributional regularizer.
This implementation supports both standard normal prior and mixture of Gaussians prior.
Using a single Gaussian component is equivalent to the fixed standard normal prior.
This implementation also supports using the analytic KL (KL[q(Z|x) || p(Z)]) with the
`analytic_kl` flag. Using the analytic KL is only supported when
`mixture_components` is set to 1 since otherwise no analytic form is known.
We also compute tighter bounds, the IWAE [Burda et. al. (2015)][2].
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import time
from typing import Tuple, List, Callable, Optional, Dict, Any
import numpy as np
import urllib.request
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as td
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
Variational Autoencoder from Scratch - Torch
This example implements a Variational Autoencoder (VAE) on binarized MNIST using PyTorch. It is based on the TensorFlow Probability example but adapted to use PyTorch 2.x.
!nvidia-smi
Wed Apr 30 20:37:14 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA RTX A4500 Laptop GPU Off | 00000000:01:00.0 On | Off |
| N/A 53C P3 25W / 115W | 313MiB / 16384MiB | 23% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+
# Check if GPU is available
try:
# First check if CUDA is available through PyTorch
= torch.cuda.is_available()
cuda_available if cuda_available:
= torch.device("cuda:0")
device print(f"Using GPU device: {torch.cuda.get_device_name(0)}")
else:
= torch.device("cpu")
device print("CUDA is not available. Using CPU instead.")
except Exception as e:
= torch.device("cpu")
device print(f"Error initializing CUDA: {e}")
print("Falling back to CPU.")
print(f"Using device: {device}")
CUDA is not available. Using CPU instead.
Using device: cpu
/workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:129: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
return torch._C._cuda_getDeviceCount() > 0
# Define constants and hyperparameters
= [28, 28, 1]
IMAGE_SHAPE
# Parse arguments (simulating flags in TF)
= argparse.ArgumentParser(description="VAE on binarized MNIST")
parser
parser.add_argument("--learning_rate", type=float, default=0.001, help="Initial learning rate"
)
parser.add_argument("--max_steps", type=int, default=5001, help="Number of training steps to run"
)
parser.add_argument("--latent_size",
type=int,
=16,
defaulthelp="Number of dimensions in the latent code (z)",
)"--base_depth", type=int, default=32, help="Base depth for layers")
parser.add_argument(
parser.add_argument("--activation", type=str, default="leaky_relu", help="Activation function"
)"--batch_size", type=int, default=32, help="Batch size")
parser.add_argument(
parser.add_argument("--n_samples", type=int, default=16, help="Number of samples to use in encoding"
)
parser.add_argument("--mixture_components",
type=int,
=100,
defaulthelp="Number of mixture components in the prior",
)
parser.add_argument("--analytic_kl",
="store_true",
action=False,
defaulthelp="Whether to use analytic KL",
)
parser.add_argument("--data_dir",
type=str,
=os.path.join("/tmp", "vae/data"),
defaulthelp="Directory for data",
)
parser.add_argument("--model_dir",
type=str,
=os.path.join("/tmp", "vae"),
defaulthelp="Directory to put the model's fit",
)
parser.add_argument("--viz_steps",
type=int,
=500,
defaulthelp="Frequency at which to save visualizations",
)
parser.add_argument("--fake_data",
="store_true",
action=False,
defaulthelp="If true, uses fake data instead of MNIST",
)
parser.add_argument("--delete_existing",
="store_true",
action=False,
defaulthelp="If true, deletes existing model_dir directory",
)
# In notebook environment, use default values instead of parsing arguments
= argparse.Namespace(
args =0.001,
learning_rate=5001,
max_steps=16,
latent_size=32,
base_depth="leaky_relu",
activation=32,
batch_size=16,
n_samples=100,
mixture_components=False,
analytic_kl=os.path.join("/tmp", "vae/data"),
data_dir=os.path.join("/tmp", "vae"),
model_dir=500,
viz_steps=False,
fake_data=False,
delete_existing )
# Utility functions
def _softplus_inverse(x):
"""Helper which computes the function inverse of softplus."""
return torch.log(torch.expm1(x))
# Helper to get activation function by name
def get_activation(name):
if name == "leaky_relu":
return nn.LeakyReLU(0.2)
elif name == "relu":
return nn.ReLU()
elif name == "elu":
return nn.ELU()
else:
raise ValueError(f"Unsupported activation function: {name}")
def pack_images(images, rows, cols):
"""Helper utility to make a field of images."""
= images.shape
shape = shape[-3], shape[-2], shape[-1]
width, height, depth = images.reshape(-1, width, height, depth)
images = images.shape[0]
batch = min(rows, batch)
rows = min(batch // rows, cols)
cols = images[: rows * cols]
images = images.reshape(rows, cols, width, height, depth)
images = images.permute(0, 2, 1, 3, 4)
images = images.reshape(1, rows * width, cols * height, depth)
images return images.squeeze(3) # Remove depth dimension for plotting if depth=1
def show_image_grid(images, rows=8, cols=8, figsize=(10, 10)):
"""Display a grid of images."""
=figsize)
plt.figure(figsize= pack_images(images, rows, cols)[0].cpu().numpy()
grid # Remove batch dimension if it exists
if grid.ndim == 3 and grid.shape[0] == 1:
= grid.squeeze(0)
grid # Ensure we have a 2D tensor for imshow
elif grid.ndim == 3:
# If we have a 3D tensor that's not (C,H,W) format with C=3 or C=4,
# reshape it to a 2D image by flattening the first dimension
if grid.shape[0] not in [3, 4]:
= grid.reshape(-1, grid.shape[-1])
grid
="gray")
plt.imshow(grid, cmap"off")
plt.axis(
plt.tight_layout() plt.show()
Encoder
# VAE Model Components
class Encoder(nn.Module):
"""Encoder network for VAE."""
def __init__(self, latent_size: int, base_depth: int, activation_name: str):
super(Encoder, self).__init__()
= get_activation(activation_name)
activation
self.latent_size = latent_size
# Define the encoder network
self.encoder_net = nn.Sequential(
1, base_depth, kernel_size=5, stride=1, padding=2),
nn.Conv2d(
activation,=5, stride=2, padding=2),
nn.Conv2d(base_depth, base_depth, kernel_size
activation,2 * base_depth, kernel_size=5, stride=1, padding=2),
nn.Conv2d(base_depth,
activation,
nn.Conv2d(2 * base_depth, 2 * base_depth, kernel_size=5, stride=2, padding=2
),
activation,
nn.Conv2d(2 * base_depth, 4 * latent_size, kernel_size=7, stride=1, padding=0
),
activation,
nn.Flatten(),4 * latent_size, 2 * latent_size),
nn.Linear(
)
def forward(self, x):
# Scale images from [0, 1] to [-1, 1]
= 2 * x - 1
x = self.encoder_net(x)
net_output
# Split the output into mean and log_variance
= net_output[..., : self.latent_size]
mean = net_output[..., self.latent_size :] + _softplus_inverse(
log_scale 1.0)
torch.tensor(
)= F.softplus(log_scale)
scale
# Return the distribution
return td.Independent(td.Normal(loc=mean, scale=scale), 1)
Decoder
class Decoder(nn.Module):
"""Decoder network for VAE."""
def __init__(
self,
int,
latent_size: int],
output_shape: List[int,
base_depth: str,
activation_name:
):super(Decoder, self).__init__()
= get_activation(activation_name)
activation
self.latent_size = latent_size
self.output_shape = output_shape
# Define the decoder network
self.decoder_net = nn.Sequential(
nn.ConvTranspose2d(2 * base_depth, kernel_size=7, stride=1, padding=0
latent_size,
),
activation,
nn.ConvTranspose2d(2 * base_depth, 2 * base_depth, kernel_size=5, stride=1, padding=2
),
activation,
nn.ConvTranspose2d(2 * base_depth,
2 * base_depth,
=5,
kernel_size=2,
stride=2,
padding=1,
output_padding
),
activation,
nn.ConvTranspose2d(2 * base_depth, base_depth, kernel_size=5, stride=1, padding=2
),
activation,
nn.ConvTranspose2d(
base_depth,
base_depth,=5,
kernel_size=2,
stride=2,
padding=1,
output_padding
),
activation,
nn.ConvTranspose2d(=5, stride=1, padding=2
base_depth, base_depth, kernel_size
),
activation,-1], kernel_size=5, stride=1, padding=2),
nn.Conv2d(base_depth, output_shape[
)
def forward(self, z):
# Remember original shape
= z.shape
original_shape
# Reshape for convolutional decoder
= z.reshape(-1, self.latent_size, 1, 1)
z
# Apply decoder network
= self.decoder_net(z)
logits
# Reshape logits to match original dimensions + output shape
= original_shape[:-1] + tuple(self.output_shape)
new_shape = logits.reshape(new_shape)
logits
# Return the distribution
return td.Independent(td.Bernoulli(logits=logits), len(self.output_shape))
Prior
class MixturePrior:
"""Mixture of Gaussians prior distribution."""
def __init__(self, latent_size: int, mixture_components: int, device: torch.device):
self.latent_size = latent_size
self.mixture_components = mixture_components
self.device = device
if mixture_components == 1:
# Standard normal prior
self.distribution = td.MultivariateNormal(
=torch.zeros(latent_size, device=device),
loc=torch.eye(latent_size, device=device),
covariance_matrix
)else:
# Mixture of Gaussians prior
self.locs = nn.Parameter(
=device) * 0.1
torch.randn(mixture_components, latent_size, device
)self.raw_scales = nn.Parameter(
=device) * 0.1
torch.randn(mixture_components, latent_size, device
)self.mixture_logits = nn.Parameter(
=device)
torch.zeros(mixture_components, device
)
# The distribution is created in the forward pass
self.distribution = None
self._create_distribution()
def _create_distribution(self):
if self.mixture_components == 1:
return
# Create component distributions (diagonal normal)
= F.softplus(self.raw_scales)
scales = td.Independent(
component_distribution =self.locs, scale=scales), 1
td.Normal(loc
)
# Create mixture distribution
= td.Categorical(logits=self.mixture_logits)
mixture_distribution
# Create the mixture
self.distribution = td.MixtureSameFamily(
=mixture_distribution,
mixture_distribution=component_distribution,
component_distribution
)
def log_prob(self, z):
if self.mixture_components > 1:
# Ensure distribution is up-to-date with current parameters
self._create_distribution()
return self.distribution.log_prob(z)
def sample(self, sample_shape=torch.Size()):
if self.mixture_components > 1:
# Ensure distribution is up-to-date with current parameters
self._create_distribution()
return self.distribution.sample(sample_shape)
VAE
class VAE(nn.Module):
"""Variational Autoencoder model."""
def __init__(
self,
int,
latent_size: int,
base_depth: str,
activation_name: int,
mixture_components: bool,
analytic_kl: int,
n_samples:
):super(VAE, self).__init__()
if analytic_kl and mixture_components != 1:
raise NotImplementedError(
"Using analytic_kl is only supported when mixture_components = 1 "
"since there's no closed form otherwise."
)
self.latent_size = latent_size
self.analytic_kl = analytic_kl
self.n_samples = n_samples
# Create encoder, decoder, and prior
self.encoder = Encoder(latent_size, base_depth, activation_name)
self.decoder = Decoder(latent_size, IMAGE_SHAPE, base_depth, activation_name)
self.prior = MixturePrior(latent_size, mixture_components, device=device)
def forward(self, x):
"""Forward pass through the VAE model."""
# Encode input to get approximate posterior q(z|x)
= self.encoder(x)
approx_posterior
# Sample from the posterior
= approx_posterior.rsample([self.n_samples])
approx_posterior_sample
# Decode the samples to get p(x|z)
= self.decoder(approx_posterior_sample)
decoder_likelihood
# Calculate distortion (negative log likelihood)
= -decoder_likelihood.log_prob(x)
distortion
# Calculate rate (KL divergence)
if self.analytic_kl:
# For standard normal prior we can use a closed form
= approx_posterior.mean
mean = approx_posterior.variance
var = 0.5 * torch.sum(mean.pow(2) + var - 1 - torch.log(var), dim=-1)
rate else:
# Monte Carlo estimate of KL
= approx_posterior.log_prob(
rate
approx_posterior_sample- self.prior.log_prob(approx_posterior_sample)
)
# ELBO = -rate - distortion
= -(rate + distortion)
elbo_local
# Compute IWAE bound
= torch.logsumexp(elbo_local, dim=0) - torch.log(
log_weight_mean self.n_samples, dtype=torch.float, device=device)
torch.tensor(
)
# Return various quantities for training and evaluation
return {
"elbo": elbo_local.mean(),
"elbo_iwae": log_weight_mean.mean(),
"rate": rate.mean(),
"distortion": distortion.mean(),
"approx_posterior": approx_posterior,
"decoder_likelihood": decoder_likelihood,
}
Binarized MNIST
# Dataset loading functions
= "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/"
ROOT_PATH = "binarized_mnist_{split}.amat"
FILE_TEMPLATE
def download(directory, filename):
"""Downloads a file."""
= os.path.join(directory, filename)
filepath if os.path.exists(filepath):
return filepath
if not os.path.exists(directory):
os.makedirs(directory)= os.path.join(ROOT_PATH, filename)
url print(f"Downloading {url} to {filepath}")
urllib.request.urlretrieve(url, filepath)return filepath
class BinarizedMNIST(Dataset):
"""Binarized MNIST dataset."""
def __init__(self, directory, split_name):
super(BinarizedMNIST, self).__init__()
# Download and load the data
= download(directory, FILE_TEMPLATE.format(split=split_name))
amat_file with open(amat_file, "r") as f:
= f.readlines()
lines
# Process each line to get binary data
= []
data for line in lines:
= [int(bit) for bit in line.strip().split()]
row
data.append(row)
# Convert to tensor and reshape
self.data = (
=torch.float)
torch.tensor(data, dtype-1, 28, 28, 1)
.reshape(0, 3, 1, 2)
.permute(
)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], torch.tensor(0, dtype=torch.long) # dummy label
def build_fake_data(batch_size):
"""Builds fake MNIST-style data for unit testing."""
= torch.rand(batch_size, 1, 28, 28)
random_sample
= TensorDataset(
train_dataset =torch.long)
random_sample, torch.zeros(batch_size, dtype
)= TensorDataset(
eval_dataset =torch.long)
random_sample, torch.zeros(batch_size, dtype
)
= DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
eval_loader
return train_loader, eval_loader
def build_dataloaders(data_dir, batch_size):
"""Builds data loaders for training and evaluation."""
= BinarizedMNIST(data_dir, "train")
train_dataset = BinarizedMNIST(data_dir, "valid")
valid_dataset
= DataLoader(
train_loader =batch_size, shuffle=True, drop_last=True
train_dataset, batch_size
)= DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
valid_loader
return train_loader, valid_loader
Trainer
# Training and evaluation
def train(model, train_loader, optimizer, epoch, writer=None):
"""Train for one epoch."""
model.train()= 0
train_loss
for batch_idx, (data, _) in enumerate(train_loader):
= data.to(device)
data
optimizer.zero_grad()
# Forward pass
= model(data)
outputs = -outputs["elbo"] # Negative ELBO is the loss
loss
# Backward pass and optimize
loss.backward()
optimizer.step()
# Track loss
+= loss.item()
train_loss
= train_loss / len(train_loader)
avg_loss
if writer is not None:
"Loss/train", avg_loss, epoch)
writer.add_scalar(
return avg_loss
def evaluate(model, eval_loader, epoch, writer=None):
"""Evaluate the model."""
eval()
model.= {"elbo": 0, "elbo_iwae": 0, "rate": 0, "distortion": 0}
metrics
with torch.no_grad():
for data, _ in eval_loader:
= data.to(device)
data = model(data)
outputs
for k in metrics:
+= outputs[k].item()
metrics[k]
# Only need the first batch for visualization
if "eval_data" not in locals():
= data
eval_data = outputs
eval_outputs
# Average metrics
for k in metrics:
/= len(eval_loader)
metrics[k]
if writer is not None:
for k, v in metrics.items():
f"Metrics/{k}", v, epoch)
writer.add_scalar(
return metrics, eval_data, eval_outputs
def visualize(model, data, outputs, epoch, writer=None):
"""Generate visualizations of reconstructions and samples."""
with torch.no_grad():
# Get reconstructions
= outputs["decoder_likelihood"].mean[:3, :16]
recon_samples
# Generate random samples from prior
= model.prior.sample(torch.Size([16]))
random_z = model.decoder(random_z).mean
random_samples
if writer is not None:
# Input images - squeeze first dim to make it (H,W) format
= pack_images(data[:16].unsqueeze(1), 1, 16)[0].squeeze(0)
input_grid "input", input_grid, epoch, dataformats='HW')
writer.add_image(
# Reconstruction - squeeze first dim to make it (H,W) format
= pack_images(recon_samples, 3, 16)[0].squeeze(0)
recon_grid "recon/mean", recon_grid, epoch, dataformats='HW')
writer.add_image(
# Random samples - squeeze first dim to make it (H,W) format
= pack_images(random_samples, 4, 4)[0].squeeze(0)
random_grid "random/mean", random_grid, epoch, dataformats='HW')
writer.add_image(
return {"input": data[:16], "recon": recon_samples, "random": random_samples}
Application
def get_scheduler(optimizer, max_steps):
"""Create a cosine annealing scheduler."""
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_steps)
def main():
# Set parameters
= vars(args)
params = params["activation"]
activation
# Create directories
if params["delete_existing"] and os.path.exists(params["model_dir"]):
print(f"Deleting old log directory at {params['model_dir']}")
import shutil
"model_dir"])
shutil.rmtree(params["model_dir"], exist_ok=True)
os.makedirs(params[
# Setup data loaders
if params["fake_data"]:
= build_fake_data(params["batch_size"])
train_loader, eval_loader else:
= build_dataloaders(
train_loader, eval_loader "data_dir"], params["batch_size"]
params[
)
# Create model
= VAE(
model =params["latent_size"],
latent_size=params["base_depth"],
base_depth=params["activation"],
activation_name=params["mixture_components"],
mixture_components=params["analytic_kl"],
analytic_kl=params["n_samples"],
n_samples
).to(device)
# Setup optimizer and scheduler
= optim.Adam(model.parameters(), lr=params["learning_rate"])
optimizer = get_scheduler(optimizer, params["max_steps"])
scheduler
# Setup TensorBoard writer
= SummaryWriter(log_dir=params["model_dir"])
writer
# Training loop
= 0
total_steps = params["max_steps"] // len(train_loader) + 1
epochs = params["viz_steps"] // len(train_loader) + 1
viz_epochs
for epoch in range(epochs):
# Train for one epoch
= train(model, train_loader, optimizer, epoch, writer)
train_loss
scheduler.step()
# Evaluate and visualize at regular intervals
if epoch % viz_epochs == 0 or epoch == epochs - 1:
= evaluate(
metrics, eval_data, eval_outputs
model, eval_loader, epoch, writer
)= visualize(model, eval_data, eval_outputs, epoch, writer)
vis
# Display metrics
print(
f"Epoch {epoch}: Train loss = {train_loss:.4f}, ELBO = {metrics['elbo']:.4f}, "
f"IWAE ELBO = {metrics['elbo_iwae']:.4f}, Rate = {metrics['rate']:.4f}, "
f"Distortion = {metrics['distortion']:.4f}"
)
# Display visualizations
=(12, 4))
plt.figure(figsize131)
plt.subplot("Original")
plt.title("input"], rows=2, cols=8, figsize=(6, 2))
show_image_grid(vis[
132)
plt.subplot("Reconstruction")
plt.title("recon"], rows=3, cols=5, figsize=(6, 3))
show_image_grid(vis[
133)
plt.subplot("Generated")
plt.title("random"], rows=4, cols=4, figsize=(4, 4))
show_image_grid(vis[
writer.close()print("Training completed!")
if __name__ == "__main__":
main()
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_train.amat to /tmp/vae/data/binarized_mnist_train.amat
Downloading http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_valid.amat to /tmp/vae/data/binarized_mnist_valid.amat
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) Cell In[11], line 2 1 if __name__ == "__main__": ----> 2 main() Cell In[10], line 51, in main() 47 viz_epochs = params["viz_steps"] // len(train_loader) + 1 49 for epoch in range(epochs): 50 # Train for one epoch ---> 51 train_loss = train(model, train_loader, optimizer, epoch, writer) 52 scheduler.step() 54 # Evaluate and visualize at regular intervals Cell In[9], line 12, in train(model, train_loader, optimizer, epoch, writer) 9 optimizer.zero_grad() 11 # Forward pass ---> 12 outputs = model(data) 13 loss = -outputs["elbo"] # Negative ELBO is the loss 15 # Backward pass and optimize File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() Cell In[7], line 39, in VAE.forward(self, x) 36 approx_posterior_sample = approx_posterior.rsample([self.n_samples]) 38 # Decode the samples to get p(x|z) ---> 39 decoder_likelihood = self.decoder(approx_posterior_sample) 41 # Calculate distortion (negative log likelihood) 42 distortion = -decoder_likelihood.log_prob(x) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() Cell In[5], line 64, in Decoder.forward(self, z) 61 z = z.reshape(-1, self.latent_size, 1, 1) 63 # Apply decoder network ---> 64 logits = self.decoder_net(z) 66 # Reshape logits to match original dimensions + output shape 67 new_shape = original_shape[:-1] + tuple(self.output_shape) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/container.py:250, in Sequential.forward(self, input) 248 def forward(self, input): 249 for module in self: --> 250 input = module(input) 251 return input File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs) 1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1738 else: -> 1739 return self._call_impl(*args, **kwargs) File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs) 1745 # If we don't have any hooks, we want to skip the rest of the logic in 1746 # this function, and just call forward. 1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1748 or _global_backward_pre_hooks or _global_backward_hooks 1749 or _global_forward_hooks or _global_forward_pre_hooks): -> 1750 return forward_call(*args, **kwargs) 1752 result = None 1753 called_always_called_hooks = set() File /workspaces/engineering-ai-agents/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:1162, in ConvTranspose2d.forward(self, input, output_size) 1151 num_spatial_dims = 2 1152 output_padding = self._output_padding( 1153 input, 1154 output_size, (...) 1159 self.dilation, # type: ignore[arg-type] 1160 ) -> 1162 return F.conv_transpose2d( 1163 input, 1164 self.weight, 1165 self.bias, 1166 self.stride, 1167 self.padding, 1168 output_padding, 1169 self.groups, 1170 self.dilation, 1171 ) KeyboardInterrupt: