From e7af1061a03cbbca14f87bea7e344255b964f367 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Fri, 3 Nov 2023 14:53:23 +0100 Subject: [PATCH] Option for partially filled and tests. --- gext/main.py | 6 +++++- tests/test_extrapolation.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/gext/main.py b/gext/main.py index c7aac97..b4f4288 100644 --- a/gext/main.py +++ b/gext/main.py @@ -23,6 +23,7 @@ class Extrapolator: "nsteps": 6, "descriptor": "distance", "fitting": "leastsquare", + "allow_partially_filled": True, } self.nelectrons = nelectrons @@ -95,7 +96,10 @@ class Extrapolator: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: """Get a new electronic density to be used as a guess.""" - n = min(self.options["nsteps"], self.descriptors.count) + if self.options["allow_partially_filled"]: + n = min(self.options["nsteps"], self.descriptors.count) + else: + n = self.options["nsteps"] prev_descriptors = self.descriptors.get(n) descriptor = self._compute_descriptor(coords) diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index 4adc67a..ff3befa 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -41,3 +41,35 @@ def test_extrapolation(datafile): assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) \ /np.linalg.norm(density, ord=np.inf) < THRESHOLD + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_partial_extrapolation(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # amount of data we want to use for fitting + n = 9 + m = 5 + assert n < nframes + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) + + # load data in the extrapolator up to index n - 1 + for (coords, coeff, overlap) in zip(data["trajectory"][:m], + data["coefficients"][:m], data["overlaps"][:m]): + extrapolator.load_data(coords, coeff, overlap) + + # check an extrapolation at index n + guessed_density = extrapolator.guess(data["trajectory"][m], data["overlaps"][m]) + coeff = data["coefficients"][m][:, :nelectrons//2] + density = coeff @ coeff.T + + assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD + assert np.linalg.norm(guessed_density - density, ord=np.inf) \ + /np.linalg.norm(density, ord=np.inf) < THRESHOLD -- GitLab