Skip to content
Snippets Groups Projects
Select Git revision
  • 0c054fa47d325cdb78af827e83e90778c9044892
  • main default protected
  • askarpza-main-patch-76094
  • polynomial_regression
  • optimization
  • v0.8.0
  • v0.7.1
  • v0.7.0
  • v0.6.0
  • v0.5.0
  • v0.4.1
  • v0.4.0
  • v0.3.0
  • v0.2.0
14 results

main.py

Blame
  • main.py 8.48 KiB
    """Main module containing the Extrapolator class."""
    
    from typing import Optional
    import numpy as np
    
    from . import grassmann
    from .fitting import LeastSquare, QuasiTimeReversible,DiffFitting
    from .descriptors import Distance, Coulomb
    from .buffer import CircularBuffer
    
    class Extrapolator:
    
        """Class for performing Grassmann extrapolations. On initialization
        it requires the number of electrons, the number of basis functions
        and the number of atoms of the molecule."""
    
        supported_options = {
            "verbose": False,
            "nsteps": 6,
            "descriptor": "distance",
            "fitting": "diff",
            "allow_partially_filled": True,
            "store_overlap": True,
        }
    
        def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
    
            if not (isinstance(nelectrons, int) and isinstance(nbasis, int) \
                    and isinstance(natoms, int)):
                raise ValueError("Dimensions are not integers")
    
            self.nelectrons = nelectrons
            self.nbasis = nbasis
            self.natoms = natoms
            self.set_options(**kwargs)
    
            self.gammas = CircularBuffer(self.options["nsteps"],
                (self.nelectrons//2, self.nbasis))
            self.descriptors = CircularBuffer(self.options["nsteps"],
                ((self.natoms - 1)*self.natoms//2, ))
            if self.options["store_overlap"]:
                self.overlaps = CircularBuffer(self.options["nsteps"],
                    (self.nbasis, self.nbasis))
    
            self.tangent: Optional[np.ndarray] = None
    
        def set_options(self, **kwargs):
            """Given an arbitrary amount of keyword arguments, parse them if
            specified, set default values if not specified and raise an error
            if invalid arguments are passed."""
    
            self.options = {}
            descriptor_options = {}
            fitting_options = {}
    
            # set specified options
            for key, value in kwargs.items():
                if key in self.supported_options:
                    self.options[key] = value
                elif key.startswith("descriptor_"):
                    descriptor_options[key[11:]] = value
                elif key.startswith("fitting_"):
                    fitting_options[key[8:]] = value
                else:
                    raise ValueError(f"Unsupported option: {key}")
    
            # set unspecified options with defaults
            for option, default_value in self.supported_options.items():
                if not option in self.options:
                    self.options[option] = default_value
    
            # do some check on the options, set things and pipe options
            # to submodules
            if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100:
                raise ValueError("Unsupported nsteps")
    
            if self.options["descriptor"] == "distance":
                self.descriptor_calculator = Distance()
            elif self.options["descriptor"] == "coulomb":
                self.descriptor_calculator = Coulomb()
            else:
                raise ValueError("Unsupported descriptor")
            self.descriptor_calculator.set_options(**descriptor_options)
    
            if self.options["fitting"] == "leastsquare":
                self.fitting_calculator = LeastSquare()
            elif self.options["fitting"] == "diff":
                self.fitting_calculator = DiffFitting()
            elif self.options["fitting"] == "qtr":
                self.fitting_calculator = QuasiTimeReversible()
            else:
                raise ValueError("Unsupported fitting")
            self.fitting_calculator.set_options(**fitting_options)
    
        def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap):
            """Load a new data point in the extrapolator."""
    
            # Crop the coefficient matrix up to the number of electron
            # pairs, then apply S^1/2
            coeff = self._crop_coeff(coeff)
            coeff = self._normalize(coeff, overlap)
    
            # if it is the first time we load data, set the tangent point
            if self.tangent is None:
                self._set_tangent(coeff)
    
            # push the new data to the corresponding vectors
            self.gammas.push(self._grassmann_log(coeff))
            self.descriptors.push(self._compute_descriptor(coords))
    
            if self.options["store_overlap"]:
                self.overlaps.push(overlap)
    
        def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray:
            """Get a new electronic density matrix to be used as a guess."""
            c_guess = self.guess_coefficients(coords, overlap)
            return c_guess @ c_guess.T
    
        def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray:
            """Get a new coefficient matrix to be used as a guess."""
    
            # check if we have enough data points to perform an extrapolation
            count = self.descriptors.count
            if self.options["allow_partially_filled"]:
                if count == 0:
                    raise ValueError("Not enough data loaded in the extrapolator")
                n = min(self.options["nsteps"], count)
            else:
                n = self.options["nsteps"]
                if count < n:
                    raise ValueError("Not enough data loaded in the extrapolator")
    
            if overlap is None and not self.options["store_overlap"]:
                raise ValueError("Guessing without overlap requires `store_overlap` true.")
    
            # use the descriptors to find the fitting coefficients
            prev_descriptors = self.descriptors.get(n)
            descriptor = self._compute_descriptor(coords)
            fit_coefficients = self._fit(prev_descriptors, descriptor)
            print(fit_coefficients)
    
            # use the fitting coefficients and the previous gammas to
            # extrapolate a new gamma
            gammas = self.gammas.get(n)
            gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
           
            if self.options["verbose"]:
                fit_descriptor = self.fitting_calculator.linear_combination(
                    prev_descriptors, fit_coefficients)
                print("error on descriptor:", \
                    np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
    
            # if the overlap is not given, use the coefficients to fit
            # a new overlap
            if overlap is None:
                overlaps = self.overlaps.get(n)
                overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
                inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
            else:
                inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
    
            # use the overlap and gamma to find a new set of coefficients
            c_guess = self._grassmann_exp(gamma)
            return inverse_sqrt_overlap @ c_guess
    
        def _get_tangent(self) -> np.ndarray:
            """Get the tangent point."""
            if self.tangent is not None:
                return self.tangent
            raise ValueError("Tangent point is not set.")
    
        def _crop_coeff(self, coeff) -> np.ndarray:
            """Crop the coefficient matrix to remove the virtual orbitals."""
            return coeff[:, :self.nelectrons//2]
    
        def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
            """Normalize the coefficients such that C.T * C = 1, D * D = D."""
            return self._sqrt_overlap(overlap).dot(coeff)
    
        def _set_tangent(self, c: np.ndarray):
            """Set the tangent point."""
            if self.tangent is None:
                self.tangent = c
            else:
                raise ValueError("Resetting the tangent.")
    
        def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
            """Map from the manifold to the tangent plane."""
            tangent = self._get_tangent()
            return grassmann.log(coeff, tangent)
    
        def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray:
            """Map from the tangent plane to the manifold."""
            tangent = self._get_tangent()
            return grassmann.exp(gamma, tangent)
    
        def _sqrt_overlap(self, overlap) -> np.ndarray:
            """Compute the square root of the overlap matrix."""
            q, s, vt = np.linalg.svd(overlap, full_matrices=False)
            return q @ np.diag(np.sqrt(s)) @ vt
    
        def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
            """Compute the square root of the overlap matrix."""
            q, s, vt = np.linalg.svd(overlap, full_matrices=False)
            return q @ np.diag(1.0/np.sqrt(s)) @ vt
    
        def _compute_descriptor(self, coords) -> np.ndarray:
            """Given a set of coordinates compute the corresponding descriptor."""
            return self.descriptor_calculator.compute(coords)
    
        def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
            """Fit the current descriptor using previous descriptors and
            the specified fitting scheme."""
            return self.fitting_calculator.fit(prev_descriptors, descriptor)