From 2c68944980c4aae783cf57c9f5e64696e47224f9 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Mon, 6 Nov 2023 10:55:12 +0100
Subject: [PATCH] Supported options are now a class attribute.

---
 gext/descriptors.py |  4 ++++
 gext/fitting.py     | 18 ++++++++++--------
 gext/main.py        | 18 +++++++++---------
 3 files changed, 23 insertions(+), 17 deletions(-)

diff --git a/gext/descriptors.py b/gext/descriptors.py
index f1f6657..e7813e4 100644
--- a/gext/descriptors.py
+++ b/gext/descriptors.py
@@ -7,6 +7,8 @@ class Distance:
 
     """Distance matrix descriptors."""
 
+    supported_options = {}
+
     def __init__(self, **kwargs):
         self.set_options(**kwargs)
 
@@ -24,6 +26,8 @@ class Coulomb(Distance):
 
     """Coulomb matrix descriptors."""
 
+    supported_options = {}
+
     def compute(self, coords: np.ndarray) -> np.ndarray:
         """Compute the Coulomb matrix as a descriptor."""
         return 1.0/super().compute(coords)
diff --git a/gext/fitting.py b/gext/fitting.py
index 1d9fe93..811434e 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -8,8 +8,9 @@ class AbstractFitting(abc.ABC):
 
     """Base class for fitting schemes."""
 
+    supported_options = {}
+
     def __init__(self, **kwargs):
-        self.supported_options = {}
         self.set_options(**kwargs)
 
     @abc.abstractmethod
@@ -26,7 +27,6 @@ class AbstractFitting(abc.ABC):
             if option not in self.options:
                 self.options[option] = default_value
 
-
     @abc.abstractmethod
     def fit(self, vectors: List[np.ndarray], target:np.ndarray):
         """Base method for computing new fitting coefficients."""
@@ -45,10 +45,11 @@ class LeastSquare(AbstractFitting):
 
     """Simple least square minimization fitting."""
 
+    supported_options = {
+        "regularization": 0.0,
+    }
+
     def __init__(self, **kwargs):
-        self.supported_options = {
-            "regularization": 0.0,
-        }
         super().__init__(**kwargs)
 
     def set_options(self, **kwargs):
@@ -70,10 +71,11 @@ class QuasiTimeReversible(AbstractFitting):
 
     """Quasi time reversible fitting scheme. Not yet implemented."""
 
+    supported_options = {
+        "regularization": 0.0,
+    }
+
     def __init__(self, **kwargs):
-        self.supported_options = {
-            "regularization": 0.0,
-        }
         super().__init__(**kwargs)
 
     def set_options(self, **kwargs):
diff --git a/gext/main.py b/gext/main.py
index 76af6a8..46d7e7f 100644
--- a/gext/main.py
+++ b/gext/main.py
@@ -14,16 +14,16 @@ class Extrapolator:
     it requires the number of electrons, the number of basis functions
     and the number of atoms of the molecule."""
 
-    def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
+    supported_options = {
+        "verbose": False,
+        "nsteps": 6,
+        "descriptor": "distance",
+        "fitting": "leastsquare",
+        "allow_partially_filled": True,
+        "store_overlap": True,
+    }
 
-        self.supported_options = {
-            "verbose": False,
-            "nsteps": 6,
-            "descriptor": "distance",
-            "fitting": "leastsquare",
-            "allow_partially_filled": True,
-            "store_overlap": True,
-        }
+    def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
 
         self.nelectrons = nelectrons
         self.nbasis = nbasis
-- 
GitLab