Loading torch_experiments/autograds.py +27 −14 Original line number Diff line number Diff line Loading @@ -5,32 +5,45 @@ from torch.autograd import Function class NNGPReLU(Function): @staticmethod def forward(ctx, input, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input, input)) R = input/A def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A if not is_grad_enabled: return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi F = A * (F_1 + F_2*R) if is_grad_enabled: F_1 *= torch.diagonal(input, dim1=-2, dim2=-1).unsqueeze(2)/A F_1 *= torch.diagonal(input_K, dim1=-2, dim2=-1).unsqueeze(2)/A del A del R output = input_vb + input_vw * F F_1 = F_1.to(dtype=to_dtype, device=to_device) f_2 = F_2.to(dtype=to_dtype, device=to_device) ctx.save_for_backward(F_1, F_2) F_2 = F_2.to(dtype=to_dtype, device=to_device) F = F.to(dtype=to_dtype, device=to_device) ctx.save_for_backward(input_vb, input_vw, F_1, F_2, F) return F return output @staticmethod def backward(ctx, grad_output): F_1, F_2 = ctx.saved_tensors input_vb, input_vw, F_1, F_2, F = ctx.saved_tensors F_1 = F_1.to(dtype=grad_output.dtype, device=grad_output.device) F_2 = F_2.to(dtype=grad_output.dtype, device=grad_output.device) dim = F_2.size(1) F = F.to(dtype=grad_output.dtype, device=grad_output.device) grad_input_vb = torch.sum(grad_output).view(1) grad_input_vw = torch.sum(grad_output * F).view(1) dim = F.size(1) grad_input_K = grad_output*F_2 grad_input_K[:,range(dim),range(dim)] += torch.sum( grad_output*F_1, 1) grad_input_K *= input_vw grad_input = grad_output*F_2 grad_input[:,range(dim),range(dim)] += torch.sum( grad_output*F_1, 1) return grad_input, None, None, None return grad_input_vb, grad_input_vw, grad_input_K, None, None, None nngp_relu = NNGPReLU.apply torch_experiments/modules.py +1 −1 Original line number Diff line number Diff line Loading @@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module): K_XX = self.v_b() + self.v_w() * gram_matrices for _ in range(depth): if self.manual_grad: K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) Loading Loading
torch_experiments/autograds.py +27 −14 Original line number Diff line number Diff line Loading @@ -5,32 +5,45 @@ from torch.autograd import Function class NNGPReLU(Function): @staticmethod def forward(ctx, input, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input, input)) R = input/A def forward(ctx, input_vb, input_vw, input_K, is_grad_enabled, to_dtype, to_device): A = torch.sqrt(torch.einsum("ijj, ikk->ijk", input_K, input_K)) R = input_K/A if not is_grad_enabled: return input_vb + input_vw * A / math.pi * (torch.sqrt(1 - R ** 2) + (math.pi - torch.acos(R)) * R) F_1 = torch.sqrt(1 - R**2)/math.pi F_2 = 1 - torch.acos(R)/math.pi F = A * (F_1 + F_2*R) if is_grad_enabled: F_1 *= torch.diagonal(input, dim1=-2, dim2=-1).unsqueeze(2)/A F_1 *= torch.diagonal(input_K, dim1=-2, dim2=-1).unsqueeze(2)/A del A del R output = input_vb + input_vw * F F_1 = F_1.to(dtype=to_dtype, device=to_device) f_2 = F_2.to(dtype=to_dtype, device=to_device) ctx.save_for_backward(F_1, F_2) F_2 = F_2.to(dtype=to_dtype, device=to_device) F = F.to(dtype=to_dtype, device=to_device) ctx.save_for_backward(input_vb, input_vw, F_1, F_2, F) return F return output @staticmethod def backward(ctx, grad_output): F_1, F_2 = ctx.saved_tensors input_vb, input_vw, F_1, F_2, F = ctx.saved_tensors F_1 = F_1.to(dtype=grad_output.dtype, device=grad_output.device) F_2 = F_2.to(dtype=grad_output.dtype, device=grad_output.device) dim = F_2.size(1) F = F.to(dtype=grad_output.dtype, device=grad_output.device) grad_input_vb = torch.sum(grad_output).view(1) grad_input_vw = torch.sum(grad_output * F).view(1) dim = F.size(1) grad_input_K = grad_output*F_2 grad_input_K[:,range(dim),range(dim)] += torch.sum( grad_output*F_1, 1) grad_input_K *= input_vw grad_input = grad_output*F_2 grad_input[:,range(dim),range(dim)] += torch.sum( grad_output*F_1, 1) return grad_input, None, None, None return grad_input_vb, grad_input_vw, grad_input_K, None, None, None nngp_relu = NNGPReLU.apply
torch_experiments/modules.py +1 −1 Original line number Diff line number Diff line Loading @@ -41,7 +41,7 @@ class BottleneckNNGP(nn.Module): K_XX = self.v_b() + self.v_w() * gram_matrices for _ in range(depth): if self.manual_grad: K_XX = self.v_b() + self.v_w() * nngp_relu(K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) K_XX = nngp_relu(self.v_b(), self.v_w(), K_XX, torch.is_grad_enabled(), self.to_dtype, self.to_device) else: A = torch.sqrt(torch.einsum("ijj, ikk->ijk", K_XX, K_XX)) R = torch.clamp(K_XX/A, -1+self.jitter, 1-self.jitter) Loading