Skip to content
Snippets Groups Projects
Commit cf50f260 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Lagrange.

parent 8b4a8c08
No related branches found
No related tags found
No related merge requests found
Pipeline #2260 failed
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import abc import abc
from typing import List from typing import List
import numpy as np import numpy as np
import scipy
class AbstractFitting(abc.ABC): class AbstractFitting(abc.ABC):
...@@ -13,7 +14,6 @@ class AbstractFitting(abc.ABC): ...@@ -13,7 +14,6 @@ class AbstractFitting(abc.ABC):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
@abc.abstractmethod
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Base method for setting options.""" """Base method for setting options."""
self.options = {} self.options = {}
...@@ -341,3 +341,30 @@ class PolynomialRegression(AbstractFitting): ...@@ -341,3 +341,30 @@ class PolynomialRegression(AbstractFitting):
if self.options["ref"]: if self.options["ref"]:
gamma += self.gamma_ref gamma += self.gamma_ref
return np.reshape(gamma, self.gamma_shape) return np.reshape(gamma, self.gamma_shape)
class LagrangeFitting(AbstractFitting):
supported_options = {}
def __init__(self):
super().__init__()
self.gammas = []
def train(self, descriptor_list: List[np.ndarray], gamma_list: List[np.ndarray]):
self.gammas = gamma_list
def extrapolate(self, _):
tokens = []
q = len(self.gammas)
tokens = np.array(tokens)
result = np.zeros(self.gammas[0].shape)
for i, gamma in enumerate(self.gammas):
l = 1.0
for m in range(1, q+1):
if m == i + 1:
continue
l *= (q+1-m)/(i+1-m)
result += l*gamma
return result
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np import numpy as np
from . import grassmann from . import grassmann
from .fitting import LeastSquare, QuasiTimeReversible, PolynomialRegression, DiffFitting from .fitting import LeastSquare, QuasiTimeReversible, PolynomialRegression, DiffFitting, LagrangeFitting
from .descriptors import Distance, Coulomb, Flatten from .descriptors import Distance, Coulomb, Flatten
from .buffer import CircularBuffer from .buffer import CircularBuffer
...@@ -92,6 +92,8 @@ class Extrapolator: ...@@ -92,6 +92,8 @@ class Extrapolator:
self.fitting_calculator = QuasiTimeReversible() self.fitting_calculator = QuasiTimeReversible()
elif self.options["fitting"] == "polynomialregression": elif self.options["fitting"] == "polynomialregression":
self.fitting_calculator = PolynomialRegression() self.fitting_calculator = PolynomialRegression()
elif self.options["fitting"] == "lagrange":
self.fitting_calculator = LagrangeFitting()
else: else:
raise ValueError("Unsupported fitting") raise ValueError("Unsupported fitting")
self.fitting_calculator.set_options(**fitting_options) self.fitting_calculator.set_options(**fitting_options)
...@@ -118,6 +120,34 @@ class Extrapolator: ...@@ -118,6 +120,34 @@ class Extrapolator:
c_guess = self.guess_coefficients(coords, overlap) c_guess = self.guess_coefficients(coords, overlap)
return c_guess @ c_guess.T return c_guess @ c_guess.T
def guess_no_mapping(self, coords: np.ndarray, overlap):
# check if we have enough data points to perform an extrapolation
count = self.descriptors.count
if self.options["allow_partially_filled"]:
if count == 0:
raise ValueError("Not enough data loaded in the extrapolator")
n = min(self.options["nsteps"], count)
else:
n = self.options["nsteps"]
if count < n:
raise ValueError("Not enough data loaded in the extrapolator")
if overlap is None and not self.options["store_overlap"]:
raise ValueError("Guessing without overlap requires `store_overlap` true.")
# get the required quantities
prev_descriptors = self.descriptors.get(n)
coefficients = self.coefficients.get(n)
ds = [c @ c.T for c in coefficients]
descriptor = self._compute_descriptor(coords)
self.fitting_calculator.train(prev_descriptors, ds)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
d = self.fitting_calculator.extrapolate(descriptor)
return inverse_sqrt_overlap @ d @ inverse_sqrt_overlap
def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray: def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess.""" """Get a new coefficient matrix to be used as a guess."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment