From 84b04f6803cccf1133dea8755d4f44eec00b9348 Mon Sep 17 00:00:00 2001 From: Zahra Askarpour <Zahra.Askarpour@mathematik.uni-stuttgart.de> Date: Thu, 15 Feb 2024 12:06:40 +0100 Subject: [PATCH] test_fitting works --- test_diff_fitting.py | 54 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 test_diff_fitting.py diff --git a/test_diff_fitting.py b/test_diff_fitting.py new file mode 100644 index 0000000..0f25591 --- /dev/null +++ b/test_diff_fitting.py @@ -0,0 +1,54 @@ +import os +import sys +import numpy as np + +import gext +import gext.descriptors +import gext.fitting +import gext.grassmann +from tests import utils + +SMALL = 1e-8 +THRESHOLD = 5e-2 + +regularization = 0.0 + +# load test data from json file +data = utils.load_json(f"tests/urea.json") +nelectrons = data["nelectrons"] +natoms = data["trajectory"].shape[1] +nbasis = data["overlaps"].shape[1] +nframes = data["trajectory"].shape[0] + +# initialize an extrapolator +extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting_regularization=regularization, + fitting="diff") + +# load data in the extrapolator +for (coords, coeff, overlap) in zip(data["trajectory"], + data["coefficients"], data["overlaps"]): + extrapolator.load_data(coords, coeff, overlap) + +descriptors = extrapolator.descriptors.get(10) +target = descriptors[-1] + +fitting_calculator = extrapolator.fitting_calculator + +# check if things are reasonable +for start in range(0, 8): + vectors = descriptors[start:-1] + fit_coefficients = fitting_calculator.fit(vectors, target) + fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) + error = np.linalg.norm(target - fitted_target, ord=np.inf) + assert error < THRESHOLD + +# if we put the target in the vectors used for the fitting, +# check that we get an error smaller than the regularization +vectors = descriptors[:-1] +vectors[0] = target +fit_coefficients = fitting_calculator.fit(vectors, target) +fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) + +print(np.linalg.norm(target - fitted_target, ord=np.inf), max(SMALL, regularization)) +assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization) -- GitLab