Loading torch_experiments/modules.py +3 −3 Original line number Diff line number Diff line Loading @@ -45,12 +45,12 @@ class BottleneckNNGP(nn.Module): Ks += self.v_b() for _ in range(depth): if self.manual_grad: KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) Ks = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks)) R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter) KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) Ks = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return Ks + noise * torch.eye(Ks.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) def log_likelihood(self, *args, **kwargs): Loading Loading
torch_experiments/modules.py +3 −3 Original line number Diff line number Diff line Loading @@ -45,12 +45,12 @@ class BottleneckNNGP(nn.Module): Ks += self.v_b() for _ in range(depth): if self.manual_grad: KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) Ks = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks)) R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter) KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) Ks = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) return Ks + noise * torch.eye(Ks.size(1), dtype=self.dtype, device=self.device).unsqueeze(0) def log_likelihood(self, *args, **kwargs): Loading