From 7dfaccc03864fdabf63fb7c27659739fea6af408 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Tue, 7 Nov 2023 14:30:54 +0100 Subject: [PATCH] Time reversible working. --- gext/fitting.py | 29 +++++++++++++++- gext/main.py | 2 +- tests/test_descriptor_fitting.py | 57 +++++++++++++++++++++++++------- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/gext/fitting.py b/gext/fitting.py index 53c1f8e..9422b15 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -86,4 +86,31 @@ class QuasiTimeReversible(AbstractFitting): def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Time reversible least square minimization fitting.""" - return np.zeros(0) + + past_target = vectors[0] + matrix = np.array(vectors[1:]).T + + q = matrix.shape[1] + if q == 1: + time_reversible_matrix = matrix + elif q%2 == 0: + time_reversible_matrix = matrix[:, :q//2] + matrix[:, :q//2-1:-1] + else: + time_reversible_matrix = matrix[:, :q//2+1] + matrix[:, :q//2-1:-1] + + A = time_reversible_matrix.T @ time_reversible_matrix + b = time_reversible_matrix.T @ (target + past_target) + + if self.options["regularization"] > 0.0: + A += np.identity(len(b))*self.options["regularization"] + coefficients = np.linalg.solve(A, b) + + if q == 1: + full_coefficients = np.concatenate(([-1.0], coefficients)) + elif q%2 == 0: + full_coefficients = np.concatenate(([-1.0], coefficients, + coefficients[::-1])) + else: + full_coefficients = np.concatenate(([-1.0], coefficients[:-1], + 2.0*coefficients[-1:], coefficients[-2::-1])) + return np.array(full_coefficients, dtype=np.float64) diff --git a/gext/main.py b/gext/main.py index 46d7e7f..18716ea 100644 --- a/gext/main.py +++ b/gext/main.py @@ -77,7 +77,7 @@ class Extrapolator: elif self.options["fitting"] == "qtr": self.fitting_calculator = QuasiTimeReversible() else: - raise ValueError("Unsupported descriptor") + raise ValueError("Unsupported fitting") self.fitting_calculator.set_options(**fitting_options) def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap): diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index d457c1e..164ae8b 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -11,10 +11,11 @@ import gext.grassmann import utils SMALL = 1e-8 +THRESHOLD = 5e-2 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) -@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1]) -def test_descriptor_fitting(datafile, regularization): +@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05]) +def test_least_square(datafile, regularization): # load test data from json file data = utils.load_json(f"tests/{datafile}") @@ -25,33 +26,65 @@ def test_descriptor_fitting(datafile, regularization): # initialize an extrapolator extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, - nsteps=nframes, fitting_regularization=regularization) + nsteps=nframes, fitting_regularization=regularization, + fitting="leastsquare") # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], data["coefficients"], data["overlaps"]): extrapolator.load_data(coords, coeff, overlap) - # we check if the error goes down with a larger data set - errors = [] descriptors = extrapolator.descriptors.get(10) target = descriptors[-1] - fitting_calculator = gext.fitting.LeastSquare() + fitting_calculator = extrapolator.fitting_calculator + # check if things are reasonable for start in range(0, 9): vectors = descriptors[start:-1] fit_coefficients = fitting_calculator.fit(vectors, target) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) - errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) + error = np.linalg.norm(target - fitted_target, ord=np.inf) + assert error < THRESHOLD - assert errors[0] < errors[-1] - - # we check that we get a zero error if we put the target in the vectors - # used for the fitting + # if we put the target in the vectors used for the fitting, + # check that we get an error smaller than the regularization vectors = descriptors[:-1] vectors[0] = target fit_coefficients = fitting_calculator.fit(vectors, target) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) - assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL + assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization) + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05]) +def test_quasi_time_reversible(datafile, regularization): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, + nsteps=nframes, fitting="qtr", fitting_regularization=regularization) + + # load data in the extrapolator + for (coords, coeff, overlap) in zip(data["trajectory"], + data["coefficients"], data["overlaps"]): + extrapolator.load_data(coords, coeff, overlap) + + descriptors = extrapolator.descriptors.get(10) + target = descriptors[-1] + + fitting_calculator = extrapolator.fitting_calculator + + # check if things are reasonable + for start in range(0, 8): + vectors = descriptors[start:-1] + fit_coefficients = fitting_calculator.fit(vectors, target) + fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) + error = np.linalg.norm(target - fitted_target, ord=np.inf) + assert error < THRESHOLD -- GitLab