diff --git a/gext/descriptors.py b/gext/descriptors.py index 7a31240c2296214eb28ce7661e753f33561bcec9..1d6fe4821cfc707deaf22ef397e852cbe5ab9472 100644 --- a/gext/descriptors.py +++ b/gext/descriptors.py @@ -10,7 +10,7 @@ class Distance(): def __init__(self, **kwargs): self.set_options(**kwargs) - def set_options(self, kwargs): + def set_options(self, **kwargs): """Given an option dictionary set the valid options and raise an error if there are invalid ones.""" if len(kwargs) > 0: diff --git a/gext/main.py b/gext/main.py index 86ad8178eb5380d2f73c559b68d1111e1d9a196e..b924b3b068ef33241ffd694ebc0bc84ecfff1f00 100644 --- a/gext/main.py +++ b/gext/main.py @@ -5,7 +5,7 @@ import numpy as np from . import grassmann from . import fitting -from . import descriptors +from .descriptors import Distance, Coulomb from .buffer import CircularBuffer class Extrapolator: @@ -16,22 +16,19 @@ class Extrapolator: steps used by the extrapolator is an optional argument with default value of 6.""" - def _update_docstring(self): - options_str = "\n".join(f" - '{key}': {value}" for key, value in self.supported_options.items()) - self.__doc__ = self.__doc__.format(options=options_str) - def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): self.supported_options = { "verbose": False, - "nsteps": 6 + "nsteps": 6, + "descriptor": "distance", + "fitting": "linear", } - self._update_docstring() - self.nelectrons = nelectrons self.nbasis = nbasis self.natoms = natoms + self.set_options(**kwargs) self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis)) self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis)) @@ -40,8 +37,6 @@ class Extrapolator: self.tangent: Optional[np.ndarray] = None - self.set_options(**kwargs) - def set_options(self, **kwargs): """Given an arbitrary amount of keyword arguments, parse them if specified, set default values if not specified and raise an error @@ -62,8 +57,19 @@ class Extrapolator: raise ValueError(f"Unsupported option: {key}") for option, default_value in self.supported_options.items(): - if not hasattr(self.options, option): - setattr(self.options, option, default_value) + if not option in self.options: + self.options[option] = default_value + + if self.options["nsteps"] <= 1 or self.options["nsteps"] >= 100: + raise ValueError("Unsupported nsteps") + + if self.options["descriptor"] == "distance": + self.descriptor_calculator = Distance() + elif self.options["descriptor"] == "coulomb": + self.descriptor_calculator = Coulomb() + else: + raise ValueError("Unsupported descriptor") + self.descriptor_calculator.set_options(**descriptor_options) def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): @@ -81,20 +87,21 @@ class Extrapolator: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: """Get a new electronic density to be used as a guess.""" - prev_descriptors = self.descriptors.get(self.nsteps) + prev_descriptors = self.descriptors.get(self.options["nsteps"]) descriptor = self._compute_descriptor(coords) fit_coefficients = fitting.linear(prev_descriptors, descriptor) - gammas = self.gammas.get(self.nsteps) + gammas = self.gammas.get(self.options["nsteps"]) gamma = fitting.linear_combination(gammas, fit_coefficients) fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients) if self.options["verbose"]: - print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) + print("error on descriptor:", \ + np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) if overlap is None: - overlaps = self.overlaps.get(self.nsteps) + overlaps = self.overlaps.get(self.options["nsteps"]) overlap = fitting.linear_combination(overlaps, fit_coefficients) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) else: @@ -105,7 +112,6 @@ class Extrapolator: return c_guess @ c_guess.T - def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" if self.tangent is not None: @@ -149,4 +155,4 @@ class Extrapolator: def _compute_descriptor(self, coords) -> np.ndarray: """Given a set of coordinates compute the corresponding descriptor.""" - return descriptors.distance(coords) + return self.descriptor_calculator.compute(coords) diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index e7d84b3a866bb2a8ac12fbaf506465f2d6aedffd..493a5329f8af3bec73ec2fb1ca7d9ad31b86c577 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -23,7 +23,7 @@ def test_descriptor_fitting(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py index 69914ef676b19e68880bfb000651901f92d5b603..4adc67a2be15d261c0f713670be3ae28f992ea92 100644 --- a/tests/test_extrapolation.py +++ b/tests/test_extrapolation.py @@ -26,7 +26,7 @@ def test_extrapolation(datafile): assert n < nframes # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, n) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) # load data in the extrapolator up to index n - 1 for (coords, coeff, overlap) in zip(data["trajectory"][:n], diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py index 04d3fb9cf63fa1a937b614334174de0e66fd1d06..69786b275375912718b09782840b9dc1e232ae33 100644 --- a/tests/test_grassmann.py +++ b/tests/test_grassmann.py @@ -21,7 +21,7 @@ def test_grassmann(datafile): nframes = data["trajectory"].shape[0] # initialize an extrapolator - extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nframes) + extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"],