diff --git a/gext/fitting.py b/gext/fitting.py index 94f4e871c9fd221fdbae3009c1accb3b5562200e..1d9fe935b628237c7fdda23bf13c6cc0ac043f61 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -9,11 +9,23 @@ class AbstractFitting(abc.ABC): """Base class for fitting schemes.""" def __init__(self, **kwargs): + self.supported_options = {} self.set_options(**kwargs) @abc.abstractmethod def set_options(self, **kwargs): """Base method for setting options.""" + self.options = {} + for key, value in kwargs.items(): + if key in self.supported_options: + self.options[key] = value + else: + raise ValueError(f"Unsupported option: {key}") + + for option, default_value in self.supported_options.items(): + if option not in self.options: + self.options[option] = default_value + @abc.abstractmethod def fit(self, vectors: List[np.ndarray], target:np.ndarray): @@ -33,22 +45,15 @@ class LeastSquare(AbstractFitting): """Simple least square minimization fitting.""" - supported_options = { - "regularization": 0.0, - } + def __init__(self, **kwargs): + self.supported_options = { + "regularization": 0.0, + } + super().__init__(**kwargs) def set_options(self, **kwargs): """Set options for least square minimization""" - self.options = {} - for key, value in kwargs.items(): - if key in self.supported_options: - self.options[key] = value - else: - raise ValueError(f"Unsupported option: {key}") - - for option, default_value in self.supported_options.items(): - if option not in self.options: - self.options[option] = default_value + super().set_options(**kwargs) if self.options["regularization"] < 0 \ or self.options["regularization"] > 100: @@ -65,8 +70,19 @@ class QuasiTimeReversible(AbstractFitting): """Quasi time reversible fitting scheme. Not yet implemented.""" + def __init__(self, **kwargs): + self.supported_options = { + "regularization": 0.0, + } + super().__init__(**kwargs) + def set_options(self, **kwargs): """Set options for quasi time reversible fitting""" + super().set_options(**kwargs) + + if self.options["regularization"] < 0 \ + or self.options["regularization"] > 100: + raise ValueError("Unsupported value for regularization") def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Time reversible least square minimization fitting.""" diff --git a/gext/main.py b/gext/main.py index e87e37ef8e123c7eb25516a4f2c46afb6fa5ddb0..b48594a0962b3360d8a7b3c2725ea7d6bd504833 100644 --- a/gext/main.py +++ b/gext/main.py @@ -12,9 +12,7 @@ class Extrapolator: """Class for performing Grassmann extrapolations. On initialization it requires the number of electrons, the number of basis functions - and the number of atoms of the molecule. The number of previous - steps used by the extrapolator is an optional argument with default - value of 6.""" + and the number of atoms of the molecule.""" def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): @@ -83,12 +81,17 @@ class Extrapolator: def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): """Load a new data point in the extrapolator.""" + + # Crop the coefficient matrix up to the number of electron + # pairs, then apply S^1/2 coeff = self._crop_coeff(coeff) coeff = self._normalize(coeff, overlap) + # if it is the first time we load data, set the tangent point if self.tangent is None: self._set_tangent(coeff) + # push the new data to the corresponding vectors self.gammas.push(self._grassmann_log(coeff)) self.descriptors.push(self._compute_descriptor(coords)) self.overlaps.push(overlap)