From a748e52dacb2f85d5b4cdfd53af9ab7e97cddd66 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Fri, 3 Nov 2023 14:47:46 +0100 Subject: [PATCH] Options working. --- gext/descriptors.py | 4 +- gext/fitting.py | 77 +++++++++++++++++++++++++------- gext/main.py | 28 ++++++++---- tests/test_descriptor_fitting.py | 10 +++-- 4 files changed, 88 insertions(+), 31 deletions(-) diff --git a/gext/descriptors.py b/gext/descriptors.py index 1d6fe48..f1f6657 100644 --- a/gext/descriptors.py +++ b/gext/descriptors.py @@ -1,9 +1,9 @@ -"""Module which provides functions to compute descriptors.""" +"""Module which provides functionality to compute descriptors.""" import numpy as np from scipy.spatial.distance import pdist -class Distance(): +class Distance: """Distance matrix descriptors.""" diff --git a/gext/fitting.py b/gext/fitting.py index 6dfd0de..3c6adb7 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -1,21 +1,66 @@ -"""Module that defines fitting functions.""" +"""Module which provides functionality to perform fitting.""" from typing import List import numpy as np +import abc + +class AbstractFitting(abc.ABC): + + def __init__(self, **kwargs): + self.set_options(**kwargs) + + @abc.abstractmethod + def set_options(self, **kwargs): + """Base method for setting options.""" + + @abc.abstractmethod + def compute(self, vectors: List[np.ndarray], target:np.ndarray): + """Base method for computing new fitting coefficients.""" + + def linear_combination(self, vectors: List[np.ndarray], + coefficients: np. ndarray) -> np.ndarray: + """Given a set of vectors (or matrices) and the corresponding + coefficients, build their linear combination.""" + result = np.zeros(vectors[0].shape, dtype=np.float64) + for coeff, vector in zip(coefficients, vectors): + result += vector*coeff + return result + +class LeastSquare(AbstractFitting): -def linear(vectors: List[np.ndarray], target: np.ndarray): """Simple least square minimization fitting.""" - matrix = np.vstack(vectors).T - coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) - return np.array(coefficients, dtype=np.float64) - -def quasi_time_reversible(): - """Time reversible least square minimization fitting.""" - -def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray: - """Given a set of vectors (or matrices) and the corresponding - coefficients, build their linear combination.""" - result = np.zeros(vectors[0].shape, dtype=np.float64) - for coeff, vector in zip(coefficients, vectors): - result += vector*coeff - return result + + supported_options = { + "regularization": 0.0, + } + + def set_options(self, **kwargs): + self.options = {} + for key, value in kwargs.items(): + if key in self.supported_options: + self.options[key] = value + else: + raise ValueError(f"Unsupported option: {key}") + + for option, default_value in self.supported_options.items(): + if not option in self.options: + self.options[option] = default_value + + if self.options["regularization"] < 0 \ + or self.options["regularization"] > 100: + raise ValueError("Unsupported value for regularization") + + def compute(self, vectors: List[np.ndarray], target: np.ndarray): + """Given a set of vectors and a target return the fitting + coefficients.""" + matrix = np.vstack(vectors).T + coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) + return np.array(coefficients, dtype=np.float64) + +class QuasiTimeReversible(AbstractFitting): + + def set_options(**kwargs): + """TODO""" + + def compute(self): + """Time reversible least square minimization fitting.""" diff --git a/gext/main.py b/gext/main.py index b924b3b..c7aac97 100644 --- a/gext/main.py +++ b/gext/main.py @@ -4,7 +4,7 @@ from typing import Optional import numpy as np from . import grassmann -from . import fitting +from .fitting import LeastSquare, QuasiTimeReversible from .descriptors import Distance, Coulomb from .buffer import CircularBuffer @@ -22,7 +22,7 @@ class Extrapolator: "verbose": False, "nsteps": 6, "descriptor": "distance", - "fitting": "linear", + "fitting": "leastsquare", } self.nelectrons = nelectrons @@ -71,6 +71,14 @@ class Extrapolator: raise ValueError("Unsupported descriptor") self.descriptor_calculator.set_options(**descriptor_options) + if self.options["fitting"] == "leastsquare": + self.fitting_calculator = LeastSquare() + elif self.options["fitting"] == "qtr": + self.fitting_calculator = QuasiTimeReversible() + else: + raise ValueError("Unsupported descriptor") + self.fitting_calculator.set_options(**fitting_options) + def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): """Load a new data point in the extrapolator.""" @@ -87,22 +95,24 @@ class Extrapolator: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: """Get a new electronic density to be used as a guess.""" - prev_descriptors = self.descriptors.get(self.options["nsteps"]) + n = min(self.options["nsteps"], self.descriptors.count) + + prev_descriptors = self.descriptors.get(n) descriptor = self._compute_descriptor(coords) - fit_coefficients = fitting.linear(prev_descriptors, descriptor) + fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor) - gammas = self.gammas.get(self.options["nsteps"]) - gamma = fitting.linear_combination(gammas, fit_coefficients) + gammas = self.gammas.get(n) + gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients) - fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients) + fit_descriptor = self.fitting_calculator.linear_combination(prev_descriptors, fit_coefficients) if self.options["verbose"]: print("error on descriptor:", \ np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) if overlap is None: - overlaps = self.overlaps.get(self.options["nsteps"]) - overlap = fitting.linear_combination(overlaps, fit_coefficients) + overlaps = self.overlaps.get(n) + overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) else: inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index 493a532..cc8c42b 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -35,10 +35,12 @@ def test_descriptor_fitting(datafile): descriptors = extrapolator.descriptors.get(10) target = descriptors[-1] + fitting_calculator = gext.fitting.LeastSquare() + for start in range(0, 9): vectors = descriptors[start:-1] - fit_coefficients = gext.fitting.linear(vectors, target) - fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients) + fit_coefficients = fitting_calculator.compute(vectors, target) + fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) assert errors[0] < errors[-1] @@ -47,7 +49,7 @@ def test_descriptor_fitting(datafile): # used for the fitting vectors = descriptors[:-1] vectors[0] = target - fit_coefficients = gext.fitting.linear(vectors, target) - fitted_target = gext.fitting.linear_combination(vectors, fit_coefficients) + fit_coefficients = fitting_calculator.compute(vectors, target) + fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL -- GitLab