diff --git a/gext/main.py b/gext/main.py index 7b4e3e569ab6049c4f03622f8deb3b5ba8e2191c..1c9e692e22035a252444827029898617eb22232f 100644 --- a/gext/main.py +++ b/gext/main.py @@ -94,7 +94,12 @@ class Extrapolator: self.overlaps.push(overlap) def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: - """Get a new electronic density to be used as a guess.""" + """Get a new electronic density matrix to be used as a guess.""" + c_guess = self.guess_coefficients(coords, overlap) + return c_guess @ c_guess.T + + def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray: + """Get a new coefficient matrix to be used as a guess.""" if self.options["allow_partially_filled"]: n = min(self.options["nsteps"], self.descriptors.count) @@ -125,7 +130,7 @@ class Extrapolator: c_guess = self._grassmann_exp(gamma) c_guess = inverse_sqrt_overlap @ c_guess - return c_guess @ c_guess.T + return c_guess def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index ff3befa64684e27361cc6c51e5765895b5b7ab65..eed97dcd10f31c7bfaa8735cf1752e40c8ed1d44 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -73,3 +73,36 @@ def test_partial_extrapolation(datafile): assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) \ /np.linalg.norm(density, ord=np.inf) < THRESHOLD + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_coefficient_extrapolation(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # amount of data we want to use for fitting + n = 9 + assert n < nframes + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) + + # load data in the extrapolator up to index n - 1 + for (coords, coeff, overlap) in zip(data["trajectory"][:n], + data["coefficients"][:n], data["overlaps"][:n]): + extrapolator.load_data(coords, coeff, overlap) + + # check an extrapolation at index n + guessed_coefficients = extrapolator.guess_coefficients( + data["trajectory"][n], data["overlaps"][n]) + coeff = data["coefficients"][n][:, :nelectrons//2] + density = coeff @ coeff.T + guessed_density = guessed_coefficients @ guessed_coefficients.T + + assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD + assert np.linalg.norm(guessed_density - density, ord=np.inf) \ + /np.linalg.norm(density, ord=np.inf) < THRESHOLD