From 0f0a45925205772d6e4cfe45f6574282ae7eb263 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Mon, 6 Nov 2023 15:19:48 +0100
Subject: [PATCH] Implemented regularization.

---
 gext/fitting.py                  | 8 ++++++--
 tests/test_descriptor_fitting.py | 8 +++++---
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/gext/fitting.py b/gext/fitting.py
index 6b0e704..53c1f8e 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -60,8 +60,12 @@ class LeastSquare(AbstractFitting):
     def fit(self, vectors: List[np.ndarray], target: np.ndarray):
         """Given a set of vectors and a target return the fitting
         coefficients."""
-        matrix = np.vstack(vectors).T
-        coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None)
+        matrix = np.array(vectors).T
+        A = matrix.T @ matrix
+        b = matrix.T @ target
+        if self.options["regularization"] > 0.0:
+            A += np.identity(len(b))*self.options["regularization"]
+        coefficients = np.linalg.solve(A, b)
         return np.array(coefficients, dtype=np.float64)
 
 class QuasiTimeReversible(AbstractFitting):
diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py
index 94e1ca9..d457c1e 100644
--- a/tests/test_descriptor_fitting.py
+++ b/tests/test_descriptor_fitting.py
@@ -10,10 +10,11 @@ import gext.fitting
 import gext.grassmann
 import utils
 
-SMALL = 1e-10
+SMALL = 1e-8
 
 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
-def test_descriptor_fitting(datafile):
+@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1])
+def test_descriptor_fitting(datafile, regularization):
 
     # load test data from json file
     data = utils.load_json(f"tests/{datafile}")
@@ -23,7 +24,8 @@ def test_descriptor_fitting(datafile):
     nframes = data["trajectory"].shape[0]
 
     # initialize an extrapolator
-    extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes)
+    extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
+        nsteps=nframes, fitting_regularization=regularization)
 
     # load data in the extrapolator
     for (coords, coeff, overlap) in zip(data["trajectory"],
-- 
GitLab