Commit 69a0ebc5 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Update bottleneck noise in prior_distro.

parent 0f9c7f01
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ def batch_matrix_sqrt(A):

class bottleneck_nngp_prior(object):

	def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False):
	def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False, bottleneck_noise=1e-9):
		self.n_inputs = X.shape[0]
		self.X = X
		self.v_b = v_b
@@ -31,8 +31,9 @@ class bottleneck_nngp_prior(object):
		self.depths = depths
		self.widths = widths
		self.bottleneck_activation = bottleneck_activation
		self.bottleneck_noise = bottleneck_noise

	def K(self, Gram, depth):
	def K(self, Gram, depth, noise=0.0):
		Gram = np.atleast_2d(Gram)
		if Gram.ndim == 2:
			Gram = np.expand_dims(Gram, 0)
@@ -41,7 +42,7 @@ class bottleneck_nngp_prior(object):
			A = np.sqrt(np.einsum("ijj,ikk->ijk", K_XX, K_XX))
			R = np.maximum(np.minimum(K_XX/A, 1.0), -1.0)
			K_XX = self.v_b + self.v_w*A/np.pi*(np.sqrt(1-R**2) + (np.pi-np.arccos(R))*R)
		return K_XX + self.v_n*np.expand_dims(np.eye(self.n_inputs), 0)
		return K_XX + noise*np.expand_dims(np.eye(self.n_inputs), 0)

	def sample(self, n_samples=1):
		grams = self.X.dot(self.X.T)
@@ -71,14 +72,14 @@ class bottleneck_nngp_prior(object):
	def log_likelihood(self, Y, n_samples=1):
		grams = np.expand_dims(self.X.dot(self.X.T), 0)
		for (depth, width) in zip(self.depths[:-1], self.widths):
			Ks = self.K(grams, depth)
			Ks = self.K(grams, depth, noise=self.bottleneck_noise)
			Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0)
			samples = np.random.normal(0, 1, size=(n_samples, self.n_inputs, width))
			samples = np.einsum("sik,skj->sij", Ls, samples)
			if self.bottleneck_activation:
				samples = np.sqrt(2)*np.maximum(0, samples)
			grams = np.einsum("sik,sjk->sij", samples, samples)/width
		Ks = self.K(grams, self.depths[-1])
		Ks = self.K(grams, self.depths[-1], noise=self.v_n)
		Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0)
		LYs = np.stack([sp_linalg.solve_triangular(L, Y, lower=True) for L in Ls], 0)
		logdets = 2*np.sum(np.log(np.diagonal(Ls, axis1=1, axis2=2)), 1)