Skip to content
Snippets Groups Projects
Commit a748e52d authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Options working.

parent e32f573f
Branches
Tags
1 merge request!4Options
"""Module which provides functions to compute descriptors.""" """Module which provides functionality to compute descriptors."""
import numpy as np import numpy as np
from scipy.spatial.distance import pdist from scipy.spatial.distance import pdist
class Distance(): class Distance:
"""Distance matrix descriptors.""" """Distance matrix descriptors."""
......
"""Module that defines fitting functions.""" """Module which provides functionality to perform fitting."""
from typing import List from typing import List
import numpy as np import numpy as np
import abc
class AbstractFitting(abc.ABC):
def __init__(self, **kwargs):
self.set_options(**kwargs)
@abc.abstractmethod
def set_options(self, **kwargs):
"""Base method for setting options."""
@abc.abstractmethod
def compute(self, vectors: List[np.ndarray], target:np.ndarray):
"""Base method for computing new fitting coefficients."""
def linear_combination(self, vectors: List[np.ndarray],
coefficients: np. ndarray) -> np.ndarray:
"""Given a set of vectors (or matrices) and the corresponding
coefficients, build their linear combination."""
result = np.zeros(vectors[0].shape, dtype=np.float64)
for coeff, vector in zip(coefficients, vectors):
result += vector*coeff
return result
class LeastSquare(AbstractFitting):
def linear(vectors: List[np.ndarray], target: np.ndarray):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
matrix = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) supported_options = {
return np.array(coefficients, dtype=np.float64) "regularization": 0.0,
}
def quasi_time_reversible():
"""Time reversible least square minimization fitting.""" def set_options(self, **kwargs):
self.options = {}
def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray: for key, value in kwargs.items():
"""Given a set of vectors (or matrices) and the corresponding if key in self.supported_options:
coefficients, build their linear combination.""" self.options[key] = value
result = np.zeros(vectors[0].shape, dtype=np.float64) else:
for coeff, vector in zip(coefficients, vectors): raise ValueError(f"Unsupported option: {key}")
result += vector*coeff
return result for option, default_value in self.supported_options.items():
if not option in self.options:
self.options[option] = default_value
if self.options["regularization"] < 0 \
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def compute(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting
coefficients."""
matrix = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None)
return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting):
def set_options(**kwargs):
"""TODO"""
def compute(self):
"""Time reversible least square minimization fitting."""
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np import numpy as np
from . import grassmann from . import grassmann
from . import fitting from .fitting import LeastSquare, QuasiTimeReversible
from .descriptors import Distance, Coulomb from .descriptors import Distance, Coulomb
from .buffer import CircularBuffer from .buffer import CircularBuffer
...@@ -22,7 +22,7 @@ class Extrapolator: ...@@ -22,7 +22,7 @@ class Extrapolator:
"verbose": False, "verbose": False,
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
"fitting": "linear", "fitting": "leastsquare",
} }
self.nelectrons = nelectrons self.nelectrons = nelectrons
...@@ -71,6 +71,14 @@ class Extrapolator: ...@@ -71,6 +71,14 @@ class Extrapolator:
raise ValueError("Unsupported descriptor") raise ValueError("Unsupported descriptor")
self.descriptor_calculator.set_options(**descriptor_options) self.descriptor_calculator.set_options(**descriptor_options)
if self.options["fitting"] == "leastsquare":
self.fitting_calculator = LeastSquare()
elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible()
else:
raise ValueError("Unsupported descriptor")
self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray): overlap: np.ndarray):
"""Load a new data point in the extrapolator.""" """Load a new data point in the extrapolator."""
...@@ -87,22 +95,24 @@ class Extrapolator: ...@@ -87,22 +95,24 @@ class Extrapolator:
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.options["nsteps"]) n = min(self.options["nsteps"], self.descriptors.count)
prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
fit_coefficients = fitting.linear(prev_descriptors, descriptor) fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor)
gammas = self.gammas.get(self.options["nsteps"]) gammas = self.gammas.get(n)
gamma = fitting.linear_combination(gammas, fit_coefficients) gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients) fit_descriptor = self.fitting_calculator.linear_combination(prev_descriptors, fit_coefficients)
if self.options["verbose"]: if self.options["verbose"]:
print("error on descriptor:", \ print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
if overlap is None: if overlap is None:
overlaps = self.overlaps.get(self.options["nsteps"]) overlaps = self.overlaps.get(n)
overlap = fitting.linear_combination(overlaps, fit_coefficients) overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else: else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
......
...@@ -35,10 +35,12 @@ def test_descriptor_fitting(datafile): ...@@ -35,10 +35,12 @@ def test_descriptor_fitting(datafile):
descriptors = extrapolator.descriptors.get(10) descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1] target = descriptors[-1]
fitting_calculator = gext.fitting.LeastSquare()
for start in range(0, 9): for start in range(0, 9):
vectors = descriptors[start:-1] vectors = descriptors[start:-1]
fit_coefficients = gext.fitting.linear(vectors, target) fit_coefficients = fitting_calculator.compute(vectors, target)
fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
assert errors[0] < errors[-1] assert errors[0] < errors[-1]
...@@ -47,7 +49,7 @@ def test_descriptor_fitting(datafile): ...@@ -47,7 +49,7 @@ def test_descriptor_fitting(datafile):
# used for the fitting # used for the fitting
vectors = descriptors[:-1] vectors = descriptors[:-1]
vectors[0] = target vectors[0] = target
fit_coefficients = gext.fitting.linear(vectors, target) fit_coefficients = fitting_calculator.compute(vectors, target)
fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment