From 2c68944980c4aae783cf57c9f5e64696e47224f9 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Mon, 6 Nov 2023 10:55:12 +0100 Subject: [PATCH] Supported options are now a class attribute. --- gext/descriptors.py | 4 ++++ gext/fitting.py | 18 ++++++++++-------- gext/main.py | 18 +++++++++--------- 3 files changed, 23 insertions(+), 17 deletions(-) diff --git a/gext/descriptors.py b/gext/descriptors.py index f1f6657..e7813e4 100644 --- a/gext/descriptors.py +++ b/gext/descriptors.py @@ -7,6 +7,8 @@ class Distance: """Distance matrix descriptors.""" + supported_options = {} + def __init__(self, **kwargs): self.set_options(**kwargs) @@ -24,6 +26,8 @@ class Coulomb(Distance): """Coulomb matrix descriptors.""" + supported_options = {} + def compute(self, coords: np.ndarray) -> np.ndarray: """Compute the Coulomb matrix as a descriptor.""" return 1.0/super().compute(coords) diff --git a/gext/fitting.py b/gext/fitting.py index 1d9fe93..811434e 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -8,8 +8,9 @@ class AbstractFitting(abc.ABC): """Base class for fitting schemes.""" + supported_options = {} + def __init__(self, **kwargs): - self.supported_options = {} self.set_options(**kwargs) @abc.abstractmethod @@ -26,7 +27,6 @@ class AbstractFitting(abc.ABC): if option not in self.options: self.options[option] = default_value - @abc.abstractmethod def fit(self, vectors: List[np.ndarray], target:np.ndarray): """Base method for computing new fitting coefficients.""" @@ -45,10 +45,11 @@ class LeastSquare(AbstractFitting): """Simple least square minimization fitting.""" + supported_options = { + "regularization": 0.0, + } + def __init__(self, **kwargs): - self.supported_options = { - "regularization": 0.0, - } super().__init__(**kwargs) def set_options(self, **kwargs): @@ -70,10 +71,11 @@ class QuasiTimeReversible(AbstractFitting): """Quasi time reversible fitting scheme. Not yet implemented.""" + supported_options = { + "regularization": 0.0, + } + def __init__(self, **kwargs): - self.supported_options = { - "regularization": 0.0, - } super().__init__(**kwargs) def set_options(self, **kwargs): diff --git a/gext/main.py b/gext/main.py index 76af6a8..46d7e7f 100644 --- a/gext/main.py +++ b/gext/main.py @@ -14,16 +14,16 @@ class Extrapolator: it requires the number of electrons, the number of basis functions and the number of atoms of the molecule.""" - def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): + supported_options = { + "verbose": False, + "nsteps": 6, + "descriptor": "distance", + "fitting": "leastsquare", + "allow_partially_filled": True, + "store_overlap": True, + } - self.supported_options = { - "verbose": False, - "nsteps": 6, - "descriptor": "distance", - "fitting": "leastsquare", - "allow_partially_filled": True, - "store_overlap": True, - } + def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): self.nelectrons = nelectrons self.nbasis = nbasis -- GitLab