import torch
import torch.nn as nn
import torch.distributions as D
# ----------------------------
# Utilities
# ----------------------------
def rank_uniform(x, eps=1e-3):
"""
Probability integral transform via empirical CDF (pseudo-observations).
Returns U in (0,1); eps shrinks away from 0/1 for stability.
"""
n = x.shape[0]
ranks = torch.argsort(torch.argsort(x, dim=0), dim=0).float() + 1.0
u = ranks / (n + 1.0)
return u.clamp(eps, 1 - eps)
def inv_std_normal(u):
"""Phi^{-1}(u) using torch.erfinv."""
return torch.sqrt(torch.tensor(2.0, device=u.device)) * torch.erfinv(2 * u - 1)
def tril_param_to_corr(L_unconstrained, jitter=1e-5):
"""
Map unconstrained lower-triangular params to a valid correlation matrix.
We build a covariance via L @ L^T, then scale to correlation.
Diagonal of L is exp() to keep PD; result is symmetric, PD, with unit diag.
"""
d = L_unconstrained.size(-1)
# build lower-tri with exp(diag) for positivity
L = torch.tril(L_unconstrained)
diag_idx = torch.arange(d, device=L.device)
L[..., diag_idx, diag_idx] = torch.exp(L[..., diag_idx, diag_idx])
Sigma = L @ L.transpose(-1, -2)
# scale to correlation
std = torch.sqrt(torch.diagonal(Sigma, dim1=-2, dim2=-1) + jitter)
Dinv = torch.diag_embed(1.0 / std)
R = Dinv @ Sigma @ Dinv
# enforce symmetry numerically
R = 0.5 * (R + R.transpose(-1, -2))
# add tiny jitter for stability
R = R + torch.eye(d, device=R.device) * jitter
return R
# ----------------------------
# Parametric marginals (Normals by default)
# ----------------------------
class NormalMarginals(nn.Module):
def __init__(self, d):
super().__init__()
self.mu = nn.Parameter(torch.zeros(d))
self.log_sigma = nn.Parameter(torch.zeros(d))
def forward(self, x):
"""Return log f_i(x_i) and u_i = F_i(x_i)."""
sigma = torch.exp(self.log_sigma)
dist = D.Normal(self.mu, sigma)
logpdf = dist.log_prob(x) # (n, d)
u = dist.cdf(x).clamp(1e-6, 1 - 1e-6) # (n, d)
return logpdf, u
# ----------------------------
# Gaussian Copula Module
# ----------------------------
class GaussianCopula(nn.Module):
def __init__(self, d):
super().__init__()
# Unconstrained lower-triangular for covariance construction
L = torch.zeros(d, d)
L.fill_(0.0)
# initialize diag to small positive value
L += torch.eye(d) * (-1.0) # exp(-1) ~ 0.367 for stability
self.L_unconstrained = nn.Parameter(L)
def corr(self):
return tril_param_to_corr(self.L_unconstrained)
def log_copula_density(self, u):
"""
u: (n, d) with entries in (0,1).
c(u) = |R|^{-1/2} * exp(-0.5 z^T (R^{-1} - I) z), where z = Phi^{-1}(u)
"""
z = inv_std_normal(u) # (n, d)
R = self.corr() # (d, d)
d = R.size(0)
# Cholesky for logdet and solves
L = torch.linalg.cholesky(R) # R = L L^T
logdetR = 2.0 * torch.log(torch.diagonal(L)).sum()
# Solve R^{-1} z via triangular solves
# solve L y = z => y
y = torch.cholesky_solve(z.unsqueeze(-1), L).squeeze(-1) # (n, d)
quad = (z * y).sum(-1) # z^T R^{-1} z
quad_I = (z * z).sum(-1)
logc = -0.5 * logdetR + (-0.5) * (quad - quad_I)
return logc # (n,)
# ----------------------------
# Full model: marginals + copula
# ----------------------------
class CopulaLikelihood(nn.Module):
def __init__(self, d, marginals=None, use_pseudo=False):
super().__init__()
self.d = d
self.copula = GaussianCopula(d)
self.marginals = marginals if marginals is not None else NormalMarginals(d)
self.use_pseudo = use_pseudo # if True, use empirical CDFs for u and ignore marginal logpdf
def forward(self, x):
"""
x: (n, d)
returns: total log-likelihood over n samples, per-sample terms as well
"""
if self.use_pseudo:
u = rank_uniform(x) # empirical PIT
logc = self.copula.log_copula_density(u)
logm = torch.zeros_like(logc) # ignored in pseudo-MLE
else:
logm, u = self.marginals(x) # (n,d), (n,d)
logc = self.copula.log_copula_density(u)
logm = logm.sum(-1) # sum over dimensions
ll = logc + logm # (n,)
return ll.sum(), {"logc": logc, "logm": logm, "u": u}
# ----------------------------
# Example usage
# ----------------------------
if __name__ == "__main__":
torch.manual_seed(0)
device = "cpu"
# Synthetic data (n,d)
n, d = 2000, 3
# Generate from a Gaussian copula with Normal marginals
true_R = torch.tensor([[1.0, 0.7, 0.3], [0.7, 1.0, 0.2], [0.3, 0.2, 1.0]])
L = torch.linalg.cholesky(true_R)
z = torch.randn(n, d)
z = z @ L.T
u = D.Normal(0, 1).cdf(z)
# Choose true marginals (Normal with distinct mean/std)
mu_true = torch.tensor([0.0, 2.0, -1.0])
sig_true = torch.tensor([1.0, 0.5, 2.0])
x = D.Normal(mu_true, sig_true).icdf(u) # data ~ coupled via Gaussian copula
# ---- Train FULL MLE (marginals + copula)
model = CopulaLikelihood(d, marginals=NormalMarginals(d), use_pseudo=False).to(device)
opt = torch.optim.Adam(model.parameters(), lr=5e-2)
for t in range(400):
opt.zero_grad()
ll, _ = model(x)
loss = -ll / n
loss.backward()
opt.step()
if (t + 1) % 100 == 0:
with torch.no_grad():
R_hat = model.copula.corr()
print(f"iter {t + 1:03d} nll: {loss.item():.4f}")
with torch.no_grad():
print("Estimated mu:", model.marginals.mu)
print("Estimated sigma:", torch.exp(model.marginals.log_sigma))
print("Estimated R:\n", model.copula.corr())
# ---- PSEUDO-MLE (rank uniforms for u; fit copula only)
model_p = CopulaLikelihood(d, marginals=NormalMarginals(d), use_pseudo=True).to(device)
# (marginal params won’t matter; only copula trained)
opt_p = torch.optim.Adam(model_p.copula.parameters(), lr=5e-2)
for t in range(300):
opt_p.zero_grad()
ll_p, _ = model_p(x)
loss_p = -ll_p / n
loss_p.backward()
opt_p.step()
if (t + 1) % 100 == 0:
print(f"[pseudo] iter {t + 1:03d} nll: {loss_p.item():.4f}")
with torch.no_grad():
print("Estimated R (pseudo-MLE):\n", model_p.copula.corr())