From 2ae7f6e65792abb4718ff328035624f311732cad Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Mon, 4 Mar 2024 13:26:49 +0100 Subject: [PATCH] Fixed normal tangent case. --- gext/main.py | 17 ++++++++++------- tests/test_grassmann.py | 3 ++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/gext/main.py b/gext/main.py index 61124b7..066d6a3 100644 --- a/gext/main.py +++ b/gext/main.py @@ -102,13 +102,16 @@ class Extrapolator: # pairs, then apply S^1/2 coeff = self._crop_coeff(coeff) coeff = self._normalize(coeff, overlap) + + # if it is the first time we load data, set the tangent point + if self.tangent is None and self.options["tangent"] != "one_before_last": + self._set_tangent(coeff) + if self.options["tangent"]=="one_before_last": self.coeffs.push(coeff) else: self.gammas.push(self._grassmann_log(coeff)) - # if it is the first time we load data, set the tangent point - if self.tangent is None and self.options["tangent"] is not "one_before_last": - self._set_tangent(coeff) + # push the new data to the corresponding vectors self.descriptors.push(self._compute_descriptor(descriptor_input)) @@ -140,7 +143,7 @@ class Extrapolator: # use the descriptors to find the fitting coefficients prev_descriptors= self.descriptors.get(n) - descriptor = self._compute_descriptor(coords) + descriptor = self._compute_descriptor(descriptor_input) fit_coefficients = self._fit(prev_descriptors, descriptor) print(fit_coefficients) @@ -224,9 +227,9 @@ class Extrapolator: q, s, vt = np.linalg.svd(overlap, full_matrices=False) return q @ np.diag(1.0/np.sqrt(s)) @ vt - def _compute_descriptor(self, coords) -> np.ndarray: - """Given a set of coordinates compute the corresponding descriptor.""" - return self.descriptor_calculator.compute(coords) + def _compute_descriptor(self, descriptor_input) -> np.ndarray: + """Given an input compute the corresponding descriptor.""" + return self.descriptor_calculator.compute(descriptor_input) def _fit(self, prev_descriptors, descriptor) -> np.ndarray: """Fit the current descriptor using previous descriptors and diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py index 69786b2..663220e 100644 --- a/tests/test_grassmann.py +++ b/tests/test_grassmann.py @@ -21,7 +21,8 @@ def test_grassmann(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes, + tangent="one") # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], -- GitLab