Commit 471e0986 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Clamp R in custom function to avoid nan.

parent 790b6206
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -9,11 +9,12 @@ class NNGPReLU(Function):
		if not is_grad_enabled:
			A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
			input_K /= A
			input_K.clamp_(-1, 1)
			R = input_K
			return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)

		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
		R = input_K/A
		R = torch.clamp(input_K/A, -1, 1)

		F_1 = torch.sqrt(1 - R**2)/math.pi
		F_2 = 1 - torch.acos(R)/math.pi