Skip to content
Snippets Groups Projects
Commit 0f0a4592 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Implemented regularization.

parent e338089a
Branches
Tags
1 merge request!6QTR
......@@ -60,8 +60,12 @@ class LeastSquare(AbstractFitting):
def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting
coefficients."""
matrix = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None)
matrix = np.array(vectors).T
A = matrix.T @ matrix
b = matrix.T @ target
if self.options["regularization"] > 0.0:
A += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(A, b)
return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting):
......
......@@ -10,10 +10,11 @@ import gext.fitting
import gext.grassmann
import utils
SMALL = 1e-10
SMALL = 1e-8
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_descriptor_fitting(datafile):
@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1])
def test_descriptor_fitting(datafile, regularization):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
......@@ -23,7 +24,8 @@ def test_descriptor_fitting(datafile):
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes)
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting_regularization=regularization)
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment