Skip to content
Snippets Groups Projects
Commit e7af1061 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Option for partially filled and tests.

parent a748e52d
No related branches found
No related tags found
1 merge request!4Options
Pipeline #1950 passed
This commit is part of merge request !4. Comments created here will be created in the context of that merge request.
...@@ -23,6 +23,7 @@ class Extrapolator: ...@@ -23,6 +23,7 @@ class Extrapolator:
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
"fitting": "leastsquare", "fitting": "leastsquare",
"allow_partially_filled": True,
} }
self.nelectrons = nelectrons self.nelectrons = nelectrons
...@@ -95,7 +96,10 @@ class Extrapolator: ...@@ -95,7 +96,10 @@ class Extrapolator:
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density to be used as a guess."""
if self.options["allow_partially_filled"]:
n = min(self.options["nsteps"], self.descriptors.count) n = min(self.options["nsteps"], self.descriptors.count)
else:
n = self.options["nsteps"]
prev_descriptors = self.descriptors.get(n) prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
... ...
......
...@@ -41,3 +41,35 @@ def test_extrapolation(datafile): ...@@ -41,3 +41,35 @@ def test_extrapolation(datafile):
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \ assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD /np.linalg.norm(density, ord=np.inf) < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_partial_extrapolation(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
m = 5
assert n < nframes
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
# load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:m],
data["coefficients"][:m], data["overlaps"][:m]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index n
guessed_density = extrapolator.guess(data["trajectory"][m], data["overlaps"][m])
coeff = data["coefficients"][m][:, :nelectrons//2]
density = coeff @ coeff.T
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment