From 44b657239b8f662b39ba8f11457eb6b8c813586a Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Thu, 19 Oct 2023 17:34:33 +0200
Subject: [PATCH] Added the guess method.

---
 grext/main.py | 26 ++++++++++++++++++++++++--
 1 file changed, 24 insertions(+), 2 deletions(-)

diff --git a/grext/main.py b/grext/main.py
index 1bf9adc..e17766d 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -44,10 +44,27 @@ class Extrapolator:
         self.descriptors.push(self._compute_descriptor(coords))
         self.overlaps.push(overlap)
 
-    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]):
+    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray:
         """Get a new electronic density to be used as a guess."""
+
+        prev_descriptors = self.descriptors.get(self.nsteps)
+        gammas = self.gammas.get(self.nsteps)
         descriptor = self._compute_descriptor(coords)
-        coefficients = fitting.linear()
+        coefficients = fitting.linear(prev_descriptors, descriptor)
+
+        gamma = fitting.linear_combination(gammas, coefficients)
+
+        if overlap is None:
+            overlaps = self.overlaps.get(self.nsteps)
+            overlap = fitting.linear_combination(overlaps, coefficients)
+            inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
+        else:
+            inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
+
+        c_guess = self._grassmann_exp(gamma)
+        c_guess = inverse_sqrt_overlap @ c_guess
+
+        return c_guess @ c_guess.T
 
     def _get_tangent(self) -> np.ndarray:
         """Get the tangent point."""
@@ -85,6 +102,11 @@ class Extrapolator:
         q, s, vt = np.linalg.svd(overlap, full_matrices=False)
         return q @ np.diag(np.sqrt(s)) @ vt
 
+    def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
+        """Compute the square root of the overlap matrix."""
+        q, s, vt = np.linalg.svd(overlap, full_matrices=False)
+        return q @ np.diag(1.0/np.sqrt(s)) @ vt
+
     def _compute_descriptor(self, coords) -> np.ndarray:
         """Given a set of coordinates compute the corresponding descriptor."""
         return descriptors.distance(coords)
-- 
GitLab