Skip to content
Snippets Groups Projects
Commit acb93543 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Merge branch 'coefficients' into 'main'

Guess for coefficients instead of density

See merge request !5
parents b5cbeecb edd1566c
Branches
Tags
1 merge request!5Guess for coefficients instead of density
Pipeline #1954 passed
...@@ -94,7 +94,12 @@ class Extrapolator: ...@@ -94,7 +94,12 @@ class Extrapolator:
self.overlaps.push(overlap) self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density matrix to be used as a guess."""
c_guess = self.guess_coefficients(coords, overlap)
return c_guess @ c_guess.T
def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess."""
if self.options["allow_partially_filled"]: if self.options["allow_partially_filled"]:
n = min(self.options["nsteps"], self.descriptors.count) n = min(self.options["nsteps"], self.descriptors.count)
...@@ -125,7 +130,7 @@ class Extrapolator: ...@@ -125,7 +130,7 @@ class Extrapolator:
c_guess = self._grassmann_exp(gamma) c_guess = self._grassmann_exp(gamma)
c_guess = inverse_sqrt_overlap @ c_guess c_guess = inverse_sqrt_overlap @ c_guess
return c_guess @ c_guess.T return c_guess
def _get_tangent(self) -> np.ndarray: def _get_tangent(self) -> np.ndarray:
"""Get the tangent point.""" """Get the tangent point."""
......
...@@ -73,3 +73,36 @@ def test_partial_extrapolation(datafile): ...@@ -73,3 +73,36 @@ def test_partial_extrapolation(datafile):
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \ assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD /np.linalg.norm(density, ord=np.inf) < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_coefficient_extrapolation(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
assert n < nframes
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
# load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:n],
data["coefficients"][:n], data["overlaps"][:n]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index n
guessed_coefficients = extrapolator.guess_coefficients(
data["trajectory"][n], data["overlaps"][n])
coeff = data["coefficients"][n][:, :nelectrons//2]
density = coeff @ coeff.T
guessed_density = guessed_coefficients @ guessed_coefficients.T
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment