diff --git a/gext/main.py b/gext/main.py index 61124b7debd4b6d02398532563f6e20920083aa6..066d6a39548e00814c3a36164de951717a24616a 100644 --- a/gext/main.py +++ b/gext/main.py @@ -102,13 +102,16 @@ class Extrapolator: # pairs, then apply S^1/2 coeff = self._crop_coeff(coeff) coeff = self._normalize(coeff, overlap) + + # if it is the first time we load data, set the tangent point + if self.tangent is None and self.options["tangent"] != "one_before_last": + self._set_tangent(coeff) + if self.options["tangent"]=="one_before_last": self.coeffs.push(coeff) else: self.gammas.push(self._grassmann_log(coeff)) - # if it is the first time we load data, set the tangent point - if self.tangent is None and self.options["tangent"] is not "one_before_last": - self._set_tangent(coeff) + # push the new data to the corresponding vectors self.descriptors.push(self._compute_descriptor(descriptor_input)) @@ -140,7 +143,7 @@ class Extrapolator: # use the descriptors to find the fitting coefficients prev_descriptors= self.descriptors.get(n) - descriptor = self._compute_descriptor(coords) + descriptor = self._compute_descriptor(descriptor_input) fit_coefficients = self._fit(prev_descriptors, descriptor) print(fit_coefficients) @@ -224,9 +227,9 @@ class Extrapolator: q, s, vt = np.linalg.svd(overlap, full_matrices=False) return q @ np.diag(1.0/np.sqrt(s)) @ vt - def _compute_descriptor(self, coords) -> np.ndarray: - """Given a set of coordinates compute the corresponding descriptor.""" - return self.descriptor_calculator.compute(coords) + def _compute_descriptor(self, descriptor_input) -> np.ndarray: + """Given an input compute the corresponding descriptor.""" + return self.descriptor_calculator.compute(descriptor_input) def _fit(self, prev_descriptors, descriptor) -> np.ndarray: """Fit the current descriptor using previous descriptors and diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py index 69786b275375912718b09782840b9dc1e232ae33..663220e1619dfc1a45e497f904445ee7efd52223 100644 --- a/tests/test_grassmann.py +++ b/tests/test_grassmann.py @@ -21,7 +21,8 @@ def test_grassmann(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes, + tangent="one") # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"],