Loading torch_experiments/modules.py +1 −1 Original line number Diff line number Diff line Loading @@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module): K_XX = self.v_b() + self.v_w() * gram_matrices for _ in range(depth): if self.manual_grad: K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, self.training, self.to_dtype, self.to_device) K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) Loading Loading
torch_experiments/modules.py +1 −1 Original line number Diff line number Diff line Loading @@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module): K_XX = self.v_b() + self.v_w() * gram_matrices for _ in range(depth): if self.manual_grad: K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, self.training, self.to_dtype, self.to_device) K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) Loading