From b1c6167f2d2ef70f65cb17ed38638ad0bb375a3c Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Mon, 4 Mar 2024 14:13:39 +0100
Subject: [PATCH] More general tangent + check tangent.

---
 gext/main.py            | 58 ++++++++++++++++++-----------------------
 tests/test_grassmann.py |  7 ++---
 tests/test_tangent.py   | 54 ++++++++++++++++++++++++++++++++++++++
 3 files changed, 83 insertions(+), 36 deletions(-)
 create mode 100644 tests/test_tangent.py

diff --git a/gext/main.py b/gext/main.py
index 066d6a3..4a16d51 100644
--- a/gext/main.py
+++ b/gext/main.py
@@ -21,7 +21,7 @@ class Extrapolator:
         "fitting": "diff",
         "allow_partially_filled": True,
         "store_overlap": True,
-        "tangent": "one_before_last",
+        "tangent": "fixed",
     }
 
     def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
@@ -35,12 +35,10 @@ class Extrapolator:
         self.natoms = natoms
         self.set_options(**kwargs)
 
-        self.gammas = CircularBuffer(self.options["nsteps"])
         self.descriptors = CircularBuffer(self.options["nsteps"])
         if self.options["store_overlap"]:
             self.overlaps = CircularBuffer(self.options["nsteps"])
-        if self.options["tangent"]=="one_before_last":
-            self.coeffs = CircularBuffer(self.options["nsteps"])
+        self.coeffs = CircularBuffer(self.options["nsteps"])
         self.tangent: Optional[np.ndarray] = None
         self.H_cores = CircularBuffer(self.options["nsteps"])
 
@@ -74,6 +72,14 @@ class Extrapolator:
         if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100:
             raise ValueError("Unsupported nsteps")
 
+        if isinstance(self.options["tangent"], int):
+            if self.options["tangent"] < -self.options["nsteps"] or \
+                    self.options["tangent"] >= self.options["nsteps"]:
+                raise ValueError("Unsupported tangent")
+        else:
+            if self.options["tangent"] != "fixed":
+                raise ValueError("Unsupported tangent")
+
         if self.options["descriptor"] == "distance":
             self.descriptor_calculator = Distance()
         elif self.options["descriptor"] == "coulomb":
@@ -102,16 +108,12 @@ class Extrapolator:
         # pairs, then apply S^1/2
         coeff = self._crop_coeff(coeff)
         coeff = self._normalize(coeff, overlap)
+        self.coeffs.push(coeff)
 
-        # if it is the first time we load data, set the tangent point
-        if self.tangent is None and self.options["tangent"] != "one_before_last":
+        # If working with a fixed tangent, set it the first time we load data
+        if self.options["tangent"] == "fixed" and self.tangent is None:
             self._set_tangent(coeff)
 
-        if self.options["tangent"]=="one_before_last":
-            self.coeffs.push(coeff)
-        else:
-            self.gammas.push(self._grassmann_log(coeff))
-
         # push the new data to the corresponding vectors
         self.descriptors.push(self._compute_descriptor(descriptor_input))
 
@@ -149,19 +151,11 @@ class Extrapolator:
 
         # use the fitting coefficients and the previous gammas to
         # extrapolate a new gamma
+        coeffs=self.coeffs.get(n)
+        gammas=[]
+        for i in range(len(coeffs)):
+            gammas.append(self._grassmann_log(coeffs[i]))
 
-        if self.options["tangent"]=="one_before_last":
-            coeffs=self.coeffs.get(n)
-
-            self._set_tangent(coeffs[-1])
-            gammas=[]
-            for i in range(len(coeffs)):
-                gammas.append(self._grassmann_log(coeffs[i]))
-            print('maxgamma_last', np.max(gammas[-1]))
-        else:
-            gammas = self.gammas.get(n)
-
-            print('maxgamm', np.max(gammas[0]))
         gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
 
         if self.options["verbose"]:
@@ -185,9 +179,13 @@ class Extrapolator:
 
     def _get_tangent(self) -> np.ndarray:
         """Get the tangent point."""
-        if self.tangent is not None:
-            return self.tangent
-        raise ValueError("Tangent point is not set.")
+        if self.options["tangent"] == "fixed":
+            if self.tangent is not None:
+                return self.tangent
+            raise ValueError("Tangent point is not set.")
+        n = min(self.options["nsteps"], self.coeffs.count)
+        coefficients = self.coeffs.get(n)
+        return coefficients[self.options["tangent"]]
 
     def _crop_coeff(self, coeff) -> np.ndarray:
         """Crop the coefficient matrix to remove the virtual orbitals."""
@@ -199,13 +197,7 @@ class Extrapolator:
 
     def _set_tangent(self, c: np.ndarray):
         """Set the tangent point."""
-        if self.options["tangent"]=="one_before_last":
-            self.tangent = c
-        else:
-            if self.tangent is None:
-                self.tangent = c
-            else:
-                raise ValueError("Resetting the tangent.")
+        self.tangent = c
 
     def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
         """Map from the manifold to the tangent plane."""
diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py
index 663220e..a2a6e83 100644
--- a/tests/test_grassmann.py
+++ b/tests/test_grassmann.py
@@ -22,7 +22,7 @@ def test_grassmann(datafile):
 
     # initialize an extrapolator
     extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes,
-        tangent="one")
+        tangent="fixed")
 
     # load data in the extrapolator
     for (coords, coeff, overlap) in zip(data["trajectory"],
@@ -31,8 +31,7 @@ def test_grassmann(datafile):
 
     # check the Grassmann projections
     c0 = extrapolator._get_tangent()
-    for (coeff, gamma, overlap) in zip(data["coefficients"],
-            extrapolator.gammas.get(nframes), extrapolator.overlaps.get(nframes)):
+    for (coeff, overlap) in zip(data["coefficients"], extrapolator.overlaps.get(nframes)):
 
         sqrt_overlap = extrapolator._sqrt_overlap(overlap)
 
@@ -44,6 +43,8 @@ def test_grassmann(datafile):
         assert np.linalg.norm(d - d @ d, ord=np.inf) < SMALL
         assert np.trace(d) - nelectrons < SMALL
 
+        gamma = gext.grassmann.log(coeff, c0)
+
         # compute the density from the inverse Grassmann map: Exp(Log(D))
         coeff_1 = gext.grassmann.exp(gamma, c0)
         d_1 = coeff_1 @ coeff_1.T
diff --git a/tests/test_tangent.py b/tests/test_tangent.py
new file mode 100644
index 0000000..393267a
--- /dev/null
+++ b/tests/test_tangent.py
@@ -0,0 +1,54 @@
+import pytest
+import os
+import sys
+import numpy as np
+
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import gext
+import gext.grassmann
+import utils
+
+SMALL = 1e-10
+THRESHOLD = 5e-2
+
+@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
+@pytest.mark.parametrize("tangent", ["fixed", 0, 2, 4, 6, 8, -1, -2])
+def test_tangent(datafile, tangent):
+
+    # load test data from json file
+    data = utils.load_json(f"tests/{datafile}")
+    nelectrons = data["nelectrons"]
+    natoms = data["trajectory"].shape[1]
+    nbasis = data["overlaps"].shape[1]
+    nframes = data["trajectory"].shape[0]
+
+    # amount of data we want to use for fitting
+    n = 9
+    assert n < nframes
+
+    # initialize an extrapolator
+    extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
+        nsteps=n, fitting="leastsquare", fitting_regularization=0.0,
+        descriptor="distance", tangent=tangent)
+
+    # load data in the extrapolator up to index n - 1
+    for (coords, coeff, overlap) in zip(data["trajectory"][:n],
+            data["coefficients"][:n], data["overlaps"][:n]):
+        extrapolator.load_data(coords, coeff, overlap)
+
+    # check an extrapolation at index n
+    guessed_density = extrapolator.guess(data["trajectory"][n], data["overlaps"][n])
+    coeff = data["coefficients"][n][:, :nelectrons//2]
+    density = coeff @ coeff.T
+
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) \
+          /np.linalg.norm(density, ord=np.inf) < THRESHOLD
+
+    coeffs = extrapolator.coeffs.get(n)
+    gammas = []
+    for i in range(len(coeffs)):
+        gammas.append(extrapolator._grassmann_log(coeffs[i]))
+
+    if tangent != "fixed":
+        assert np.linalg.norm(gammas[tangent]) < SMALL
-- 
GitLab