Skip to content
Snippets Groups Projects
Commit f60da40e authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Reshape.

parent bfb73bcb
Branches
Tags
1 merge request!6QTR
...@@ -9,11 +9,23 @@ class AbstractFitting(abc.ABC): ...@@ -9,11 +9,23 @@ class AbstractFitting(abc.ABC):
"""Base class for fitting schemes.""" """Base class for fitting schemes."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.supported_options = {}
self.set_options(**kwargs) self.set_options(**kwargs)
@abc.abstractmethod @abc.abstractmethod
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Base method for setting options.""" """Base method for setting options."""
self.options = {}
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
else:
raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items():
if option not in self.options:
self.options[option] = default_value
@abc.abstractmethod @abc.abstractmethod
def fit(self, vectors: List[np.ndarray], target:np.ndarray): def fit(self, vectors: List[np.ndarray], target:np.ndarray):
...@@ -33,22 +45,15 @@ class LeastSquare(AbstractFitting): ...@@ -33,22 +45,15 @@ class LeastSquare(AbstractFitting):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
supported_options = { def __init__(self, **kwargs):
"regularization": 0.0, self.supported_options = {
} "regularization": 0.0,
}
super().__init__(**kwargs)
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Set options for least square minimization""" """Set options for least square minimization"""
self.options = {} super().set_options(**kwargs)
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
else:
raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items():
if option not in self.options:
self.options[option] = default_value
if self.options["regularization"] < 0 \ if self.options["regularization"] < 0 \
or self.options["regularization"] > 100: or self.options["regularization"] > 100:
...@@ -65,8 +70,19 @@ class QuasiTimeReversible(AbstractFitting): ...@@ -65,8 +70,19 @@ class QuasiTimeReversible(AbstractFitting):
"""Quasi time reversible fitting scheme. Not yet implemented.""" """Quasi time reversible fitting scheme. Not yet implemented."""
def __init__(self, **kwargs):
self.supported_options = {
"regularization": 0.0,
}
super().__init__(**kwargs)
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Set options for quasi time reversible fitting""" """Set options for quasi time reversible fitting"""
super().set_options(**kwargs)
if self.options["regularization"] < 0 \
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def fit(self, vectors: List[np.ndarray], target: np.ndarray): def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Time reversible least square minimization fitting.""" """Time reversible least square minimization fitting."""
......
...@@ -12,9 +12,7 @@ class Extrapolator: ...@@ -12,9 +12,7 @@ class Extrapolator:
"""Class for performing Grassmann extrapolations. On initialization """Class for performing Grassmann extrapolations. On initialization
it requires the number of electrons, the number of basis functions it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule. The number of previous and the number of atoms of the molecule."""
steps used by the extrapolator is an optional argument with default
value of 6."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
...@@ -83,12 +81,17 @@ class Extrapolator: ...@@ -83,12 +81,17 @@ class Extrapolator:
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray): overlap: np.ndarray):
"""Load a new data point in the extrapolator.""" """Load a new data point in the extrapolator."""
# Crop the coefficient matrix up to the number of electron
# pairs, then apply S^1/2
coeff = self._crop_coeff(coeff) coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap) coeff = self._normalize(coeff, overlap)
# if it is the first time we load data, set the tangent point
if self.tangent is None: if self.tangent is None:
self._set_tangent(coeff) self._set_tangent(coeff)
# push the new data to the corresponding vectors
self.gammas.push(self._grassmann_log(coeff)) self.gammas.push(self._grassmann_log(coeff))
self.descriptors.push(self._compute_descriptor(coords)) self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap) self.overlaps.push(overlap)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment