diff --git a/grext/main.py b/grext/main.py index 1bf9adc5c0c0476a17beba304d2bdb3ddbc5ade7..e17766dbf8787fb08a96bc8a3f65f92b85a76555 100644 --- a/grext/main.py +++ b/grext/main.py @@ -44,10 +44,27 @@ class Extrapolator: self.descriptors.push(self._compute_descriptor(coords)) self.overlaps.push(overlap) - def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]): + def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray: """Get a new electronic density to be used as a guess.""" + + prev_descriptors = self.descriptors.get(self.nsteps) + gammas = self.gammas.get(self.nsteps) descriptor = self._compute_descriptor(coords) - coefficients = fitting.linear() + coefficients = fitting.linear(prev_descriptors, descriptor) + + gamma = fitting.linear_combination(gammas, coefficients) + + if overlap is None: + overlaps = self.overlaps.get(self.nsteps) + overlap = fitting.linear_combination(overlaps, coefficients) + inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) + else: + inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) + + c_guess = self._grassmann_exp(gamma) + c_guess = inverse_sqrt_overlap @ c_guess + + return c_guess @ c_guess.T def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" @@ -85,6 +102,11 @@ class Extrapolator: q, s, vt = np.linalg.svd(overlap, full_matrices=False) return q @ np.diag(np.sqrt(s)) @ vt + def _inverse_sqrt_overlap(self, overlap) -> np.ndarray: + """Compute the square root of the overlap matrix.""" + q, s, vt = np.linalg.svd(overlap, full_matrices=False) + return q @ np.diag(1.0/np.sqrt(s)) @ vt + def _compute_descriptor(self, coords) -> np.ndarray: """Given a set of coordinates compute the corresponding descriptor.""" return descriptors.distance(coords)