From f60da40e47c009c72944a82aeeb4cd7671fb1b38 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Mon, 6 Nov 2023 10:33:43 +0100 Subject: [PATCH] Reshape. --- gext/fitting.py | 42 +++++++++++++++++++++++++++++------------- gext/main.py | 9 ++++++--- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/gext/fitting.py b/gext/fitting.py index 94f4e87..1d9fe93 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -9,11 +9,23 @@ class AbstractFitting(abc.ABC): """Base class for fitting schemes.""" def __init__(self, **kwargs): + self.supported_options = {} self.set_options(**kwargs) @abc.abstractmethod def set_options(self, **kwargs): """Base method for setting options.""" + self.options = {} + for key, value in kwargs.items(): + if key in self.supported_options: + self.options[key] = value + else: + raise ValueError(f"Unsupported option: {key}") + + for option, default_value in self.supported_options.items(): + if option not in self.options: + self.options[option] = default_value + @abc.abstractmethod def fit(self, vectors: List[np.ndarray], target:np.ndarray): @@ -33,22 +45,15 @@ class LeastSquare(AbstractFitting): """Simple least square minimization fitting.""" - supported_options = { - "regularization": 0.0, - } + def __init__(self, **kwargs): + self.supported_options = { + "regularization": 0.0, + } + super().__init__(**kwargs) def set_options(self, **kwargs): """Set options for least square minimization""" - self.options = {} - for key, value in kwargs.items(): - if key in self.supported_options: - self.options[key] = value - else: - raise ValueError(f"Unsupported option: {key}") - - for option, default_value in self.supported_options.items(): - if option not in self.options: - self.options[option] = default_value + super().set_options(**kwargs) if self.options["regularization"] < 0 \ or self.options["regularization"] > 100: @@ -65,8 +70,19 @@ class QuasiTimeReversible(AbstractFitting): """Quasi time reversible fitting scheme. Not yet implemented.""" + def __init__(self, **kwargs): + self.supported_options = { + "regularization": 0.0, + } + super().__init__(**kwargs) + def set_options(self, **kwargs): """Set options for quasi time reversible fitting""" + super().set_options(**kwargs) + + if self.options["regularization"] < 0 \ + or self.options["regularization"] > 100: + raise ValueError("Unsupported value for regularization") def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Time reversible least square minimization fitting.""" diff --git a/gext/main.py b/gext/main.py index e87e37e..b48594a 100644 --- a/gext/main.py +++ b/gext/main.py @@ -12,9 +12,7 @@ class Extrapolator: """Class for performing Grassmann extrapolations. On initialization it requires the number of electrons, the number of basis functions - and the number of atoms of the molecule. The number of previous - steps used by the extrapolator is an optional argument with default - value of 6.""" + and the number of atoms of the molecule.""" def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): @@ -83,12 +81,17 @@ class Extrapolator: def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): """Load a new data point in the extrapolator.""" + + # Crop the coefficient matrix up to the number of electron + # pairs, then apply S^1/2 coeff = self._crop_coeff(coeff) coeff = self._normalize(coeff, overlap) + # if it is the first time we load data, set the tangent point if self.tangent is None: self._set_tangent(coeff) + # push the new data to the corresponding vectors self.gammas.push(self._grassmann_log(coeff)) self.descriptors.push(self._compute_descriptor(coords)) self.overlaps.push(overlap) -- GitLab