From faf3508ce26c8b61bd2b84d02e89a15f4dd5bdff Mon Sep 17 00:00:00 2001 From: Massimiliano Lupo Pasini <massimiliano.lupo.pasini@gmail.com> Date: Thu, 11 Feb 2021 16:34:01 -0500 Subject: [PATCH] bug fix for call to bfgs --- modules/optimizers.py | 2 +- utils/dataloaders.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/optimizers.py b/modules/optimizers.py index e5eebf4..a097c7d 100644 --- a/modules/optimizers.py +++ b/modules/optimizers.py @@ -195,7 +195,7 @@ class FixedPointIteration(object): self.optimizer_specified = True elif optimizer_string.lower() == 'lbfgs': self.optimizer = torch.optim.LBFGS(self.model.get_model().parameters(), lr=self.lr, history_size=10, - max_iter=20, line_search_fn=True, batch_mode=True) + max_iter=20, line_search_fn='strong_wolfe') self.optimizer_str = optimizer_string.lower() self.optimizer_specified = True else: diff --git a/utils/dataloaders.py b/utils/dataloaders.py index 1cb450f..3aacee0 100644 --- a/utils/dataloaders.py +++ b/utils/dataloaders.py @@ -209,7 +209,7 @@ def generate_dataloaders(dataset_name, subsample_factor, batch_size): if dataset_name == 'graduate_admission': input_dim, output_dim, dataset = graduate_admission_data() if dataset_name == 'nonlinear': - input_dim, output_dim, dataset = nonlinear_data(1.0, 1.0, 1000) + input_dim, output_dim, dataset = nonlinear_data(1000) dataset_size = len(dataset) indices = list(range(dataset_size)) validation_split = 0.2 -- GitLab