From 0f0a45925205772d6e4cfe45f6574282ae7eb263 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Mon, 6 Nov 2023 15:19:48 +0100 Subject: [PATCH] Implemented regularization. --- gext/fitting.py | 8 ++++++-- tests/test_descriptor_fitting.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/gext/fitting.py b/gext/fitting.py index 6b0e704..53c1f8e 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -60,8 +60,12 @@ class LeastSquare(AbstractFitting): def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Given a set of vectors and a target return the fitting coefficients.""" - matrix = np.vstack(vectors).T - coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) + matrix = np.array(vectors).T + A = matrix.T @ matrix + b = matrix.T @ target + if self.options["regularization"] > 0.0: + A += np.identity(len(b))*self.options["regularization"] + coefficients = np.linalg.solve(A, b) return np.array(coefficients, dtype=np.float64) class QuasiTimeReversible(AbstractFitting): diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index 94e1ca9..d457c1e 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -10,10 +10,11 @@ import gext.fitting import gext.grassmann import utils -SMALL = 1e-10 +SMALL = 1e-8 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) -def test_descriptor_fitting(datafile): +@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1]) +def test_descriptor_fitting(datafile, regularization): # load test data from json file data = utils.load_json(f"tests/{datafile}") @@ -23,7 +24,8 @@ def test_descriptor_fitting(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting_regularization=regularization) # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], -- GitLab