Commit 0df876e0 authored by Devanshu Agrawal's avatar Devanshu Agrawal
Browse files

Add iris and rings datasets.

parent 72579984
Loading
Loading
Loading
Loading
+61 −0
Original line number Diff line number Diff line
@@ -25,3 +25,64 @@ class Boston(Dataset):

	def __getitem__(self, idx):
		return self.data[idx], self.labels[idx]


class Iris(Dataset):

	def __init__(self, dtype=torch.float64, device="cpu", standardize=False):
		self.dtype = dtype
		self.device = device
		self.load_data(standardize=standardize)

	def __len__(self):
		return len(self.data)

	def load_data(self, standardize):
		self.data, self.labels = datasets.load_iris(return_X_y=True)
		self.data = torch.from_numpy(self.data).to(dtype=self.dtype, device=self.device)
		self.labels = torch.from_numpy(self.labels).to(dtype=torch.long, device=self.device).view(len(self.labels), 1)
		self.labels = torch.zeros(len(self.labels), 3, dtype=self.dtype, device=self.device).scatter_(1, self.labels, 1)

		if standardize:
			self.data = (self.data - torch.mean(self.data, dim=0, keepdim=True))/torch.std(self.data, dim=0, keepdim=True, unbiased=False)
			self.labels = (self.labels - torch.mean(self.labels, dim=0, keepdim=True))/torch.std(self.labels, dim=0, keepdim=True, unbiased=False)


	def __getitem__(self, idx):
		return self.data[idx], self.labels[idx]


class Rings(Dataset):

	def __init__(self, dtype=torch.float64, device="cpu", standardize=False):
		self.dtype = dtype
		self.device = device
		self.load_data(standardize=standardize)

	def __len__(self):
		return len(self.data)

	def load_data(self, standardize):
		import itertools
		import numpy as np
		thetas = np.pi/6*np.arange(12)
		zs = 0.25*np.arange(5)-0.5

		X_1 = np.array([[np.cos(theta), np.sin(theta), z] for (theta, z) in itertools.product(thetas, zs)])
		X_2 = X_1.dot(np.array([[0, 0, -1], [0, 1, 0], [1, 0, 0]]))
		X_2[:,1] += 1
		X = np.concatenate((X_1, X_2), axis=0)
		Y = np.zeros((X.shape[0]))
		Y[X.shape[0]//2:] = 1

		self.data, self.labels = X, Y
		self.data = torch.from_numpy(self.data).to(dtype=self.dtype, device=self.device)
		self.labels = torch.from_numpy(self.labels).to(dtype=self.dtype, device=self.device).view(len(self.labels), 1)

		if standardize:
			self.data = (self.data - torch.mean(self.data, dim=0, keepdim=True))/torch.std(self.data, dim=0, keepdim=True, unbiased=False)
			self.labels = (self.labels - torch.mean(self.labels, dim=0, keepdim=True))/torch.std(self.labels, dim=0, keepdim=True, unbiased=False)


	def __getitem__(self, idx):
		return self.data[idx], self.labels[idx]