Commit 16c4e17b authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Add delta loss given same sample draw.

parent ac9d99ba
Loading
Loading
Loading
Loading
+31 −17
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import argparse
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from modules import BottleneckNNGP
from tempmodules import BottleneckNNGP
import datasets

# set torch seed
@@ -43,38 +43,51 @@ X, Y = next(iter(dataloader))
# get model
model = BottleneckNNGP(depths=args.depths, widths=args.widths, v_b=args.vb, v_w=args.vw, v_n=args.vn, device=device)

# get optimizer
# define stochastic loss function
loss_fn = lambda num_samples, manual_samples: -model.log_likelihood(X, Y, num_samples=num_samples, manual_samples=manual_samples)/X.shape[0]

# get optimizer and LR scheduler
optimizer = optim.Adam(model.parameters(), lr=args.lr)
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda iter: 1/(1+iter//5))
scheduler = optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lambda iter: 0.9)

# initialize results dict
results = {"train": {"loss": []}}
train_results_keys = ["loss", "v_b", "v_w", "v_n"]
for (key, value) in zip(train_results_keys[1:], [model.v_b(), model.v_w(), model.v_n()]):
	results["train"][key] = [value.item()]
results_train_keys = ["v_b", "v_w", "v_n", "loss", "loss_new", "delta_loss"]
results = {"train": {key: [] for key in results_train_keys}}
for (key, value) in zip(results_train_keys[:3], [model.v_b(), model.v_w(), model.v_n()]):
	results["train"][key].append( value.item() )

# training loop
print("Training ...")
time_0 = time.time()
for i in range(args.iters):
	# sample standard normal
	samples = [torch.randn(args.train_samples, X.shape[0], width, dtype=torch.float64, device="cpu") for width in args.widths]
	# compute loss
	loss = -model.log_likelihood(X, Y, num_samples=args.train_samples)
	loss = loss_fn(args.train_samples, samples)
	# optim step
	loss.backward()
	optimizer.step()
	optimizer.zero_grad()
#	scheduler.step()
	# compute new loss given same sample draw
	with torch.no_grad():
		loss_new = loss_fn(args.train_samples, samples)
		delta_loss = loss_new-loss
	# decrease LR if loss increases
	if delta_loss.item() > 0:
		scheduler.step()
	# update results
	for (key, value) in zip(train_results_keys, [loss, model.v_b(), model.v_w(), model.v_n()]):
	for (key, value) in zip(results_train_keys, [model.v_b(), model.v_w(), model.v_n(), loss, loss_new, delta_loss]):
		results["train"][key].append( value.item() )
	# print loss
	print("iter {} loss: {:.3f}".format(i, loss.item()))
	# print loss info
	if i >= 5:
		print("iter {} delta loss {:.3f} % mov avg loss {:.3f}".format(i, 100*(loss_new/loss-1.0).item(), sum(results["train"]["loss"][-5:])/5))
	else:
		print("iter {} delta loss {:.3f} % mov avg loss ---".format(i, 100*(loss_new/loss-1.0).item()))

# compute loss of final iteration
with torch.no_grad():
	loss = -model.log_likelihood(X, Y, num_samples=args.train_samples)
	loss = loss_fn(args.train_samples, None)
results["train"]["loss"].append(loss.item())
print("iter {} loss: {:.3f}".format(args.iters, loss.item()))

# record training time
results["train"]["time"] = time.time()-time_0
@@ -82,9 +95,10 @@ results["train"]["time"] = time.time()-time_0
# get final values with higher number of samples
time_0 = time.time()
with torch.no_grad():
	loss = -model.log_likelihood(X, Y, num_samples=args.test_samples)
	loss = loss_fn(args.test_samples, None)
results["test"] = {}
for (key, value) in zip(train_results_keys, [loss, model.v_b(), model.v_w(), model.v_n()]):
results_test_keys = results_train_keys[:4]
for (key, value) in zip(results_test_keys, [model.v_b(), model.v_w(), model.v_n(), loss]):
	results["test"][key] = value.item()

# record test time
@@ -92,7 +106,7 @@ results["test"]["time"] = time.time()-time_0

# print final values
print("Final values:")
for key in train_results_keys:
for key in results_test_keys:
	print(key+": {:.3f}".format(results["test"][key]))

# record command-line arguments
+6 −3
Original line number Diff line number Diff line
@@ -44,12 +44,15 @@ class BottleneckNNGP(nn.Module):
		return self.forward(*args, **kwargs)


	def forward(self, x, y, num_samples=100):
	def forward(self, x, y, num_samples=100, manual_samples=None):
		gram_matrices = (x @ x.t()).unsqueeze(0)

		for (depth, width) in zip(self.depths[:-1], self.widths):
		for (i, (depth, width)) in enumerate(zip(self.depths[:-1], self.widths)):
			Ks = self.K(gram_matrices, depth, self.jitter)
			Ls = torch.cholesky(Ks)
			if manual_samples is not None:
				samples = manual_samples[i].to(dtype=self.dtype, device=self.device)
			else:
				samples = torch.randn(num_samples, x.shape[0], width, dtype=self.dtype, device=self.device)
			samples = torch.matmul(Ls, samples)
			samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() * nn.functional.relu(samples)