diff --git a/grext/fitting.py b/grext/fitting.py index 99c607b517054c36b41089a270f5b3a97efde754..9e4d26dc6279a2a34b40deab21abebe9afab110f 100644 --- a/grext/fitting.py +++ b/grext/fitting.py @@ -1,7 +1,17 @@ """Module that defines fitting functions.""" -def linear(): +from typing import List +import numpy as np + +def linear(vectors: List[np.ndarray], target: np.ndarray): """Simple least square minimization fitting.""" + A = np.vstack(vectors).T + coefficients, _, _, _ = np.linalg.lstsq(A, target, rcond=None) + return np.array(coefficients, dtype=np.float64) def quasi_time_reversible(): """Time reversible least square minimization fitting.""" + +def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray: + A = np.vstack(vectors).T + return A @ coefficients diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py new file mode 100644 index 0000000000000000000000000000000000000000..c673bceb5d970f3e5ef45d0d07f2ad11cf5e62ac --- /dev/null +++ b/tests/test_descriptor_fitting.py @@ -0,0 +1,53 @@ +import pytest +import os +import sys +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +import grext +import grext.descriptors +import grext.fitting +import grext.grassmann +import utils + +SMALL = 1e-10 + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_descriptor_fitting(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # initialize an extrapolator + extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, nframes) + + # load data in the extrapolator + for (coords, coeff, overlap) in zip(data["trajectory"], + data["coefficients"], data["overlaps"]): + extrapolator.load_data(coords, coeff, overlap) + + # we check if the error goes down with a larger data set + errors = [] + descriptors = extrapolator.descriptors.get(10) + target = descriptors[-1] + + for start in range(0, 9): + vectors = descriptors[start:-1] + fit_coefficients = grext.fitting.linear(vectors, target) + fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients) + errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) + + assert errors[0] < errors[-1] + + # we check that we get a zero error if we put the target in the vectors + # used for the fitting + vectors = descriptors[:-1] + vectors[0] = target + fit_coefficients = grext.fitting.linear(vectors, target) + fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients) + + assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL