Commit 348d67ab authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Add GPU support.

parent 46cf6a17
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -5,8 +5,9 @@ from sklearn import datasets

class Boston(Dataset):

	def __init__(self, dtype=torch.float64, standardize=False):
	def __init__(self, dtype=torch.float64, device="cpu", standardize=False):
		self.dtype = dtype
		self.device = device
		self.load_data(standardize=standardize)

	def __len__(self):
@@ -14,8 +15,8 @@ class Boston(Dataset):

	def load_data(self, standardize):
		self.data, self.labels = datasets.load_boston(return_X_y=True)
		self.data = torch.from_numpy(self.data).to(dtype=self.dtype)
		self.labels = torch.from_numpy(self.labels).to(dtype=self.dtype).view(len(self.labels), 1)
		self.data = torch.from_numpy(self.data).to(dtype=self.dtype, device=self.device)
		self.labels = torch.from_numpy(self.labels).to(dtype=self.dtype, device=self.device).view(len(self.labels), 1)

		if standardize:
			self.data = (self.data - torch.mean(self.data, dim=0, keepdim=True))/torch.std(self.data, dim=0, keepdim=True, unbiased=False)
+5 −2
Original line number Diff line number Diff line
@@ -27,17 +27,20 @@ parser.add_argument("--train_samples", "-s", default=100, type=int, help="Number
parser.add_argument("--test_samples", "-t", default=100, type=int, help="Number of MC samples from bottlenecks at test time.")
# optimization
parser.add_argument("--iters", "-i", default=10, type=int, help="Number of training iterations.")
# GPU
parser.add_argument("--gpu", "-g", default=-1, type=int, help="Which GPU to use. If negative, CPU is used.")

args = parser.parse_args()
device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)


# load dataset
dataset = getattr(datasets, args.dataset.capitalize())(standardize=True)
dataset = getattr(datasets, args.dataset.capitalize())(device=device, standardize=True)
dataloader = DataLoader(dataset, batch_size=len(dataset))
X, Y = next(iter(dataloader))

# get model
model = BottleneckNNGP(depths=args.depths, widths=args.widths, v_b=args.vb, v_w=args.vw, v_n=args.vn)
model = BottleneckNNGP(depths=args.depths, widths=args.widths, v_b=args.vb, v_w=args.vw, v_n=args.vn, device=device)

# get optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-1)
+10 −9
Original line number Diff line number Diff line
@@ -4,16 +4,17 @@ import torch.nn as nn

class BottleneckNNGP(nn.Module):

	def __init__(self, depths=[0], widths=[], v_b=1.0, v_w=1.0, v_n=1.0, jitter=1e-10, dtype=torch.float64):
	def __init__(self, depths=[0], widths=[], v_b=1.0, v_w=1.0, v_n=1.0, jitter=1e-10, dtype=torch.float64, device="cpu"):
		super().__init__()
		self.depths = depths
		self.widths = widths
		self.jitter = jitter
		self.dtype = dtype
		self.device = device

		self.v_b = nn.Parameter(data=torch.tensor([v_b], dtype=self.dtype), requires_grad=True)
		self.v_w = nn.Parameter(data=torch.tensor([v_w], dtype=self.dtype), requires_grad=True)
		self.v_n = nn.Parameter(data=torch.tensor([v_n], dtype=self.dtype), requires_grad=True)
		self.v_b = nn.Parameter(data=torch.tensor([v_b], dtype=self.dtype, device=self.device), requires_grad=True)
		self.v_w = nn.Parameter(data=torch.tensor([v_w], dtype=self.dtype, device=self.device), requires_grad=True)
		self.v_n = nn.Parameter(data=torch.tensor([v_n], dtype=self.dtype, device=self.device), requires_grad=True)


	def forward(self, x, num_samples=100):
@@ -22,9 +23,9 @@ class BottleneckNNGP(nn.Module):
		for (depth, width) in zip(self.depths[:-1], self.widths):
			Ks = self.K(gram_matrices, depth, self.jitter)
			Ls = torch.cholesky(Ks)
			samples = torch.randn(num_samples, x.shape[0], width, dtype=self.dtype)
			samples = torch.randn(num_samples, x.shape[0], width, dtype=self.dtype, device=self.device)
			samples = torch.matmul(Ls, samples)
			samples = torch.tensor([2.], dtype=self.dtype).sqrt() * nn.functional.relu(samples)
			samples = torch.tensor([2.], dtype=self.dtype, device=self.device).sqrt() * nn.functional.relu(samples)
			gram_matrices = torch.matmul(samples, torch.transpose(samples, 1, 2)) / width

		Ks = self.K(gram_matrices, self.depths[-1], self.v_n)
@@ -37,7 +38,7 @@ class BottleneckNNGP(nn.Module):
			A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX))
			R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter)
			K_XX = self.v_b + self.v_w * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R)
		return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype).unsqueeze(0)
		return K_XX + noise * torch.eye(K_XX.size(1), dtype=self.dtype, device=self.device).unsqueeze(0)


	def log_likelihood(self, Ks, y):
@@ -48,10 +49,10 @@ class BottleneckNNGP(nn.Module):
			-0.5 * (
				torch.sum(torch.sum(LYs**2, 1), 1)
				+ y.shape[1] * (
					logdets + y.shape[0] * torch.tensor([2 * math.pi], dtype=self.dtype).log()
					logdets + y.shape[0] * torch.tensor([2 * math.pi], dtype=self.dtype, device=self.device).log()
				)
			),
			0
		) - torch.tensor([Ks.size(0)], dtype=self.dtype).log()
		) - torch.tensor([Ks.size(0)], dtype=self.dtype, device=self.device).log()

		return ll