diff --git a/gext/descriptors.py b/gext/descriptors.py index f1f6657e3111ef600331fcb59244d12dc8ec6527..e7813e40626fb3e6dffc7dc05e2a666a68f4b93e 100644 --- a/gext/descriptors.py +++ b/gext/descriptors.py @@ -7,6 +7,8 @@ class Distance: """Distance matrix descriptors.""" + supported_options = {} + def __init__(self, **kwargs): self.set_options(**kwargs) @@ -24,6 +26,8 @@ class Coulomb(Distance): """Coulomb matrix descriptors.""" + supported_options = {} + def compute(self, coords: np.ndarray) -> np.ndarray: """Compute the Coulomb matrix as a descriptor.""" return 1.0/super().compute(coords) diff --git a/gext/fitting.py b/gext/fitting.py index 1d9fe935b628237c7fdda23bf13c6cc0ac043f61..811434ea1ea6720829fa553cdbd1dddd6c7a23cb 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -8,8 +8,9 @@ class AbstractFitting(abc.ABC): """Base class for fitting schemes.""" + supported_options = {} + def __init__(self, **kwargs): - self.supported_options = {} self.set_options(**kwargs) @abc.abstractmethod @@ -26,7 +27,6 @@ class AbstractFitting(abc.ABC): if option not in self.options: self.options[option] = default_value - @abc.abstractmethod def fit(self, vectors: List[np.ndarray], target:np.ndarray): """Base method for computing new fitting coefficients.""" @@ -45,10 +45,11 @@ class LeastSquare(AbstractFitting): """Simple least square minimization fitting.""" + supported_options = { + "regularization": 0.0, + } + def __init__(self, **kwargs): - self.supported_options = { - "regularization": 0.0, - } super().__init__(**kwargs) def set_options(self, **kwargs): @@ -70,10 +71,11 @@ class QuasiTimeReversible(AbstractFitting): """Quasi time reversible fitting scheme. Not yet implemented.""" + supported_options = { + "regularization": 0.0, + } + def __init__(self, **kwargs): - self.supported_options = { - "regularization": 0.0, - } super().__init__(**kwargs) def set_options(self, **kwargs): diff --git a/gext/main.py b/gext/main.py index 76af6a80b510b9f702b7290a71bc3526ae7473aa..46d7e7f61e55699ec99f6b36ca8c10c537de4dea 100644 --- a/gext/main.py +++ b/gext/main.py @@ -14,16 +14,16 @@ class Extrapolator: it requires the number of electrons, the number of basis functions and the number of atoms of the molecule.""" - def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): + supported_options = { + "verbose": False, + "nsteps": 6, + "descriptor": "distance", + "fitting": "leastsquare", + "allow_partially_filled": True, + "store_overlap": True, + } - self.supported_options = { - "verbose": False, - "nsteps": 6, - "descriptor": "distance", - "fitting": "leastsquare", - "allow_partially_filled": True, - "store_overlap": True, - } + def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): self.nelectrons = nelectrons self.nbasis = nbasis