Commit 0ad431ac authored by Theodore Papamarkou's avatar Theodore Papamarkou
Browse files

Cleaning up and a fix in pytorch marginal log-lik of bottleneck NNGP

parent 9309142a
Loading
Loading
Loading
Loading
+2 −11
Original line number Diff line number Diff line
import math

import torch
from torch.distributions import Normal
import torch.nn as nn

class Hyperparameters:
@@ -28,16 +27,8 @@ class BottleneckNNGP(nn.Module):
        self.v_w = nn.Parameter(data=torch.empty([1], dtype=self.dtype, device=self.device), requires_grad=True)
        self.v_n = nn.Parameter(data=torch.empty([1], dtype=self.dtype, device=self.device), requires_grad=True)

        self.prior = self.default_prior()

        self.samples = samples

    def default_prior(self):
        return Normal(
            torch.zeros(self.num_params(), dtype=self.dtype, device=self.device),
            torch.ones(self.num_params(), dtype=self.dtype, device=self.device)
        )

    def num_params(self):
        """ Get the number of model parameters. """
        return sum(p.numel() for p in self.parameters())
@@ -65,10 +56,10 @@ class BottleneckNNGP(nn.Module):

    def log_lik(self, x, y, samples=None):
        """ Log-likelihood """
        gram = (x @ x.t()).unsqueeze(0)
        grams = (x @ x.t()).unsqueeze(0)

        for (depth, width) in zip(self.hp.depths[:-1], self.hp.widths):
            Ks = self.K(gram, depth, x.shape[0], self.hp.noise)
            Ks = self.K(grams, depth, x.shape[0], self.hp.noise)
            Ls = torch.stack([torch.cholesky(K) for K in Ks], 0)

            if samples is None: