diff --git a/gext/fitting.py b/gext/fitting.py index b0a988f609fddbf633fd170bf2d0b19dac0c9d20..ca0730a074d3dfd464fe4c014fe37072b4274530 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -3,6 +3,7 @@ import abc from typing import List import numpy as np +import scipy class AbstractFitting(abc.ABC): @@ -13,7 +14,6 @@ class AbstractFitting(abc.ABC): def __init__(self, **kwargs): self.set_options(**kwargs) - @abc.abstractmethod def set_options(self, **kwargs): """Base method for setting options.""" self.options = {} @@ -341,3 +341,30 @@ class PolynomialRegression(AbstractFitting): if self.options["ref"]: gamma += self.gamma_ref return np.reshape(gamma, self.gamma_shape) + +class LagrangeFitting(AbstractFitting): + + supported_options = {} + + def __init__(self): + super().__init__() + self.gammas = [] + + def train(self, descriptor_list: List[np.ndarray], gamma_list: List[np.ndarray]): + self.gammas = gamma_list + + def extrapolate(self, _): + tokens = [] + q = len(self.gammas) + tokens = np.array(tokens) + + result = np.zeros(self.gammas[0].shape) + for i, gamma in enumerate(self.gammas): + l = 1.0 + for m in range(1, q+1): + if m == i + 1: + continue + l *= (q+1-m)/(i+1-m) + result += l*gamma + return result + diff --git a/gext/main.py b/gext/main.py index e442371353849704139368cf22f754622e095174..154264acfaae22811c7c5af65954aa3cf291f85f 100644 --- a/gext/main.py +++ b/gext/main.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np from . import grassmann -from .fitting import LeastSquare, QuasiTimeReversible, PolynomialRegression, DiffFitting +from .fitting import LeastSquare, QuasiTimeReversible, PolynomialRegression, DiffFitting, LagrangeFitting from .descriptors import Distance, Coulomb, Flatten from .buffer import CircularBuffer @@ -92,6 +92,8 @@ class Extrapolator: self.fitting_calculator = QuasiTimeReversible() elif self.options["fitting"] == "polynomialregression": self.fitting_calculator = PolynomialRegression() + elif self.options["fitting"] == "lagrange": + self.fitting_calculator = LagrangeFitting() else: raise ValueError("Unsupported fitting") self.fitting_calculator.set_options(**fitting_options) @@ -118,6 +120,34 @@ class Extrapolator: c_guess = self.guess_coefficients(coords, overlap) return c_guess @ c_guess.T + def guess_no_mapping(self, coords: np.ndarray, overlap): + # check if we have enough data points to perform an extrapolation + count = self.descriptors.count + if self.options["allow_partially_filled"]: + if count == 0: + raise ValueError("Not enough data loaded in the extrapolator") + n = min(self.options["nsteps"], count) + else: + n = self.options["nsteps"] + if count < n: + raise ValueError("Not enough data loaded in the extrapolator") + + if overlap is None and not self.options["store_overlap"]: + raise ValueError("Guessing without overlap requires `store_overlap` true.") + + # get the required quantities + prev_descriptors = self.descriptors.get(n) + coefficients = self.coefficients.get(n) + + ds = [c @ c.T for c in coefficients] + descriptor = self._compute_descriptor(coords) + self.fitting_calculator.train(prev_descriptors, ds) + inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) + d = self.fitting_calculator.extrapolate(descriptor) + + return inverse_sqrt_overlap @ d @ inverse_sqrt_overlap + + def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray: """Get a new coefficient matrix to be used as a guess."""