From 952b9d3f480fa630b2d239bd979973f6b9ff87ca Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Thu, 19 Oct 2023 18:08:01 +0200
Subject: [PATCH] Now a simple extrapolation is working.

---
 grext/fitting.py            |  6 ++++--
 grext/main.py               | 23 +++++++++++++-------
 tests/test_extrapolation.py | 43 +++++++++++++++++++++++++++++++++++++
 3 files changed, 62 insertions(+), 10 deletions(-)
 create mode 100644 tests/test_extrapolation.py

diff --git a/grext/fitting.py b/grext/fitting.py
index 9e4d26d..431581c 100644
--- a/grext/fitting.py
+++ b/grext/fitting.py
@@ -13,5 +13,7 @@ def quasi_time_reversible():
     """Time reversible least square minimization fitting."""
 
 def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
-    A = np.vstack(vectors).T
-    return A @ coefficients
+    result = np.zeros(vectors[0].shape, dtype=np.float64)
+    for coeff, vector in zip(coefficients, vectors):
+        result += vector*coeff
+    return result
diff --git a/grext/main.py b/grext/main.py
index e17766d..f12718a 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -17,20 +17,22 @@ class Extrapolator:
     value of 10."""
 
     def __init__(self, nelectrons: int, nbasis: int, natoms: int,
-            nsteps: int = 10):
+            nsteps: int = 10, **kwargs):
 
         self.nelectrons = nelectrons
         self.nbasis = nbasis
         self.natoms = natoms
         self.nsteps = nsteps
 
-        self.gammas = CircularBuffer(self.nsteps, (self.nelectrons, self.nbasis))
+        self.gammas = CircularBuffer(self.nsteps, (self.nelectrons//2, self.nbasis))
         self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
         self.descriptors = CircularBuffer(self.nsteps,
             ((self.natoms - 1)*self.natoms//2, ))
 
         self.tangent: Optional[np.ndarray] = None
 
+        self.options = kwargs
+
     def load_data(self, coords: np.ndarray, coeff: np.ndarray,
             overlap: np.ndarray):
         """Load a new data point in the extrapolator."""
@@ -44,19 +46,24 @@ class Extrapolator:
         self.descriptors.push(self._compute_descriptor(coords))
         self.overlaps.push(overlap)
 
-    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray:
+    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray] = None) -> np.ndarray:
         """Get a new electronic density to be used as a guess."""
 
         prev_descriptors = self.descriptors.get(self.nsteps)
-        gammas = self.gammas.get(self.nsteps)
         descriptor = self._compute_descriptor(coords)
-        coefficients = fitting.linear(prev_descriptors, descriptor)
+        fit_coefficients = fitting.linear(prev_descriptors, descriptor)
+
+        gammas = self.gammas.get(self.nsteps)
+        gamma = fitting.linear_combination(gammas, fit_coefficients)
+
+        fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
 
-        gamma = fitting.linear_combination(gammas, coefficients)
+        if self.options["verbose"]:
+            print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
 
         if overlap is None:
             overlaps = self.overlaps.get(self.nsteps)
-            overlap = fitting.linear_combination(overlaps, coefficients)
+            overlap = fitting.linear_combination(overlaps, fit_coefficients)
             inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
         else:
             inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
@@ -74,7 +81,7 @@ class Extrapolator:
 
     def _crop_coeff(self, coeff) -> np.ndarray:
         """Crop the coefficient matrix to remove the virtual orbitals."""
-        return coeff[:, :self.nelectrons]
+        return coeff[:, :self.nelectrons//2]
 
     def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
         """Normalize the coefficients such that C.T * C = 1, D * D = D."""
diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py
new file mode 100644
index 0000000..5fcb992
--- /dev/null
+++ b/tests/test_extrapolation.py
@@ -0,0 +1,43 @@
+import pytest
+import os
+import sys
+import numpy as np
+
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import grext
+import grext.grassmann
+import utils
+
+SMALL = 1e-10
+THRESHOLD = 1e-2
+
+@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
+def test_extrapolation(datafile):
+
+    # load test data from json file
+    data = utils.load_json(f"tests/{datafile}")
+    nelectrons = data["nelectrons"]
+    natoms = data["trajectory"].shape[1]
+    nbasis = data["overlaps"].shape[1]
+    nframes = data["trajectory"].shape[0]
+
+    # amount of data we want to use for fitting
+    n = 9
+    assert n < nframes
+
+    # initialize an extrapolator
+    extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, n)
+
+    # load data in the extrapolator up to index n - 1
+    for (coords, coeff, overlap) in zip(data["trajectory"][:n],
+            data["coefficients"][:n], data["overlaps"][:n]):
+        extrapolator.load_data(coords, coeff, overlap)
+
+    # check an extrapolation at index n
+    guessed_density = extrapolator.guess(data["trajectory"][n], data["overlaps"][n])
+    coeff = data["coefficients"][n][:, :nelectrons//2]
+    density = coeff @ coeff.T
+
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) \
+          /np.linalg.norm(density, ord=np.inf) < THRESHOLD
-- 
GitLab