diff --git a/gext/fitting.py b/gext/fitting.py index 6b0e704d121143fe28d184fc9dd724fbd39e9145..53c1f8e1496485f6e75ae1a09f817d19511a6b20 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -60,8 +60,12 @@ class LeastSquare(AbstractFitting): def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Given a set of vectors and a target return the fitting coefficients.""" - matrix = np.vstack(vectors).T - coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) + matrix = np.array(vectors).T + A = matrix.T @ matrix + b = matrix.T @ target + if self.options["regularization"] > 0.0: + A += np.identity(len(b))*self.options["regularization"] + coefficients = np.linalg.solve(A, b) return np.array(coefficients, dtype=np.float64) class QuasiTimeReversible(AbstractFitting): diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index 94e1ca9dca1f47abbcf044176fcd8321d007b3ae..d457c1ec70e98afcda8b2d3e92fb19e3e1bad0dd 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -10,10 +10,11 @@ import gext.fitting import gext.grassmann import utils -SMALL = 1e-10 +SMALL = 1e-8 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) -def test_descriptor_fitting(datafile): +@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1]) +def test_descriptor_fitting(datafile, regularization): # load test data from json file data = utils.load_json(f"tests/{datafile}") @@ -23,7 +24,8 @@ def test_descriptor_fitting(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting_regularization=regularization) # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"],