import torch
import torch.nn as nn
import torch.distributions as D


# ----------------------------
# Utilities
# ----------------------------
def rank_uniform(x, eps=1e-3):
    """
    Probability integral transform via empirical CDF (pseudo-observations).
    Returns U in (0,1); eps shrinks away from 0/1 for stability.
    """
    n = x.shape[0]
    ranks = torch.argsort(torch.argsort(x, dim=0), dim=0).float() + 1.0
    u = ranks / (n + 1.0)
    return u.clamp(eps, 1 - eps)


def inv_std_normal(u):
    """Phi^{-1}(u) using torch.erfinv."""
    return torch.sqrt(torch.tensor(2.0, device=u.device)) * torch.erfinv(2 * u - 1)


def tril_param_to_corr(L_unconstrained, jitter=1e-5):
    """
    Map unconstrained lower-triangular params to a valid correlation matrix.
    We build a covariance via L @ L^T, then scale to correlation.
    Diagonal of L is exp() to keep PD; result is symmetric, PD, with unit diag.
    """
    d = L_unconstrained.size(-1)
    # build lower-tri with exp(diag) for positivity
    L = torch.tril(L_unconstrained)
    diag_idx = torch.arange(d, device=L.device)
    L[..., diag_idx, diag_idx] = torch.exp(L[..., diag_idx, diag_idx])
    Sigma = L @ L.transpose(-1, -2)
    # scale to correlation
    std = torch.sqrt(torch.diagonal(Sigma, dim1=-2, dim2=-1) + jitter)
    Dinv = torch.diag_embed(1.0 / std)
    R = Dinv @ Sigma @ Dinv
    # enforce symmetry numerically
    R = 0.5 * (R + R.transpose(-1, -2))
    # add tiny jitter for stability
    R = R + torch.eye(d, device=R.device) * jitter
    return R


# ----------------------------
# Parametric marginals (Normals by default)
# ----------------------------
class NormalMarginals(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.mu = nn.Parameter(torch.zeros(d))
        self.log_sigma = nn.Parameter(torch.zeros(d))

    def forward(self, x):
        """Return log f_i(x_i) and u_i = F_i(x_i)."""
        sigma = torch.exp(self.log_sigma)
        dist = D.Normal(self.mu, sigma)
        logpdf = dist.log_prob(x)  # (n, d)
        u = dist.cdf(x).clamp(1e-6, 1 - 1e-6)  # (n, d)
        return logpdf, u


# ----------------------------
# Gaussian Copula Module
# ----------------------------
class GaussianCopula(nn.Module):
    def __init__(self, d):
        super().__init__()
        # Unconstrained lower-triangular for covariance construction
        L = torch.zeros(d, d)
        L.fill_(0.0)
        # initialize diag to small positive value
        L += torch.eye(d) * (-1.0)  # exp(-1) ~ 0.367 for stability
        self.L_unconstrained = nn.Parameter(L)

    def corr(self):
        return tril_param_to_corr(self.L_unconstrained)

    def log_copula_density(self, u):
        """
        u: (n, d) with entries in (0,1).
        c(u) = |R|^{-1/2} * exp(-0.5 z^T (R^{-1} - I) z), where z = Phi^{-1}(u)
        """
        z = inv_std_normal(u)  # (n, d)
        R = self.corr()  # (d, d)
        d = R.size(0)

        # Cholesky for logdet and solves
        L = torch.linalg.cholesky(R)  # R = L L^T
        logdetR = 2.0 * torch.log(torch.diagonal(L)).sum()

        # Solve R^{-1} z via triangular solves
        # solve L y = z  => y
        y = torch.cholesky_solve(z.unsqueeze(-1), L).squeeze(-1)  # (n, d)
        quad = (z * y).sum(-1)  # z^T R^{-1} z
        quad_I = (z * z).sum(-1)

        logc = -0.5 * logdetR + (-0.5) * (quad - quad_I)
        return logc  # (n,)


# ----------------------------
# Full model: marginals + copula
# ----------------------------
class CopulaLikelihood(nn.Module):
    def __init__(self, d, marginals=None, use_pseudo=False):
        super().__init__()
        self.d = d
        self.copula = GaussianCopula(d)
        self.marginals = marginals if marginals is not None else NormalMarginals(d)
        self.use_pseudo = use_pseudo  # if True, use empirical CDFs for u and ignore marginal logpdf

    def forward(self, x):
        """
        x: (n, d)
        returns: total log-likelihood over n samples, per-sample terms as well
        """
        if self.use_pseudo:
            u = rank_uniform(x)  # empirical PIT
            logc = self.copula.log_copula_density(u)
            logm = torch.zeros_like(logc)  # ignored in pseudo-MLE
        else:
            logm, u = self.marginals(x)  # (n,d), (n,d)
            logc = self.copula.log_copula_density(u)
            logm = logm.sum(-1)  # sum over dimensions

        ll = logc + logm  # (n,)
        return ll.sum(), {"logc": logc, "logm": logm, "u": u}


# ----------------------------
# Example usage
# ----------------------------
if __name__ == "__main__":
    torch.manual_seed(0)
    device = "cpu"

    # Synthetic data (n,d)
    n, d = 2000, 3
    # Generate from a Gaussian copula with Normal marginals
    true_R = torch.tensor([[1.0, 0.7, 0.3], [0.7, 1.0, 0.2], [0.3, 0.2, 1.0]])
    L = torch.linalg.cholesky(true_R)
    z = torch.randn(n, d)
    z = z @ L.T
    u = D.Normal(0, 1).cdf(z)
    # Choose true marginals (Normal with distinct mean/std)
    mu_true = torch.tensor([0.0, 2.0, -1.0])
    sig_true = torch.tensor([1.0, 0.5, 2.0])
    x = D.Normal(mu_true, sig_true).icdf(u)  # data ~ coupled via Gaussian copula

    # ---- Train FULL MLE (marginals + copula)
    model = CopulaLikelihood(d, marginals=NormalMarginals(d), use_pseudo=False).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=5e-2)
    for t in range(400):
        opt.zero_grad()
        ll, _ = model(x)
        loss = -ll / n
        loss.backward()
        opt.step()
        if (t + 1) % 100 == 0:
            with torch.no_grad():
                R_hat = model.copula.corr()
                print(f"iter {t + 1:03d}  nll: {loss.item():.4f}")
    with torch.no_grad():
        print("Estimated mu:", model.marginals.mu)
        print("Estimated sigma:", torch.exp(model.marginals.log_sigma))
        print("Estimated R:\n", model.copula.corr())

    # ---- PSEUDO-MLE (rank uniforms for u; fit copula only)
    model_p = CopulaLikelihood(d, marginals=NormalMarginals(d), use_pseudo=True).to(device)
    # (marginal params won’t matter; only copula trained)
    opt_p = torch.optim.Adam(model_p.copula.parameters(), lr=5e-2)
    for t in range(300):
        opt_p.zero_grad()
        ll_p, _ = model_p(x)
        loss_p = -ll_p / n
        loss_p.backward()
        opt_p.step()
        if (t + 1) % 100 == 0:
            print(f"[pseudo] iter {t + 1:03d}  nll: {loss_p.item():.4f}")
    with torch.no_grad():
        print("Estimated R (pseudo-MLE):\n", model_p.copula.corr())
iter 100  nll: 3.8697
iter 200  nll: 3.8684
iter 300  nll: 3.8684
iter 400  nll: 3.8684
Estimated mu: Parameter containing:
tensor([-0.0113,  1.9867, -1.0181], requires_grad=True)
Estimated sigma: tensor([0.9932, 0.4943, 2.0086])
Estimated R:
 tensor([[1.0000, 0.6938, 0.2965],
        [0.6938, 1.0000, 0.1956],
        [0.2965, 0.1956, 1.0000]])
[pseudo] iter 100  nll: -0.3734
[pseudo] iter 200  nll: -0.3734
[pseudo] iter 300  nll: -0.3734
Estimated R (pseudo-MLE):
 tensor([[0.9965, 0.6936, 0.2963],
        [0.6936, 1.0000, 0.1962],
        [0.2963, 0.1962, 1.0000]])