Commit d4365283 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Use del and in-place ops during no_grad.

parent 544426e7
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -6,12 +6,15 @@ class NNGPReLU(Function):

	@staticmethod
	def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device):
		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
		R = input_K/A

		if not is_grad_enabled:
			A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
			input_K /= A
			R = input_K
			return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)

		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
		R = input_K/A

		F_1 = torch.sqrt(1 - R**2)/math.pi
		F_2 = 1 - torch.acos(R)/math.pi
		F = A * (F_1 + F_2*R)
+16 −7
Original line number Diff line number Diff line
@@ -37,16 +37,20 @@ class BottleneckNNGP(nn.Module):
		self.v_n = PositiveParameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True)


	def K(self, gram_matrices, depth, noise):
		K_XX = self.v_b() + self.v_w() * gram_matrices
	def K(self, Ks, depth, noise):
		if torch.is_grad_enabled():
			Ks = self.v_b() + self.v_w() * Ks
		else:
			Ks *= self.v_w()
			Ks += self.v_b()
		for _ in range(depth):
			if self.manual_grad:
				K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device)
				KS = nngp_relu(self.v_b(), self.v_w(), Ks, torch.is_grad_enabled(), self.to_dtype, self.to_device)
			else:
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
				R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)
				K_XX = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", Ks, Ks))
				R = torch.clamp(Ks/A, -1+self.jitter, 1-self.jitter)
				KS = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return KS + noise * torch.eye(KS.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)


	def log_likelihood(self, *args, **kwargs):
@@ -67,6 +71,11 @@ class BottleneckNNGP(nn.Module):
			samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() * nn.functional.relu(samples)
			gram_matrices = torch.matmul(samples, torch.transpose(samples, 1, 2)) / width

		if not torch.is_grad_enabled():
			del Ks
			del Ls
			del samples

		Ks = self.K(gram_matrices, self.depths[-1], self.v_n())
		Ls = torch.cholesky(Ks)
		LYs = torch.triangular_solve(y, Ls, upper=False).solution