Commit 9309142a authored by Theodore Papamarkou's avatar Theodore Papamarkou
Browse files

Retained mll bottleneck nngp dependent only on nn.Module

parent 238e377d
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ from torch.utils.data import DataLoader

from eeyore.api import indexify

from bottleneck_nngp_pytorch_eeyore import BottleneckNNGP, Hyperparameters
from bottleneck_nngp_pytorch import BottleneckNNGP, Hyperparameters
from boston_dataset import Boston

# %%
+0 −65
Original line number Diff line number Diff line
# %%

import torch
from torch.utils.data import DataLoader

from eeyore.api import indexify

from bottleneck_nngp_pytorch_no_eeyore import BottleneckNNGP, Hyperparameters
from boston_dataset import Boston

# %%

boston = indexify(Boston)(standardize=True)
dataloader = DataLoader(boston, batch_size=506)

# %%

hparams = Hyperparameters(depths=[3, 3], widths=[64], activation=True, noise=1e-10, num_mc=100)

# %%

v_b = 1.2 # 1.0
v_w = 1.5 # 1.0
v_n = 0.25 # 1e-4

# %%

model = BottleneckNNGP(hparams=hparams)

# %%
    
model.set_params(torch.tensor([v_b, v_w, v_n]))

# %%

x, y, z = next(iter(dataloader))

# %%

log_lik_val = model.log_lik(x, y)
print("Log-likelihood using pytorch:", log_lik_val.item())

# %%

import numpy as np
from sklearn import datasets

import bottleneck_nngp_numpy

# %%

X_np, Y_np = datasets.load_boston(return_X_y=True)
Y_np = Y_np.reshape((-1, 1))

X_np = (X_np - np.mean(X_np, axis=0, keepdims=True))/np.std(X_np, axis=0, keepdims=True)
Y_np = (Y_np - np.mean(Y_np, axis=0, keepdims=True))/np.std(Y_np, axis=0, keepdims=True)

# %%

model_np = bottleneck_nngp_numpy.bottleneck_nngp_prior(X_np, v_b=model.v_b.data.item(), v_w=model.v_w.data.item(), v_n=model.v_n.data.item(), depths=model.hp.depths, widths=model.hp.widths, bottleneck_activation=True)

# %%

log_lik_val_np = model_np.log_likelihood(Y_np, samples=model.samples.clone().detach().numpy(), n_samples=model.hp.num_mc)
print("Log-likelihood using numpy:", log_lik_val_np)
+0 −91
Original line number Diff line number Diff line
import math

import torch
from torch.distributions import Normal
import torch.nn as nn

from eeyore.api import BayesianModel

class Hyperparameters:
    def __init__(self, depths=[0], widths=[], activation=True, noise=1e-10, num_mc=1000):
        self.depths = depths
        self.widths = widths
        self.activation = activation
        self.noise = noise  # bottleneck noise
        self.num_mc = num_mc  # number of Monte Carlo samples

        if (len(self.depths) != len(self.widths) + 1):
            raise ValueError

class BottleneckNNGP(BayesianModel):
    def __init__(self, constraint=None, bounds=[None, None], temperature=None, prior=None, hparams=Hyperparameters(),
    samples=None, savefile=None, dtype=torch.float64, device='cpu'):
        super().__init__(
            loss=None, constraint=constraint, bounds=bounds, temperature=temperature, dtype=dtype, device=device
        )
        self.hp = hparams

        self.v_b = nn.Parameter(data=torch.empty([1], dtype=self.dtype, device=self.device), requires_grad=True)
        self.v_w = nn.Parameter(data=torch.empty([1], dtype=self.dtype, device=self.device), requires_grad=True)
        self.v_n = nn.Parameter(data=torch.empty([1], dtype=self.dtype, device=self.device), requires_grad=True)

        self.prior = prior or self.default_prior()

        self.samples = samples

    def default_prior(self):
        return Normal(
            torch.zeros(self.num_params(), dtype=self.dtype, device=self.device),
            torch.ones(self.num_params(), dtype=self.dtype, device=self.device)
        )

    def K(self, gram, depth, n, noise):
        K_XX = self.v_b + self.v_w * gram
        for _ in range(depth):
            A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
            R = torch.max(torch.min(K_XX / A, torch.ones_like(A)), -torch.ones_like(A))
            K_XX = self.v_b + self.v_w * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
        return K_XX + noise * torch.eye(n, dtype=self.dtype, device=self.device).unsqueeze(0)

    def log_lik(self, x, y, samples=None):
        """ Log-likelihood """
        gram = (x @ x.t()).unsqueeze(0)

        for (depth, width) in zip(self.hp.depths[:-1], self.hp.widths):
            Ks = self.K(gram, depth, x.shape[0], self.hp.noise)
            Ls = torch.stack([torch.cholesky(K) for K in Ks], 0)

            if samples is None:
                self.samples = torch.randn(self.hp.num_mc, x.shape[0], width, dtype=self.dtype, device=self.device)
                self.samples = torch.cat(
                    [
                        torch.einsum("sik, skj->sij", Ls, self.samples[i, :, :].unsqueeze(0))
                        for i in range(self.hp.num_mc)
                    ]
                )
                if self.hp.activation:
                    self.samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() \
                        * torch.max(self.samples, torch.zeros_like(self.samples))
            else:
                self.samples = samples

            grams = torch.einsum("sik, sjk->sij", self.samples, self.samples) / width

        Ks = self.K(grams, self.hp.depths[-1], x.shape[0], noise=self.v_n)
        Ls = torch.stack([torch.cholesky(K) for K in Ks], 0)
        LYs = torch.stack([torch.triangular_solve(y, L, upper=False).solution for L in Ls], 0)
        logdets = 2 * torch.sum(torch.log(torch.diagonal(Ls, dim1=1, dim2=2)), 1)
        log_lik_val = torch.logsumexp(
            -0.5 * (
                torch.sum(torch.sum(LYs ** 2, 1), 1)
                + y.shape[1] * (
                    logdets + x.shape[0] * torch.tensor([2 * math.pi], dtype=self.dtype, device=self.device).log()
                )
            ),
            0
        ) - torch.tensor([self.hp.num_mc], dtype=self.dtype, device=self.device).log()

        if self.temperature is not None:
            log_lik_val = self.temperature * log_lik_val

        return log_lik_val