diff --git a/gext/fitting.py b/gext/fitting.py index 4d058aaa349e2e3d764bafc4b9c64b435afcdbd3..94f4e871c9fd221fdbae3009c1accb3b5562200e 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -16,8 +16,9 @@ class AbstractFitting(abc.ABC): """Base method for setting options.""" @abc.abstractmethod - def compute(self, vectors: List[np.ndarray], target:np.ndarray): + def fit(self, vectors: List[np.ndarray], target:np.ndarray): """Base method for computing new fitting coefficients.""" + return np.zeros(0) def linear_combination(self, vectors: List[np.ndarray], coefficients: np. ndarray) -> np.ndarray: @@ -53,7 +54,7 @@ class LeastSquare(AbstractFitting): or self.options["regularization"] > 100: raise ValueError("Unsupported value for regularization") - def compute(self, vectors: List[np.ndarray], target: np.ndarray): + def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Given a set of vectors and a target return the fitting coefficients.""" matrix = np.vstack(vectors).T @@ -67,5 +68,6 @@ class QuasiTimeReversible(AbstractFitting): def set_options(self, **kwargs): """Set options for quasi time reversible fitting""" - def compute(self, vectors: List[np.ndarray], target: np.ndarray): + def fit(self, vectors: List[np.ndarray], target: np.ndarray): """Time reversible least square minimization fitting.""" + return np.zeros(0) diff --git a/gext/main.py b/gext/main.py index 1c9e692e22035a252444827029898617eb22232f..e87e37ef8e123c7eb25516a4f2c46afb6fa5ddb0 100644 --- a/gext/main.py +++ b/gext/main.py @@ -61,7 +61,7 @@ class Extrapolator: if not option in self.options: self.options[option] = default_value - if self.options["nsteps"] <= 1 or self.options["nsteps"] >= 100: + if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100: raise ValueError("Unsupported nsteps") if self.options["descriptor"] == "distance": @@ -101,25 +101,35 @@ class Extrapolator: def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray: """Get a new coefficient matrix to be used as a guess.""" + # check if we have enough data points to perform an extrapolation + count = self.descriptors.count if self.options["allow_partially_filled"]: - n = min(self.options["nsteps"], self.descriptors.count) + if count == 0: + raise ValueError("Not enough data loaded in the extrapolator") + n = min(self.options["nsteps"], count) else: n = self.options["nsteps"] + if count < n: + raise ValueError("Not enough data loaded in the extrapolator") + # use the descriptors to find the fitting coefficients prev_descriptors = self.descriptors.get(n) descriptor = self._compute_descriptor(coords) - fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor) + fit_coefficients = self._fit(prev_descriptors, descriptor) + # use the fitting coefficients and the previous gammas to + # extrapolate a new gamma gammas = self.gammas.get(n) gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients) - fit_descriptor = self.fitting_calculator.linear_combination( - prev_descriptors, fit_coefficients) - if self.options["verbose"]: + fit_descriptor = self.fitting_calculator.linear_combination( + prev_descriptors, fit_coefficients) print("error on descriptor:", \ np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) + # if the overlap is not given, use the coefficients to fit + # a new overlap if overlap is None: overlaps = self.overlaps.get(n) overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients) @@ -127,10 +137,9 @@ class Extrapolator: else: inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) + # use the overlap and gamma to find a new set of coefficients c_guess = self._grassmann_exp(gamma) - c_guess = inverse_sqrt_overlap @ c_guess - - return c_guess + return inverse_sqrt_overlap @ c_guess def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" @@ -176,3 +185,8 @@ class Extrapolator: def _compute_descriptor(self, coords) -> np.ndarray: """Given a set of coordinates compute the corresponding descriptor.""" return self.descriptor_calculator.compute(coords) + + def _fit(self, prev_descriptors, descriptor) -> np.ndarray: + """Fit the current descriptor using previous descriptors and + the specified fitting scheme.""" + return self.fitting_calculator.fit(prev_descriptors, descriptor) diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index cc8c42be6ea39e764fd41a17308e238b26fbc8d9..94e1ca9dca1f47abbcf044176fcd8321d007b3ae 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -39,7 +39,7 @@ def test_descriptor_fitting(datafile): for start in range(0, 9): vectors = descriptors[start:-1] - fit_coefficients = fitting_calculator.compute(vectors, target) + fit_coefficients = fitting_calculator.fit(vectors, target) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) @@ -49,7 +49,7 @@ def test_descriptor_fitting(datafile): # used for the fitting vectors = descriptors[:-1] vectors[0] = target - fit_coefficients = fitting_calculator.compute(vectors, target) + fit_coefficients = fitting_calculator.fit(vectors, target) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index eed97dcd10f31c7bfaa8735cf1752e40c8ed1d44..a727c4c106d120b347da0a9e546350d0877c92be 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -106,3 +106,37 @@ def test_coefficient_extrapolation(datafile): assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) \ /np.linalg.norm(density, ord=np.inf) < THRESHOLD + +@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) +def test_errors(datafile): + + # load test data from json file + data = utils.load_json(f"tests/{datafile}") + nelectrons = data["nelectrons"] + natoms = data["trajectory"].shape[1] + nbasis = data["overlaps"].shape[1] + nframes = data["trajectory"].shape[0] + + # amount of data we want to use for fitting + n = 9 + assert n < nframes + + # initialize an extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) + + with pytest.raises(ValueError): + extrapolator.guess(data["trajectory"][0]) + + # initialize a new extrapolator + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n, + allow_partially_filled=False) + + # load data in the extrapolator up to index m - 1 + m = 4 + for (coords, coeff, overlap) in zip(data["trajectory"][:m], + data["coefficients"][:m], data["overlaps"][:m]): + extrapolator.load_data(coords, coeff, overlap) + + # check an extrapolation at index m + with pytest.raises(ValueError): + extrapolator.guess(data["trajectory"][m])