Skip to content
Snippets Groups Projects
Commit 2ae7f6e6 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Fixed normal tangent case.

parent 63489817
No related branches found
No related tags found
No related merge requests found
Pipeline #2066 passed
...@@ -102,13 +102,16 @@ class Extrapolator: ...@@ -102,13 +102,16 @@ class Extrapolator:
# pairs, then apply S^1/2 # pairs, then apply S^1/2
coeff = self._crop_coeff(coeff) coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap) coeff = self._normalize(coeff, overlap)
# if it is the first time we load data, set the tangent point
if self.tangent is None and self.options["tangent"] != "one_before_last":
self._set_tangent(coeff)
if self.options["tangent"]=="one_before_last": if self.options["tangent"]=="one_before_last":
self.coeffs.push(coeff) self.coeffs.push(coeff)
else: else:
self.gammas.push(self._grassmann_log(coeff)) self.gammas.push(self._grassmann_log(coeff))
# if it is the first time we load data, set the tangent point
if self.tangent is None and self.options["tangent"] is not "one_before_last":
self._set_tangent(coeff)
# push the new data to the corresponding vectors # push the new data to the corresponding vectors
self.descriptors.push(self._compute_descriptor(descriptor_input)) self.descriptors.push(self._compute_descriptor(descriptor_input))
...@@ -140,7 +143,7 @@ class Extrapolator: ...@@ -140,7 +143,7 @@ class Extrapolator:
# use the descriptors to find the fitting coefficients # use the descriptors to find the fitting coefficients
prev_descriptors= self.descriptors.get(n) prev_descriptors= self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(descriptor_input)
fit_coefficients = self._fit(prev_descriptors, descriptor) fit_coefficients = self._fit(prev_descriptors, descriptor)
print(fit_coefficients) print(fit_coefficients)
...@@ -224,9 +227,9 @@ class Extrapolator: ...@@ -224,9 +227,9 @@ class Extrapolator:
q, s, vt = np.linalg.svd(overlap, full_matrices=False) q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(1.0/np.sqrt(s)) @ vt return q @ np.diag(1.0/np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray: def _compute_descriptor(self, descriptor_input) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor.""" """Given an input compute the corresponding descriptor."""
return self.descriptor_calculator.compute(coords) return self.descriptor_calculator.compute(descriptor_input)
def _fit(self, prev_descriptors, descriptor) -> np.ndarray: def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
"""Fit the current descriptor using previous descriptors and """Fit the current descriptor using previous descriptors and
......
...@@ -21,7 +21,8 @@ def test_grassmann(datafile): ...@@ -21,7 +21,8 @@ def test_grassmann(datafile):
nframes = data["trajectory"].shape[0] nframes = data["trajectory"].shape[0]
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes,
tangent="one")
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment