Loading torch_experiments/autograds.py +2 −1 Original line number Diff line number Diff line Loading @@ -9,11 +9,12 @@ class NNGPReLU(Function): if not is_grad_enabled: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) input_K /= A input_K.clamp_(-1, 1) R = input_K return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A R = torch.clamp(input_K/A, -1, 1) F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi Loading Loading
torch_experiments/autograds.py +2 −1 Original line number Diff line number Diff line Loading @@ -9,11 +9,12 @@ class NNGPReLU(Function): if not is_grad_enabled: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) input_K /= A input_K.clamp_(-1, 1) R = input_K return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A R = torch.clamp(input_K/A, -1, 1) F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi Loading