Commit eec0dd9a authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Add depth print-outs in model class.

parent b2d28448
Loading
Loading
Loading
Loading
+43 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ def batch_matrix_sqrt(A):

class bottleneck_nngp_prior(object):

	def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False, bottleneck_noise=1e-9):
	def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False, bottleneck_noise=1e-10):
		self.n_inputs = X.shape[0]
		self.X = X
		self.v_b = v_b
@@ -86,3 +86,45 @@ class bottleneck_nngp_prior(object):
		logliks = -1/2*np.sum(np.sum(LYs**2, 1), 1) - Y.shape[1]/2*logdets - Y.shape[1]*self.n_inputs/2*np.log(2*np.pi)
		loglik = sp_special.logsumexp(logliks) - np.log(n_samples)
		return loglik


	def K_multidepth(self, Gram, depth, noise=0.0, collect_depths=[]):
		Gram = np.atleast_2d(Gram)
		if Gram.ndim == 2:
			Gram = np.expand_dims(Gram, 0)
		K_XX = self.v_b + self.v_w*Gram
		Ks_over_depth = []
		if 0 in collect_depths:
			Ks_over_depth.append(K_XX + noise*np.expand_dims(np.eye(self.n_inputs), 0))
		for d in range(depth):
			A = np.sqrt(np.einsum("ijj,ikk->ijk", K_XX, K_XX))
			R = np.maximum(np.minimum(K_XX/A, 1.0), -1.0)
			K_XX = self.v_b + self.v_w*A/np.pi*(np.sqrt(1-R**2) + (np.pi-np.arccos(R))*R)
			if d+1 in collect_depths:
				print("depth:", d+1)
				Ks_over_depth.append(K_XX + noise*np.expand_dims(np.eye(self.n_inputs), 0))
		return Ks_over_depth

	def log_likelihood_multidepth(self, Y, n_samples=1, collect_depths=[]):
		grams = np.expand_dims(self.X.dot(self.X.T), 0)
		for (depth, width) in zip(self.depths[:-1], self.widths):
			Ks = self.K(grams, depth, noise=self.bottleneck_noise)
			Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0)
			samples = np.random.normal(0, 1, size=(n_samples, self.n_inputs, width))
			samples = np.einsum("sik,skj->sij", Ls, samples)
			if self.bottleneck_activation:
				samples = np.sqrt(2)*np.maximum(0, samples)
			grams = np.einsum("sik,sjk->sij", samples, samples)/width
		print("computing kernel . . . ")
		Ks_over_depth = self.K_multidepth(grams, self.depths[-1], noise=self.v_n, collect_depths=collect_depths)
		lls_over_depth = []
		print("computing likelihoods . . . ")
		for (cd, Ks) in zip(collect_depths, Ks_over_depth):
			print("depth:", cd)
			Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0)
			LYs = np.stack([sp_linalg.solve_triangular(L, Y, lower=True) for L in Ls], 0)
			logdets = 2*np.sum(np.log(np.diagonal(Ls, axis1=1, axis2=2)), 1)
			logliks = -1/2*np.sum(np.sum(LYs**2, 1), 1) - Y.shape[1]/2*logdets - Y.shape[1]*self.n_inputs/2*np.log(2*np.pi)
			loglik = sp_special.logsumexp(logliks) - np.log(n_samples)
			lls_over_depth.append(loglik)
		return lls_over_depth