diff --git a/test_diff_fitting.py b/test_diff_fitting.py new file mode 100644 index 0000000000000000000000000000000000000000..0f25591a3f902c4bf978a82a99a26f6426741bdd --- /dev/null +++ b/test_diff_fitting.py @@ -0,0 +1,54 @@ +import os +import sys +import numpy as np + +import gext +import gext.descriptors +import gext.fitting +import gext.grassmann +from tests import utils + +SMALL = 1e-8 +THRESHOLD = 5e-2 + +regularization = 0.0 + +# load test data from json file +data = utils.load_json(f"tests/urea.json") +nelectrons = data["nelectrons"] +natoms = data["trajectory"].shape[1] +nbasis = data["overlaps"].shape[1] +nframes = data["trajectory"].shape[0] + +# initialize an extrapolator +extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting_regularization=regularization, + fitting="diff") + +# load data in the extrapolator +for (coords, coeff, overlap) in zip(data["trajectory"], + data["coefficients"], data["overlaps"]): + extrapolator.load_data(coords, coeff, overlap) + +descriptors = extrapolator.descriptors.get(10) +target = descriptors[-1] + +fitting_calculator = extrapolator.fitting_calculator + +# check if things are reasonable +for start in range(0, 8): + vectors = descriptors[start:-1] + fit_coefficients = fitting_calculator.fit(vectors, target) + fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) + error = np.linalg.norm(target - fitted_target, ord=np.inf) + assert error < THRESHOLD + +# if we put the target in the vectors used for the fitting, +# check that we get an error smaller than the regularization +vectors = descriptors[:-1] +vectors[0] = target +fit_coefficients = fitting_calculator.fit(vectors, target) +fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) + +print(np.linalg.norm(target - fitted_target, ord=np.inf), max(SMALL, regularization)) +assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization)