From edd1566c021720fcd0508cee01fb8f67e5e34302 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Fri, 3 Nov 2023 15:21:25 +0100 Subject: [PATCH] Added a new test. --- tests/test_extrapolation.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index ff3befa..eed97dc 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -73,3 +73,36 @@ def test_partial_extrapolation(datafile): assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) \ /np.linalg.norm(density, ord=np.inf) < THRESHOLD + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_coefficient_extrapolation(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # amount of data we want to use for fitting + n = 9 + assert n < nframes + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) + + # load data in the extrapolator up to index n - 1 + for (coords, coeff, overlap) in zip(data["trajectory"][:n], + data["coefficients"][:n], data["overlaps"][:n]): + extrapolator.load_data(coords, coeff, overlap) + + # check an extrapolation at index n + guessed_coefficients = extrapolator.guess_coefficients( + data["trajectory"][n], data["overlaps"][n]) + coeff = data["coefficients"][n][:, :nelectrons//2] + density = coeff @ coeff.T + guessed_density = guessed_coefficients @ guessed_coefficients.T + + assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD + assert np.linalg.norm(guessed_density - density, ord=np.inf) \ + /np.linalg.norm(density, ord=np.inf) < THRESHOLD -- GitLab