Commit 790b6206 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Fix KS typo in function K.

parent d4365283
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -45,12 +45,12 @@ class BottleneckNNGP(nn.Module):
			Ks += self.v_b()
		for _ in range(depth):
			if self.manual_grad:
				KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device)
				Ks = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device)
			else:
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks))
				R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter)
				KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)
				Ks = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return Ks + noise * torch.eye(Ks.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)


	def log_likelihood(self, *args, **kwargs):