From faf3508ce26c8b61bd2b84d02e89a15f4dd5bdff Mon Sep 17 00:00:00 2001
From: Massimiliano Lupo Pasini <massimiliano.lupo.pasini@gmail.com>
Date: Thu, 11 Feb 2021 16:34:01 -0500
Subject: [PATCH] bug fix for call to bfgs

---
 modules/optimizers.py | 2 +-
 utils/dataloaders.py  | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/modules/optimizers.py b/modules/optimizers.py
index e5eebf4..a097c7d 100644
--- a/modules/optimizers.py
+++ b/modules/optimizers.py
@@ -195,7 +195,7 @@ class FixedPointIteration(object):
             self.optimizer_specified = True
         elif optimizer_string.lower() == 'lbfgs':
             self.optimizer = torch.optim.LBFGS(self.model.get_model().parameters(), lr=self.lr, history_size=10,
-                                              max_iter=20, line_search_fn=True, batch_mode=True)
+                                              max_iter=20, line_search_fn='strong_wolfe')
             self.optimizer_str = optimizer_string.lower()
             self.optimizer_specified = True
         else:
diff --git a/utils/dataloaders.py b/utils/dataloaders.py
index 1cb450f..3aacee0 100644
--- a/utils/dataloaders.py
+++ b/utils/dataloaders.py
@@ -209,7 +209,7 @@ def generate_dataloaders(dataset_name, subsample_factor, batch_size):
         if dataset_name == 'graduate_admission':
             input_dim, output_dim, dataset = graduate_admission_data()
         if dataset_name == 'nonlinear':
-            input_dim, output_dim, dataset = nonlinear_data(1.0, 1.0, 1000)
+            input_dim, output_dim, dataset = nonlinear_data(1000)
         dataset_size = len(dataset)
         indices = list(range(dataset_size))
         validation_split = 0.2
-- 
GitLab