Loading torch_experiments/autograds.py +6 −3 Original line number Diff line number Diff line Loading @@ -6,12 +6,15 @@ class NNGPReLU(Function): @staticmethod def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A if not is_grad_enabled: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) input_K /= A R = input_K return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi F = A * (F_1 + F_2*R) Loading torch_experiments/modules.py +16 −7 Original line number Diff line number Diff line Loading @@ -37,16 +37,20 @@ class BottleneckNNGP(nn.Module): self.v_n = PositiveParameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True) def K(self, gram_matrices, depth, noise): K_XX = self.v_b() + self.v_w() * gram_matrices def K(self, Ks, depth, noise): if torch.is_grad_enabled(): Ks = self.v_b() + self.v_w() * Ks else: Ks *= self.v_w() Ks += self.v_b() for _ in range(depth): if self.manual_grad: K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) K_XX = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks)) R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter) KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) def log_likelihood(self, *args, **kwargs): Loading @@ -67,6 +71,11 @@ class BottleneckNNGP(nn.Module): samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() * nn.functional.relu(samples) gram_matrices = torch.matmul(samples, torch.transpose(samples, 1, 2)) / width if not torch.is_grad_enabled(): del Ks del Ls del samples Ks = self.K(gram_matrices, self.depths[-1], self.v_n()) Ls = torch.cholesky(Ks) LYs = torch.triangular_solve(y, Ls, upper=False).solution Loading Loading
torch_experiments/autograds.py +6 −3 Original line number Diff line number Diff line Loading @@ -6,12 +6,15 @@ class NNGPReLU(Function): @staticmethod def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A if not is_grad_enabled: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) input_K /= A R = input_K return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi F = A * (F_1 + F_2*R) Loading
torch_experiments/modules.py +16 −7 Original line number Diff line number Diff line Loading @@ -37,16 +37,20 @@ class BottleneckNNGP(nn.Module): self.v_n = PositiveParameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True) def K(self, gram_matrices, depth, noise): K_XX = self.v_b() + self.v_w() * gram_matrices def K(self, Ks, depth, noise): if torch.is_grad_enabled(): Ks = self.v_b() + self.v_w() * Ks else: Ks *= self.v_w() Ks += self.v_b() for _ in range(depth): if self.manual_grad: K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) K_XX = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks)) R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter) KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) def log_likelihood(self, *args, **kwargs): Loading @@ -67,6 +71,11 @@ class BottleneckNNGP(nn.Module): samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() * nn.functional.relu(samples) gram_matrices = torch.matmul(samples, torch.transpose(samples, 1, 2)) / width if not torch.is_grad_enabled(): del Ks del Ls del samples Ks = self.K(gram_matrices, self.depths[-1], self.v_n()) Ls = torch.cholesky(Ks) LYs = torch.triangular_solve(y, Ls, upper=False).solution Loading