Commit 544426e7 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Fix GPU memory issues.

parent aad27cff
Loading
Loading
Loading
Loading
+27 −14
Original line number Diff line number Diff line
@@ -5,32 +5,45 @@ from torch.autograd import Function
class NNGPReLU(Function):

	@staticmethod
	def forward(ctx, input, is_grad_enabled, to_dtype, to_device):
		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input, input))
		R = input/A
	def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device):
		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K))
		R = input_K/A

		if not is_grad_enabled:
			return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)

		F_1 = torch.sqrt(1 - R**2)/math.pi
		F_2 = 1 - torch.acos(R)/math.pi
		F = A * (F_1 + F_2*R)

		if is_grad_enabled:
			F_1 *= torch.diagonal(input, dim1=-2, dim2=-1).unsqueeze(2)/A
		F_1 *= torch.diagonal(input_K, dim1=-2, dim2=-1).unsqueeze(2)/A
		del A
		del R
		output = input_vb + input_vw * F

		F_1 = F_1.to(dtype=to_dtype, device=to_device)
			f_2 = F_2.to(dtype=to_dtype, device=to_device)
			ctx.save_for_backward(F_1, F_2)
		F_2 = F_2.to(dtype=to_dtype, device=to_device)
		F = F.to(dtype=to_dtype, device=to_device)
		ctx.save_for_backward(input_vb, input_vw, F_1, F_2, F)

		return F
		return output

	@staticmethod
	def backward(ctx, grad_output):
		F_1, F_2 = ctx.saved_tensors
		input_vb, input_vw, F_1, F_2, F = ctx.saved_tensors
		F_1 = F_1.to(dtype=grad_output.dtype, device=grad_output.device)
		F_2 = F_2.to(dtype=grad_output.dtype, device=grad_output.device)
		dim = F_2.size(1)
		F = F.to(dtype=grad_output.dtype, device=grad_output.device)

		grad_input_vb = torch.sum(grad_output).view(1)
		grad_input_vw = torch.sum(grad_output * F).view(1)

		dim = F.size(1)
		grad_input_K = grad_output*F_2
		grad_input_K[:,range(dim),range(dim)] += torch.sum(		grad_output*F_1, 1)
		grad_input_K *= input_vw

		grad_input = grad_output*F_2
		grad_input[:,range(dim),range(dim)] += torch.sum(		grad_output*F_1, 1)
		return grad_input, None, None, None
		return grad_input_vb, grad_input_vw, grad_input_K, None, None, None


nngp_relu = NNGPReLU.apply
+1 −1
Original line number Diff line number Diff line
@@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module):
		K_XX = self.v_b() + self.v_w() * gram_matrices
		for _ in range(depth):
			if self.manual_grad:
				K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device)
				K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device)
			else:
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
				R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)