Skip to content
Snippets Groups Projects
Commit 2c689449 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Supported options are now a class attribute.

parent e2b4ca59
Branches
Tags
1 merge request!6QTR
...@@ -7,6 +7,8 @@ class Distance: ...@@ -7,6 +7,8 @@ class Distance:
"""Distance matrix descriptors.""" """Distance matrix descriptors."""
supported_options = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
...@@ -24,6 +26,8 @@ class Coulomb(Distance): ...@@ -24,6 +26,8 @@ class Coulomb(Distance):
"""Coulomb matrix descriptors.""" """Coulomb matrix descriptors."""
supported_options = {}
def compute(self, coords: np.ndarray) -> np.ndarray: def compute(self, coords: np.ndarray) -> np.ndarray:
"""Compute the Coulomb matrix as a descriptor.""" """Compute the Coulomb matrix as a descriptor."""
return 1.0/super().compute(coords) return 1.0/super().compute(coords)
...@@ -8,8 +8,9 @@ class AbstractFitting(abc.ABC): ...@@ -8,8 +8,9 @@ class AbstractFitting(abc.ABC):
"""Base class for fitting schemes.""" """Base class for fitting schemes."""
supported_options = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.supported_options = {}
self.set_options(**kwargs) self.set_options(**kwargs)
@abc.abstractmethod @abc.abstractmethod
...@@ -26,7 +27,6 @@ class AbstractFitting(abc.ABC): ...@@ -26,7 +27,6 @@ class AbstractFitting(abc.ABC):
if option not in self.options: if option not in self.options:
self.options[option] = default_value self.options[option] = default_value
@abc.abstractmethod @abc.abstractmethod
def fit(self, vectors: List[np.ndarray], target:np.ndarray): def fit(self, vectors: List[np.ndarray], target:np.ndarray):
"""Base method for computing new fitting coefficients.""" """Base method for computing new fitting coefficients."""
...@@ -45,10 +45,11 @@ class LeastSquare(AbstractFitting): ...@@ -45,10 +45,11 @@ class LeastSquare(AbstractFitting):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
def __init__(self, **kwargs): supported_options = {
self.supported_options = {
"regularization": 0.0, "regularization": 0.0,
} }
def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def set_options(self, **kwargs): def set_options(self, **kwargs):
...@@ -70,10 +71,11 @@ class QuasiTimeReversible(AbstractFitting): ...@@ -70,10 +71,11 @@ class QuasiTimeReversible(AbstractFitting):
"""Quasi time reversible fitting scheme. Not yet implemented.""" """Quasi time reversible fitting scheme. Not yet implemented."""
def __init__(self, **kwargs): supported_options = {
self.supported_options = {
"regularization": 0.0, "regularization": 0.0,
} }
def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
def set_options(self, **kwargs): def set_options(self, **kwargs):
......
...@@ -14,9 +14,7 @@ class Extrapolator: ...@@ -14,9 +14,7 @@ class Extrapolator:
it requires the number of electrons, the number of basis functions it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule.""" and the number of atoms of the molecule."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): supported_options = {
self.supported_options = {
"verbose": False, "verbose": False,
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
...@@ -25,6 +23,8 @@ class Extrapolator: ...@@ -25,6 +23,8 @@ class Extrapolator:
"store_overlap": True, "store_overlap": True,
} }
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
self.nelectrons = nelectrons self.nelectrons = nelectrons
self.nbasis = nbasis self.nbasis = nbasis
self.natoms = natoms self.natoms = natoms
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment