Skip to content
Snippets Groups Projects
Commit 2ae7f6e6 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Fixed normal tangent case.

parent 63489817
No related branches found
No related tags found
No related merge requests found
Pipeline #2066 passed
......@@ -102,13 +102,16 @@ class Extrapolator:
# pairs, then apply S^1/2
coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap)
# if it is the first time we load data, set the tangent point
if self.tangent is None and self.options["tangent"] != "one_before_last":
self._set_tangent(coeff)
if self.options["tangent"]=="one_before_last":
self.coeffs.push(coeff)
else:
self.gammas.push(self._grassmann_log(coeff))
# if it is the first time we load data, set the tangent point
if self.tangent is None and self.options["tangent"] is not "one_before_last":
self._set_tangent(coeff)
# push the new data to the corresponding vectors
self.descriptors.push(self._compute_descriptor(descriptor_input))
......@@ -140,7 +143,7 @@ class Extrapolator:
# use the descriptors to find the fitting coefficients
prev_descriptors= self.descriptors.get(n)
descriptor = self._compute_descriptor(coords)
descriptor = self._compute_descriptor(descriptor_input)
fit_coefficients = self._fit(prev_descriptors, descriptor)
print(fit_coefficients)
......@@ -224,9 +227,9 @@ class Extrapolator:
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(1.0/np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return self.descriptor_calculator.compute(coords)
def _compute_descriptor(self, descriptor_input) -> np.ndarray:
"""Given an input compute the corresponding descriptor."""
return self.descriptor_calculator.compute(descriptor_input)
def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
"""Fit the current descriptor using previous descriptors and
......
......@@ -21,7 +21,8 @@ def test_grassmann(datafile):
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes)
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes,
tangent="one")
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment