Commit 72579984 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Combine forward and LL functions.

parent 21d7b7bf
Loading
Loading
Loading
Loading
+3 −6
Original line number Diff line number Diff line
@@ -56,8 +56,7 @@ print("Training ...")
time_0 = time.time()
for i in range(args.iters):
	# compute loss
	Ks = model(X, num_samples=args.train_samples)
	loss = -model.log_likelihood(Ks, Y)
	loss = -model.log_likelihood(X, Y, num_samples=args.train_samples)
	# optim step
	loss.backward()
	optimizer.step()
@@ -70,8 +69,7 @@ for i in range(args.iters):

# compute loss of final iteration
with torch.no_grad():
	Ks = model(X, num_samples=args.train_samples)
	loss = -model.log_likelihood(Ks, Y)
	loss = -model.log_likelihood(X, Y, num_samples=args.train_samples)
results["train"]["loss"].append(loss.item())
print("iter {} loss: {:.3f}".format(args.iters, loss.item()))

@@ -81,8 +79,7 @@ results["train"]["time"] = time.time()-time_0
# get final values with higher number of samples
time_0 = time.time()
with torch.no_grad():
	Ks = model(X, num_samples=args.test_samples)
	loss = -model.log_likelihood(Ks, Y)
	loss = -model.log_likelihood(X, Y, num_samples=args.test_samples)
results["test"] = {}
for (key, value) in zip(train_results_keys, [loss, model.v_b, model.v_w, model.v_n]):
	results["test"][key] = value.item()
+14 −14
Original line number Diff line number Diff line
@@ -17,7 +17,20 @@ class BottleneckNNGP(nn.Module):
		self.v_n = nn.Parameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True)


	def forward(self, x, num_samples=100):
	def K(self, gram_matrices, depth, noise):
		K_XX = self.v_b + self.v_w * gram_matrices
		for _ in range(depth):
			A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
			R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)
			K_XX = self.v_b + self.v_w * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)


	def log_likelihood(self, *args, **kwargs):
		return self.forward(*args, **kwargs)


	def forward(self, x, y, num_samples=100):
		gram_matrices = (x @ x.t()).unsqueeze(0)

		for (depth, width) in zip(self.depths[:-1], self.widths):
@@ -29,19 +42,6 @@ class BottleneckNNGP(nn.Module):
			gram_matrices = torch.matmul(samples, torch.transpose(samples, 1, 2)) / width

		Ks = self.K(gram_matrices, self.depths[-1], self.v_n)
		return Ks


	def K(self, gram_matrices, depth, noise):
		K_XX = self.v_b + self.v_w * gram_matrices
		for _ in range(depth):
			A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
			R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)
			K_XX = self.v_b + self.v_w * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)


	def log_likelihood(self, Ks, y):
		Ls = torch.cholesky(Ks)
		LYs = torch.triangular_solve(y, Ls, upper=False).solution
		logdets = 2 * torch.sum(torch.log(torch.diagonal(Ls, dim1=1, dim2=2)), 1)