Estimating Marginal Likelihood

with Naive Monte Carlo Sampling

To illustrate how estimating the marginal likelihood $p() $ via naive Monte Carlo sampling is computationally difficult in high-dimensional spaces like images, especially when using a decoder-only approach (e.g., as in a VAE’s decoder) without proper posterior inference.

The key steps are as follows:

  1. We train only a simple decoder neural network (from latent $ $ to image $ $) to mimic $p( | ) $.

  2. Sample $ (0, I) $, decode it to $ $, and compute the approximate marginal likelihood estimate using many $ $ samples.

This works very poorely without smart latent sampling (e.g., posterior inference via encoder in VAE). When we sample blindly from the prior $(0, I) $, very few $ $ will decode to high-quality digits — most will yield near-random noise, resulting in negligible $p( | ) $.

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import numpy as np

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
mnist = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
data_loader = DataLoader(mnist, batch_size=64, shuffle=True)

# A simple decoder network (like VAE's decoder)
class Decoder(nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 784),
            nn.Sigmoid()  # image pixels between 0 and 1
        )
    
    def forward(self, z):
        x = self.model(z)
        return x.view(-1, 1, 28, 28)

# Training a dummy decoder (from encoder features)
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.decoder = Decoder(latent_dim)
    
    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# Train autoencoder to use decoder later
def train_autoencoder():
    model = Autoencoder().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(5):
        for imgs, _ in data_loader:
            imgs = imgs.to(device)
            output = model(imgs)
            loss = criterion(output, imgs)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}: Loss = {loss.item():.4f}")
    
    return model.decoder

# Estimate p(x) ≈ 1/m ∑ p(x | z_i) using L2 distance as proxy
def monte_carlo_marginal_likelihood(decoder, x_target, num_samples=10000, latent_dim=20):
    z_samples = torch.randn(num_samples, latent_dim).to(device)
    x_decoded = decoder(z_samples).view(num_samples, -1)

    # Flatten the target image
    x_target_flat = x_target.view(1, -1).to(device)

    # Using negative squared error as a proxy log likelihood (not true p(x|z))
    mse = torch.sum((x_decoded - x_target_flat) ** 2, dim=1)
    log_likelihood_proxy = -mse

    # LogSumExp trick for stability
    max_ll = torch.max(log_likelihood_proxy)
    p_x_est = max_ll + torch.log(torch.mean(torch.exp(log_likelihood_proxy - max_ll)))

    return p_x_est.item()

# Main
if __name__ == "__main__":
    print("Training autoencoder (to get decoder)...")
    decoder = train_autoencoder()

    # Pick one test image to evaluate
    test_img, _ = mnist[0]
    plt.imshow(test_img.squeeze(), cmap='gray')
    plt.title("Target Image for p(x) Estimation")
    plt.axis('off')
    plt.show()

    print("Estimating p(x) via naive Monte Carlo...")
    px_log_est = monte_carlo_marginal_likelihood(decoder, test_img, num_samples=10000)
    print(f"Estimated log p(x) ≈ {px_log_est:.2f}")
Training autoencoder (to get decoder)...
Epoch 1: Loss = 0.0137
Epoch 2: Loss = 0.0151
Epoch 3: Loss = 0.0115
Epoch 4: Loss = 0.0103
Epoch 5: Loss = 0.0089

Estimating p(x) via naive Monte Carlo...
Estimated log p(x) ≈ -58.12