To illustrate how estimating the marginal likelihood $p() $ via naive Monte Carlo sampling is computationally difficult in high-dimensional spaces like images, especially when using a decoder-only approach (e.g., as in a VAE’s decoder) without proper posterior inference.
The key steps are as follows:
We train only a simple decoder neural network (from latent $ $ to image $ $) to mimic $p( | ) $.
Sample $ (0, I) $, decode it to $ $, and compute the approximate marginal likelihood estimate using many $ $ samples.
This works very poorely without smart latent sampling (e.g., posterior inference via encoder in VAE). When we sample blindly from the prior $(0, I) $, very few $ $ will decode to high-quality digits — most will yield near-random noise, resulting in negligible $p( | ) $.
import torchimport torch.nn as nnimport torchvisionimport torchvision.transforms as transformsimport matplotlib.pyplot as pltfrom torch.utils.data import DataLoaderimport numpy as np# Set devicedevice = torch.device("cuda"if torch.cuda.is_available() else"cpu")# MNIST datasettransform = transforms.Compose([transforms.ToTensor()])mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)data_loader = DataLoader(mnist, batch_size=64, shuffle=True)# A simple decoder network (like VAE's decoder)class Decoder(nn.Module):def__init__(self, latent_dim=20):super().__init__()self.model = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid() # image pixels between 0 and 1 )def forward(self, z): x =self.model(z)return x.view(-1, 1, 28, 28)# Training a dummy decoder (from encoder features)class Autoencoder(nn.Module):def__init__(self, latent_dim=20):super().__init__()self.encoder = nn.Sequential( nn.Flatten(), nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, latent_dim) )self.decoder = Decoder(latent_dim)def forward(self, x): z =self.encoder(x)returnself.decoder(z)# Train autoencoder to use decoder laterdef train_autoencoder(): model = Autoencoder().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) criterion = nn.MSELoss()for epoch inrange(5):for imgs, _ in data_loader: imgs = imgs.to(device) output = model(imgs) loss = criterion(output, imgs) optimizer.zero_grad() loss.backward() optimizer.step()print(f"Epoch {epoch +1}: Loss = {loss.item():.4f}")return model.decoder# Estimate p(x) ≈ 1/m ∑ p(x | z_i) using L2 distance as proxydef monte_carlo_marginal_likelihood(decoder, x_target, num_samples=10000, latent_dim=20): z_samples = torch.randn(num_samples, latent_dim).to(device) x_decoded = decoder(z_samples).view(num_samples, -1)# Flatten the target image x_target_flat = x_target.view(1, -1).to(device)# Using negative squared error as a proxy log likelihood (not true p(x|z)) mse = torch.sum((x_decoded - x_target_flat) **2, dim=1) log_likelihood_proxy =-mse# LogSumExp trick for stability max_ll = torch.max(log_likelihood_proxy) p_x_est = max_ll + torch.log(torch.mean(torch.exp(log_likelihood_proxy - max_ll)))return p_x_est.item()# Mainif__name__=="__main__":print("Training autoencoder (to get decoder)...") decoder = train_autoencoder()# Pick one test image to evaluate test_img, _ = mnist[0] plt.imshow(test_img.squeeze(), cmap='gray') plt.title("Target Image for p(x) Estimation") plt.axis('off') plt.show()print("Estimating p(x) via naive Monte Carlo...") px_log_est = monte_carlo_marginal_likelihood(decoder, test_img, num_samples=10000)print(f"Estimated log p(x) ≈ {px_log_est:.2f}")
Training autoencoder (to get decoder)...
Epoch 1: Loss = 0.0137
Epoch 2: Loss = 0.0151
Epoch 3: Loss = 0.0115
Epoch 4: Loss = 0.0103
Epoch 5: Loss = 0.0089
Estimating p(x) via naive Monte Carlo...
Estimated log p(x) ≈ -58.12