diff --git a/gext/main.py b/gext/main.py index b48594a0962b3360d8a7b3c2725ea7d6bd504833..76af6a80b510b9f702b7290a71bc3526ae7473aa 100644 --- a/gext/main.py +++ b/gext/main.py @@ -22,6 +22,7 @@ class Extrapolator: "descriptor": "distance", "fitting": "leastsquare", "allow_partially_filled": True, + "store_overlap": True, } self.nelectrons = nelectrons @@ -30,9 +31,10 @@ class Extrapolator: self.set_options(**kwargs) self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis)) - self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis)) self.descriptors = CircularBuffer(self.options["nsteps"], ((self.natoms - 1)*self.natoms//2, )) + if self.options["store_overlap"]: + self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis)) self.tangent: Optional[np.ndarray] = None @@ -78,8 +80,7 @@ class Extrapolator: raise ValueError("Unsupported descriptor") self.fitting_calculator.set_options(**fitting_options) - def load_data(self, coords: np.ndarray, coeff: np.ndarray, - overlap: np.ndarray): + def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap): """Load a new data point in the extrapolator.""" # Crop the coefficient matrix up to the number of electron @@ -94,14 +95,16 @@ class Extrapolator: # push the new data to the corresponding vectors self.gammas.push(self._grassmann_log(coeff)) self.descriptors.push(self._compute_descriptor(coords)) - self.overlaps.push(overlap) - def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray: + if self.options["store_overlap"]: + self.overlaps.push(overlap) + + def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray: """Get a new electronic density matrix to be used as a guess.""" c_guess = self.guess_coefficients(coords, overlap) return c_guess @ c_guess.T - def guess_coefficients(self, coords: np.ndarray, overlap = None) -> np.ndarray: + def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray: """Get a new coefficient matrix to be used as a guess.""" # check if we have enough data points to perform an extrapolation @@ -115,6 +118,9 @@ class Extrapolator: if count < n: raise ValueError("Not enough data loaded in the extrapolator") + if overlap is None and not self.options["store_overlap"]: + raise ValueError("Guessing without overlap requires `store_overlap` true.") + # use the descriptors to find the fitting coefficients prev_descriptors = self.descriptors.get(n) descriptor = self._compute_descriptor(coords)