Commit aad27cff authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Change self.training to torch.is_grad_enabled.

parent d8148d87
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module):
		K_XX = self.v_b() + self.v_w() * gram_matrices
		for _ in range(depth):
			if self.manual_grad:
				K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, self.training, self.to_dtype, self.to_device)
				K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device)
			else:
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
				R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)