From f60da40e47c009c72944a82aeeb4cd7671fb1b38 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Mon, 6 Nov 2023 10:33:43 +0100
Subject: [PATCH] Reshape.

---
 gext/fitting.py | 42 +++++++++++++++++++++++++++++-------------
 gext/main.py    |  9 ++++++---
 2 files changed, 35 insertions(+), 16 deletions(-)

diff --git a/gext/fitting.py b/gext/fitting.py
index 94f4e87..1d9fe93 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -9,11 +9,23 @@ class AbstractFitting(abc.ABC):
     """Base class for fitting schemes."""
 
     def __init__(self, **kwargs):
+        self.supported_options = {}
         self.set_options(**kwargs)
 
     @abc.abstractmethod
     def set_options(self, **kwargs):
         """Base method for setting options."""
+        self.options = {}
+        for key, value in kwargs.items():
+            if key in self.supported_options:
+                self.options[key] = value
+            else:
+                raise ValueError(f"Unsupported option: {key}")
+
+        for option, default_value in self.supported_options.items():
+            if option not in self.options:
+                self.options[option] = default_value
+
 
     @abc.abstractmethod
     def fit(self, vectors: List[np.ndarray], target:np.ndarray):
@@ -33,22 +45,15 @@ class LeastSquare(AbstractFitting):
 
     """Simple least square minimization fitting."""
 
-    supported_options = {
-        "regularization": 0.0,
-    }
+    def __init__(self, **kwargs):
+        self.supported_options = {
+            "regularization": 0.0,
+        }
+        super().__init__(**kwargs)
 
     def set_options(self, **kwargs):
         """Set options for least square minimization"""
-        self.options = {}
-        for key, value in kwargs.items():
-            if key in self.supported_options:
-                self.options[key] = value
-            else:
-                raise ValueError(f"Unsupported option: {key}")
-
-        for option, default_value in self.supported_options.items():
-            if option not in self.options:
-                self.options[option] = default_value
+        super().set_options(**kwargs)
 
         if self.options["regularization"] < 0 \
                 or self.options["regularization"] > 100:
@@ -65,8 +70,19 @@ class QuasiTimeReversible(AbstractFitting):
 
     """Quasi time reversible fitting scheme. Not yet implemented."""
 
+    def __init__(self, **kwargs):
+        self.supported_options = {
+            "regularization": 0.0,
+        }
+        super().__init__(**kwargs)
+
     def set_options(self, **kwargs):
         """Set options for quasi time reversible fitting"""
+        super().set_options(**kwargs)
+
+        if self.options["regularization"] < 0 \
+                or self.options["regularization"] > 100:
+            raise ValueError("Unsupported value for regularization")
 
     def fit(self, vectors: List[np.ndarray], target: np.ndarray):
         """Time reversible least square minimization fitting."""
diff --git a/gext/main.py b/gext/main.py
index e87e37e..b48594a 100644
--- a/gext/main.py
+++ b/gext/main.py
@@ -12,9 +12,7 @@ class Extrapolator:
 
     """Class for performing Grassmann extrapolations. On initialization
     it requires the number of electrons, the number of basis functions
-    and the number of atoms of the molecule. The number of previous
-    steps used by the extrapolator is an optional argument with default
-    value of 6."""
+    and the number of atoms of the molecule."""
 
     def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
 
@@ -83,12 +81,17 @@ class Extrapolator:
     def load_data(self, coords: np.ndarray, coeff: np.ndarray,
             overlap: np.ndarray):
         """Load a new data point in the extrapolator."""
+
+        # Crop the coefficient matrix up to the number of electron
+        # pairs, then apply S^1/2
         coeff = self._crop_coeff(coeff)
         coeff = self._normalize(coeff, overlap)
 
+        # if it is the first time we load data, set the tangent point
         if self.tangent is None:
             self._set_tangent(coeff)
 
+        # push the new data to the corresponding vectors
         self.gammas.push(self._grassmann_log(coeff))
         self.descriptors.push(self._compute_descriptor(coords))
         self.overlaps.push(overlap)
-- 
GitLab