Skip to content
Snippets Groups Projects
Commit e2b4ca59 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Now it is possible to skip overlap storing.

parent f60da40e
Branches
Tags
1 merge request!6QTR
Pipeline #1957 failed
...@@ -22,6 +22,7 @@ class Extrapolator: ...@@ -22,6 +22,7 @@ class Extrapolator:
"descriptor": "distance", "descriptor": "distance",
"fitting": "leastsquare", "fitting": "leastsquare",
"allow_partially_filled": True, "allow_partially_filled": True,
"store_overlap": True,
} }
self.nelectrons = nelectrons self.nelectrons = nelectrons
...@@ -30,9 +31,10 @@ class Extrapolator: ...@@ -30,9 +31,10 @@ class Extrapolator:
self.set_options(**kwargs) self.set_options(**kwargs)
self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis)) self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis))
self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis))
self.descriptors = CircularBuffer(self.options["nsteps"], self.descriptors = CircularBuffer(self.options["nsteps"],
((self.natoms - 1)*self.natoms//2, )) ((self.natoms - 1)*self.natoms//2, ))
if self.options["store_overlap"]:
self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis))
self.tangent: Optional[np.ndarray] = None self.tangent: Optional[np.ndarray] = None
...@@ -78,8 +80,7 @@ class Extrapolator: ...@@ -78,8 +80,7 @@ class Extrapolator:
raise ValueError("Unsupported descriptor") raise ValueError("Unsupported descriptor")
self.fitting_calculator.set_options(**fitting_options) self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap):
overlap: np.ndarray):
"""Load a new data point in the extrapolator.""" """Load a new data point in the extrapolator."""
# Crop the coefficient matrix up to the number of electron # Crop the coefficient matrix up to the number of electron
...@@ -94,14 +95,16 @@ class Extrapolator: ...@@ -94,14 +95,16 @@ class Extrapolator:
# push the new data to the corresponding vectors # push the new data to the corresponding vectors
self.gammas.push(self._grassmann_log(coeff)) self.gammas.push(self._grassmann_log(coeff))
self.descriptors.push(self._compute_descriptor(coords)) self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: if self.options["store_overlap"]:
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new electronic density matrix to be used as a guess.""" """Get a new electronic density matrix to be used as a guess."""
c_guess = self.guess_coefficients(coords, overlap) c_guess = self.guess_coefficients(coords, overlap)
return c_guess @ c_guess.T return c_guess @ c_guess.T
def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray: def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess.""" """Get a new coefficient matrix to be used as a guess."""
# check if we have enough data points to perform an extrapolation # check if we have enough data points to perform an extrapolation
...@@ -115,6 +118,9 @@ class Extrapolator: ...@@ -115,6 +118,9 @@ class Extrapolator:
if count < n: if count < n:
raise ValueError("Not enough data loaded in the extrapolator") raise ValueError("Not enough data loaded in the extrapolator")
if overlap is None and not self.options["store_overlap"]:
raise ValueError("Guessing without overlap requires `store_overlap` true.")
# use the descriptors to find the fitting coefficients # use the descriptors to find the fitting coefficients
prev_descriptors = self.descriptors.get(n) prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment