Skip to content
Snippets Groups Projects
Commit bfb73bcb authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Updated errors.

parent 4e708ccc
No related branches found
No related tags found
1 merge request!6QTR
......@@ -16,8 +16,9 @@ class AbstractFitting(abc.ABC):
"""Base method for setting options."""
@abc.abstractmethod
def compute(self, vectors: List[np.ndarray], target:np.ndarray):
def fit(self, vectors: List[np.ndarray], target:np.ndarray):
"""Base method for computing new fitting coefficients."""
return np.zeros(0)
def linear_combination(self, vectors: List[np.ndarray],
coefficients: np. ndarray) -> np.ndarray:
......@@ -53,7 +54,7 @@ class LeastSquare(AbstractFitting):
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def compute(self, vectors: List[np.ndarray], target: np.ndarray):
def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting
coefficients."""
matrix = np.vstack(vectors).T
......@@ -67,5 +68,6 @@ class QuasiTimeReversible(AbstractFitting):
def set_options(self, **kwargs):
"""Set options for quasi time reversible fitting"""
def compute(self, vectors: List[np.ndarray], target: np.ndarray):
def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Time reversible least square minimization fitting."""
return np.zeros(0)
......@@ -61,7 +61,7 @@ class Extrapolator:
if not option in self.options:
self.options[option] = default_value
if self.options["nsteps"] <= 1 or self.options["nsteps"] >= 100:
if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100:
raise ValueError("Unsupported nsteps")
if self.options["descriptor"] == "distance":
......@@ -101,25 +101,35 @@ class Extrapolator:
def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess."""
# check if we have enough data points to perform an extrapolation
count = self.descriptors.count
if self.options["allow_partially_filled"]:
n = min(self.options["nsteps"], self.descriptors.count)
if count == 0:
raise ValueError("Not enough data loaded in the extrapolator")
n = min(self.options["nsteps"], count)
else:
n = self.options["nsteps"]
if count < n:
raise ValueError("Not enough data loaded in the extrapolator")
# use the descriptors to find the fitting coefficients
prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords)
fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor)
fit_coefficients = self._fit(prev_descriptors, descriptor)
# use the fitting coefficients and the previous gammas to
# extrapolate a new gamma
gammas = self.gammas.get(n)
gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
fit_descriptor = self.fitting_calculator.linear_combination(
prev_descriptors, fit_coefficients)
if self.options["verbose"]:
fit_descriptor = self.fitting_calculator.linear_combination(
prev_descriptors, fit_coefficients)
print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
# if the overlap is not given, use the coefficients to fit
# a new overlap
if overlap is None:
overlaps = self.overlaps.get(n)
overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
......@@ -127,10 +137,9 @@ class Extrapolator:
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
# use the overlap and gamma to find a new set of coefficients
c_guess = self._grassmann_exp(gamma)
c_guess = inverse_sqrt_overlap @ c_guess
return c_guess
return inverse_sqrt_overlap @ c_guess
def _get_tangent(self) -> np.ndarray:
"""Get the tangent point."""
......@@ -176,3 +185,8 @@ class Extrapolator:
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return self.descriptor_calculator.compute(coords)
def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
"""Fit the current descriptor using previous descriptors and
the specified fitting scheme."""
return self.fitting_calculator.fit(prev_descriptors, descriptor)
......@@ -39,7 +39,7 @@ def test_descriptor_fitting(datafile):
for start in range(0, 9):
vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.compute(vectors, target)
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
......@@ -49,7 +49,7 @@ def test_descriptor_fitting(datafile):
# used for the fitting
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = fitting_calculator.compute(vectors, target)
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
......@@ -106,3 +106,37 @@ def test_coefficient_extrapolation(datafile):
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_errors(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
assert n < nframes
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
with pytest.raises(ValueError):
extrapolator.guess(data["trajectory"][0])
# initialize a new extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n,
allow_partially_filled=False)
# load data in the extrapolator up to index m - 1
m = 4
for (coords, coeff, overlap) in zip(data["trajectory"][:m],
data["coefficients"][:m], data["overlaps"][:m]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index m
with pytest.raises(ValueError):
extrapolator.guess(data["trajectory"][m])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment