Skip to content
Snippets Groups Projects
Commit e32f573f authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Descriptor options working.

parent 831490c2
No related branches found
No related tags found
1 merge request!4Options
Pipeline #1949 passed
...@@ -10,7 +10,7 @@ class Distance(): ...@@ -10,7 +10,7 @@ class Distance():
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
def set_options(self, kwargs): def set_options(self, **kwargs):
"""Given an option dictionary set the valid options and """Given an option dictionary set the valid options and
raise an error if there are invalid ones.""" raise an error if there are invalid ones."""
if len(kwargs) > 0: if len(kwargs) > 0:
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
from . import grassmann from . import grassmann
from . import fitting from . import fitting
from . import descriptors from .descriptors import Distance, Coulomb
from .buffer import CircularBuffer from .buffer import CircularBuffer
class Extrapolator: class Extrapolator:
...@@ -16,22 +16,19 @@ class Extrapolator: ...@@ -16,22 +16,19 @@ class Extrapolator:
steps used by the extrapolator is an optional argument with default steps used by the extrapolator is an optional argument with default
value of 6.""" value of 6."""
def _update_docstring(self):
options_str = "\n".join(f" - '{key}': {value}" for key, value in self.supported_options.items())
self.__doc__ = self.__doc__.format(options=options_str)
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
self.supported_options = { self.supported_options = {
"verbose": False, "verbose": False,
"nsteps": 6 "nsteps": 6,
"descriptor": "distance",
"fitting": "linear",
} }
self._update_docstring()
self.nelectrons = nelectrons self.nelectrons = nelectrons
self.nbasis = nbasis self.nbasis = nbasis
self.natoms = natoms self.natoms = natoms
self.set_options(**kwargs)
self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis)) self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis))
self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis)) self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis))
...@@ -40,8 +37,6 @@ class Extrapolator: ...@@ -40,8 +37,6 @@ class Extrapolator:
self.tangent: Optional[np.ndarray] = None self.tangent: Optional[np.ndarray] = None
self.set_options(**kwargs)
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Given an arbitrary amount of keyword arguments, parse them if """Given an arbitrary amount of keyword arguments, parse them if
specified, set default values if not specified and raise an error specified, set default values if not specified and raise an error
...@@ -62,8 +57,19 @@ class Extrapolator: ...@@ -62,8 +57,19 @@ class Extrapolator:
raise ValueError(f"Unsupported option: {key}") raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items(): for option, default_value in self.supported_options.items():
if not hasattr(self.options, option): if not option in self.options:
setattr(self.options, option, default_value) self.options[option] = default_value
if self.options["nsteps"] <= 1 or self.options["nsteps"] >= 100:
raise ValueError("Unsupported nsteps")
if self.options["descriptor"] == "distance":
self.descriptor_calculator = Distance()
elif self.options["descriptor"] == "coulomb":
self.descriptor_calculator = Coulomb()
else:
raise ValueError("Unsupported descriptor")
self.descriptor_calculator.set_options(**descriptor_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray): overlap: np.ndarray):
...@@ -81,20 +87,21 @@ class Extrapolator: ...@@ -81,20 +87,21 @@ class Extrapolator:
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.nsteps) prev_descriptors = self.descriptors.get(self.options["nsteps"])
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
fit_coefficients = fitting.linear(prev_descriptors, descriptor) fit_coefficients = fitting.linear(prev_descriptors, descriptor)
gammas = self.gammas.get(self.nsteps) gammas = self.gammas.get(self.options["nsteps"])
gamma = fitting.linear_combination(gammas, fit_coefficients) gamma = fitting.linear_combination(gammas, fit_coefficients)
fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients) fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
if self.options["verbose"]: if self.options["verbose"]:
print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
if overlap is None: if overlap is None:
overlaps = self.overlaps.get(self.nsteps) overlaps = self.overlaps.get(self.options["nsteps"])
overlap = fitting.linear_combination(overlaps, fit_coefficients) overlap = fitting.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else: else:
...@@ -105,7 +112,6 @@ class Extrapolator: ...@@ -105,7 +112,6 @@ class Extrapolator:
return c_guess @ c_guess.T return c_guess @ c_guess.T
def _get_tangent(self) -> np.ndarray: def _get_tangent(self) -> np.ndarray:
"""Get the tangent point.""" """Get the tangent point."""
if self.tangent is not None: if self.tangent is not None:
...@@ -149,4 +155,4 @@ class Extrapolator: ...@@ -149,4 +155,4 @@ class Extrapolator:
def _compute_descriptor(self, coords) -> np.ndarray: def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor.""" """Given a set of coordinates compute the corresponding descriptor."""
return descriptors.distance(coords) return self.descriptor_calculator.compute(coords)
...@@ -23,7 +23,7 @@ def test_descriptor_fitting(datafile): ...@@ -23,7 +23,7 @@ def test_descriptor_fitting(datafile):
nframes = data["trajectory"].shape[0] nframes = data["trajectory"].shape[0]
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nframes) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes)
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
......
...@@ -26,7 +26,7 @@ def test_extrapolation(datafile): ...@@ -26,7 +26,7 @@ def test_extrapolation(datafile):
assert n < nframes assert n < nframes
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, n) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
# load data in the extrapolator up to index n - 1 # load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:n], for (coords, coeff, overlap) in zip(data["trajectory"][:n],
......
...@@ -21,7 +21,7 @@ def test_grassmann(datafile): ...@@ -21,7 +21,7 @@ def test_grassmann(datafile):
nframes = data["trajectory"].shape[0] nframes = data["trajectory"].shape[0]
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nframes) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes)
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment