Skip to content
Snippets Groups Projects
Commit a748e52d authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Options working.

parent e32f573f
No related branches found
No related tags found
1 merge request!4Options
"""Module which provides functions to compute descriptors."""
"""Module which provides functionality to compute descriptors."""
import numpy as np
from scipy.spatial.distance import pdist
class Distance():
class Distance:
"""Distance matrix descriptors."""
......
"""Module that defines fitting functions."""
"""Module which provides functionality to perform fitting."""
from typing import List
import numpy as np
import abc
class AbstractFitting(abc.ABC):
def __init__(self, **kwargs):
self.set_options(**kwargs)
@abc.abstractmethod
def set_options(self, **kwargs):
"""Base method for setting options."""
@abc.abstractmethod
def compute(self, vectors: List[np.ndarray], target:np.ndarray):
"""Base method for computing new fitting coefficients."""
def linear_combination(self, vectors: List[np.ndarray],
coefficients: np. ndarray) -> np.ndarray:
"""Given a set of vectors (or matrices) and the corresponding
coefficients, build their linear combination."""
result = np.zeros(vectors[0].shape, dtype=np.float64)
for coeff, vector in zip(coefficients, vectors):
result += vector*coeff
return result
class LeastSquare(AbstractFitting):
def linear(vectors: List[np.ndarray], target: np.ndarray):
"""Simple least square minimization fitting."""
matrix = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None)
return np.array(coefficients, dtype=np.float64)
def quasi_time_reversible():
"""Time reversible least square minimization fitting."""
def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
"""Given a set of vectors (or matrices) and the corresponding
coefficients, build their linear combination."""
result = np.zeros(vectors[0].shape, dtype=np.float64)
for coeff, vector in zip(coefficients, vectors):
result += vector*coeff
return result
supported_options = {
"regularization": 0.0,
}
def set_options(self, **kwargs):
self.options = {}
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
else:
raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items():
if not option in self.options:
self.options[option] = default_value
if self.options["regularization"] < 0 \
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def compute(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting
coefficients."""
matrix = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None)
return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting):
def set_options(**kwargs):
"""TODO"""
def compute(self):
"""Time reversible least square minimization fitting."""
......@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np
from . import grassmann
from . import fitting
from .fitting import LeastSquare, QuasiTimeReversible
from .descriptors import Distance, Coulomb
from .buffer import CircularBuffer
......@@ -22,7 +22,7 @@ class Extrapolator:
"verbose": False,
"nsteps": 6,
"descriptor": "distance",
"fitting": "linear",
"fitting": "leastsquare",
}
self.nelectrons = nelectrons
......@@ -71,6 +71,14 @@ class Extrapolator:
raise ValueError("Unsupported descriptor")
self.descriptor_calculator.set_options(**descriptor_options)
if self.options["fitting"] == "leastsquare":
self.fitting_calculator = LeastSquare()
elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible()
else:
raise ValueError("Unsupported descriptor")
self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray):
"""Load a new data point in the extrapolator."""
......@@ -87,22 +95,24 @@ class Extrapolator:
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.options["nsteps"])
n = min(self.options["nsteps"], self.descriptors.count)
prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords)
fit_coefficients = fitting.linear(prev_descriptors, descriptor)
fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor)
gammas = self.gammas.get(self.options["nsteps"])
gamma = fitting.linear_combination(gammas, fit_coefficients)
gammas = self.gammas.get(n)
gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
fit_descriptor = self.fitting_calculator.linear_combination(prev_descriptors, fit_coefficients)
if self.options["verbose"]:
print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
if overlap is None:
overlaps = self.overlaps.get(self.options["nsteps"])
overlap = fitting.linear_combination(overlaps, fit_coefficients)
overlaps = self.overlaps.get(n)
overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
......
......@@ -35,10 +35,12 @@ def test_descriptor_fitting(datafile):
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
fitting_calculator = gext.fitting.LeastSquare()
for start in range(0, 9):
vectors = descriptors[start:-1]
fit_coefficients = gext.fitting.linear(vectors, target)
fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients)
fit_coefficients = fitting_calculator.compute(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
assert errors[0] < errors[-1]
......@@ -47,7 +49,7 @@ def test_descriptor_fitting(datafile):
# used for the fitting
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = gext.fitting.linear(vectors, target)
fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients)
fit_coefficients = fitting_calculator.compute(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment