Commit d8148d87 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Implement efficient manual NNGP gradient.

parent 77709b37
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
import math
import torch
from torch.autograd import Function

class NNGPReLU(Function):

	@staticmethod
	def forward(ctx, input, is_grad_enabled, to_dtype, to_device):
		A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input, input))
		R = input/A

		F_1 = torch.sqrt(1 - R**2)/math.pi
		F_2 = 1 - torch.acos(R)/math.pi
		F = A * (F_1 + F_2*R)

		if is_grad_enabled:
			F_1 *= torch.diagonal(input, dim1=-2, dim2=-1).unsqueeze(2)/A
			F_1 = F_1.to(dtype=to_dtype, device=to_device)
			f_2 = F_2.to(dtype=to_dtype, device=to_device)
			ctx.save_for_backward(F_1, F_2)

		return F

	@staticmethod
	def backward(ctx, grad_output):
		F_1, F_2 = ctx.saved_tensors
		F_1 = F_1.to(dtype=grad_output.dtype, device=grad_output.device)
		F_2 = F_2.to(dtype=grad_output.dtype, device=grad_output.device)
		dim = F_2.size(1)

		grad_input = grad_output*F_2
		grad_input[:,range(dim),range(dim)] += torch.sum(		grad_output*F_1, 1)
		return grad_input, None, None, None


nngp_relu = NNGPReLU.apply
+3 −1
Original line number Diff line number Diff line
@@ -28,9 +28,11 @@ parser.add_argument("--test_samples", "-t", default=100, type=int, help="Number
# optimization
parser.add_argument('--lr','-l', default=1e-3, type=float, help="Learning rate.")
parser.add_argument("--iters", "-i", default=10, type=int, help="Number of training iterations.")
parser.add_argument("--manual_grad", "-mg", action="store_true", help="Use manual NNGP-ReLU gradient.")
# GPU
parser.add_argument("--gpu", "-g", default=-1, type=int, help="Which GPU to use. If negative, CPU is used.")


args = parser.parse_args()
device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

@@ -41,7 +43,7 @@ dataloader = DataLoader(dataset, batch_size=len(dataset))
X, Y = next(iter(dataloader))

# get model
model = BottleneckNNGP(depths=args.depths, widths=args.widths, v_b=args.vb, v_w=args.vw, v_n=args.vn, device=device)
model = BottleneckNNGP(depths=args.depths, widths=args.widths, v_b=args.vb, v_w=args.vw, v_n=args.vn, device=device, manual_grad=args.manual_grad, to_dtype=torch.float32, to_device="cpu")

# define stochastic loss function
loss_fn = lambda num_samples, manual_samples: -model.log_likelihood(X, Y, num_samples=num_samples, manual_samples=manual_samples)/X.shape[0]
+13 −4
Original line number Diff line number Diff line
import math
import torch
import torch.nn as nn
from autograds import nngp_relu


class PositiveParameter(nn.Module):

@@ -18,7 +20,7 @@ class PositiveParameter(nn.Module):

class BottleneckNNGP(nn.Module):

	def __init__(self, depths=[0], widths=[], v_b=1.0, v_w=1.0, v_n=1.0, jitter=1e-10, dtype=torch.float64, device="cpu"):
	def __init__(self, depths=[0], widths=[], v_b=1.0, v_w=1.0, v_n=1.0, jitter=1e-10, dtype=torch.float64, device="cpu", manual_grad=False, to_dtype=None, to_device=None):
		super().__init__()
		self.depths = depths
		self.widths = widths
@@ -26,6 +28,10 @@ class BottleneckNNGP(nn.Module):
		self.dtype = dtype
		self.device = device

		self.manual_grad = manual_grad
		self.to_dtype = dtype if to_dtype is None else to_dtype
		self.to_device = device if to_device is None else to_device

		self.v_b = PositiveParameter(data=torch.tensor([v_b], dtype=self.dtype, device=self.device), requires_grad=True)
		self.v_w = PositiveParameter(data=torch.tensor([v_w], dtype=self.dtype, device=self.device), requires_grad=True)
		self.v_n = PositiveParameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True)
@@ -34,6 +40,9 @@ class BottleneckNNGP(nn.Module):
	def K(self, gram_matrices, depth, noise):
		K_XX = self.v_b() + self.v_w() * gram_matrices
		for _ in range(depth):
			if self.manual_grad:
				K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, self.training, self.to_dtype, self.to_device)
			else:
				A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
				R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)
				K_XX = self.v_b() + self.v_w() * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)