Loading prior_distro/models.py +6 −5 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ def batch_matrix_sqrt(A): class bottleneck_nngp_prior(object): def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False): def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False, bottleneck_noise=1e-9): self.n_inputs = X.shape[0] self.X = X self.v_b = v_b Loading @@ -31,8 +31,9 @@ class bottleneck_nngp_prior(object): self.depths = depths self.widths = widths self.bottleneck_activation = bottleneck_activation self.bottleneck_noise = bottleneck_noise def K(self, Gram, depth): def K(self, Gram, depth, noise=0.0): Gram = np.atleast_2d(Gram) if Gram.ndim == 2: Gram = np.expand_dims(Gram, 0) Loading @@ -41,7 +42,7 @@ class bottleneck_nngp_prior(object): A = np.sqrt(np.einsum("ijj,ikk->ijk", K_XX, K_XX)) R = np.maximum(np.minimum(K_XX/A, 1.0), -1.0) K_XX = self.v_b + self.v_w*A/np.pi*(np.sqrt(1-R**2) + (np.pi-np.arccos(R))*R) return K_XX + self.v_n*np.expand_dims(np.eye(self.n_inputs), 0) return K_XX + noise*np.expand_dims(np.eye(self.n_inputs), 0) def sample(self, n_samples=1): grams = self.X.dot(self.X.T) Loading Loading @@ -71,14 +72,14 @@ class bottleneck_nngp_prior(object): def log_likelihood(self, Y, n_samples=1): grams = np.expand_dims(self.X.dot(self.X.T), 0) for (depth, width) in zip(self.depths[:-1], self.widths): Ks = self.K(grams, depth) Ks = self.K(grams, depth, noise=self.bottleneck_noise) Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0) samples = np.random.normal(0, 1, size=(n_samples, self.n_inputs, width)) samples = np.einsum("sik,skj->sij", Ls, samples) if self.bottleneck_activation: samples = np.sqrt(2)*np.maximum(0, samples) grams = np.einsum("sik,sjk->sij", samples, samples)/width Ks = self.K(grams, self.depths[-1]) Ks = self.K(grams, self.depths[-1], noise=self.v_n) Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0) LYs = np.stack([sp_linalg.solve_triangular(L, Y, lower=True) for L in Ls], 0) logdets = 2*np.sum(np.log(np.diagonal(Ls, axis1=1, axis2=2)), 1) Loading Loading
prior_distro/models.py +6 −5 Original line number Diff line number Diff line Loading @@ -21,7 +21,7 @@ def batch_matrix_sqrt(A): class bottleneck_nngp_prior(object): def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False): def __init__(self, X, v_b=1.0, v_w=1.0, v_n=1e-4, output_dims=1, depths=[0], widths=[], bottleneck_activation=False, bottleneck_noise=1e-9): self.n_inputs = X.shape[0] self.X = X self.v_b = v_b Loading @@ -31,8 +31,9 @@ class bottleneck_nngp_prior(object): self.depths = depths self.widths = widths self.bottleneck_activation = bottleneck_activation self.bottleneck_noise = bottleneck_noise def K(self, Gram, depth): def K(self, Gram, depth, noise=0.0): Gram = np.atleast_2d(Gram) if Gram.ndim == 2: Gram = np.expand_dims(Gram, 0) Loading @@ -41,7 +42,7 @@ class bottleneck_nngp_prior(object): A = np.sqrt(np.einsum("ijj,ikk->ijk", K_XX, K_XX)) R = np.maximum(np.minimum(K_XX/A, 1.0), -1.0) K_XX = self.v_b + self.v_w*A/np.pi*(np.sqrt(1-R**2) + (np.pi-np.arccos(R))*R) return K_XX + self.v_n*np.expand_dims(np.eye(self.n_inputs), 0) return K_XX + noise*np.expand_dims(np.eye(self.n_inputs), 0) def sample(self, n_samples=1): grams = self.X.dot(self.X.T) Loading Loading @@ -71,14 +72,14 @@ class bottleneck_nngp_prior(object): def log_likelihood(self, Y, n_samples=1): grams = np.expand_dims(self.X.dot(self.X.T), 0) for (depth, width) in zip(self.depths[:-1], self.widths): Ks = self.K(grams, depth) Ks = self.K(grams, depth, noise=self.bottleneck_noise) Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0) samples = np.random.normal(0, 1, size=(n_samples, self.n_inputs, width)) samples = np.einsum("sik,skj->sij", Ls, samples) if self.bottleneck_activation: samples = np.sqrt(2)*np.maximum(0, samples) grams = np.einsum("sik,sjk->sij", samples, samples)/width Ks = self.K(grams, self.depths[-1]) Ks = self.K(grams, self.depths[-1], noise=self.v_n) Ls = np.stack([np.linalg.cholesky(K) for K in Ks], 0) LYs = np.stack([sp_linalg.solve_triangular(L, Y, lower=True) for L in Ls], 0) logdets = 2*np.sum(np.log(np.diagonal(Ls, axis1=1, axis2=2)), 1) Loading