Skip to content
Snippets Groups Projects
Commit 44b65723 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Added the guess method.

parent a754e65b
No related branches found
No related tags found
No related merge requests found
......@@ -44,10 +44,27 @@ class Extrapolator:
self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]):
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray:
"""Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.nsteps)
gammas = self.gammas.get(self.nsteps)
descriptor = self._compute_descriptor(coords)
coefficients = fitting.linear()
coefficients = fitting.linear(prev_descriptors, descriptor)
gamma = fitting.linear_combination(gammas, coefficients)
if overlap is None:
overlaps = self.overlaps.get(self.nsteps)
overlap = fitting.linear_combination(overlaps, coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
c_guess = self._grassmann_exp(gamma)
c_guess = inverse_sqrt_overlap @ c_guess
return c_guess @ c_guess.T
def _get_tangent(self) -> np.ndarray:
"""Get the tangent point."""
......@@ -85,6 +102,11 @@ class Extrapolator:
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(np.sqrt(s)) @ vt
def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(1.0/np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return descriptors.distance(coords)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment