Skip to content
Snippets Groups Projects
Commit e7af1061 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Option for partially filled and tests.

parent a748e52d
Branches
Tags
1 merge request!4Options
Pipeline #1950 passed
...@@ -23,6 +23,7 @@ class Extrapolator: ...@@ -23,6 +23,7 @@ class Extrapolator:
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
"fitting": "leastsquare", "fitting": "leastsquare",
"allow_partially_filled": True,
} }
self.nelectrons = nelectrons self.nelectrons = nelectrons
...@@ -95,7 +96,10 @@ class Extrapolator: ...@@ -95,7 +96,10 @@ class Extrapolator:
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density to be used as a guess."""
if self.options["allow_partially_filled"]:
n = min(self.options["nsteps"], self.descriptors.count) n = min(self.options["nsteps"], self.descriptors.count)
else:
n = self.options["nsteps"]
prev_descriptors = self.descriptors.get(n) prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
......
...@@ -41,3 +41,35 @@ def test_extrapolation(datafile): ...@@ -41,3 +41,35 @@ def test_extrapolation(datafile):
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \ assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD /np.linalg.norm(density, ord=np.inf) < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_partial_extrapolation(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
m = 5
assert n < nframes
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
# load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:m],
data["coefficients"][:m], data["overlaps"][:m]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index n
guessed_density = extrapolator.guess(data["trajectory"][m], data["overlaps"][m])
coeff = data["coefficients"][m][:, :nelectrons//2]
density = coeff @ coeff.T
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment