diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index 164ae8bca6eb317bb26132b2a731f501d10b3c19..f0ae573404705b11cfa65cb87d50facdfae97282 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -88,3 +88,44 @@ def test_quasi_time_reversible(datafile, regularization): fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) error = np.linalg.norm(target - fitted_target, ord=np.inf) assert error < THRESHOLD + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_time_reversibility(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting="qtr") + + # load data in the extrapolator + for (coords, coeff, overlap) in zip(data["trajectory"], + data["coefficients"], data["overlaps"]): + extrapolator.load_data(coords, coeff, overlap) + + descriptors = extrapolator.descriptors.get(10) + + # we symmetrize the future and past targets (remember that it is + # quasi time reversible, not exactly time reversible) + target = descriptors[0] + descriptors[-1] + descriptors[0] = target + descriptors[-1] = target + + fitting_calculator = extrapolator.fitting_calculator + + # fit the future target + fit_coefficients = fitting_calculator.fit(descriptors[:-1], descriptors[-1]) + fitted_target = fitting_calculator.linear_combination(descriptors[:-1], fit_coefficients) + + # fit the past target + reversed_descriptors = list(reversed(descriptors)) + fit_coefficients = fitting_calculator.fit(reversed_descriptors[:-1], reversed_descriptors[-1]) + fitted_target_reverse = fitting_calculator.linear_combination(reversed_descriptors[:-1], fit_coefficients) + + # check the time reversibility + assert np.linalg.norm(fitted_target - fitted_target_reverse, ord=np.inf) < SMALL