Skip to content
Snippets Groups Projects
Select Git revision
  • 71b1c3aa1f6e90b286c69026870ef4a63bf4bbd8
  • main default protected
  • askarpza-main-patch-76094
  • polynomial_regression
  • optimization
  • v0.8.0
  • v0.7.1
  • v0.7.0
  • v0.6.0
  • v0.5.0
  • v0.4.1
  • v0.4.0
  • v0.3.0
  • v0.2.0
14 results

main.py

Blame
  • main.py 4.72 KiB
    """Main module containing the Extrapolator class."""
    
    from typing import Optional
    import numpy as np
    
    from . import grassmann
    from . import fitting
    from . import descriptors
    from .buffer import CircularBuffer
    
    class Extrapolator:
    
        """Class for performing Grassmann extrapolations. On initialization
        it requires the number of electrons, the number of basis functions
        and the number of atoms of the molecule. The number of previous
        steps used by the extrapolator is an optional argument with default
        value of 10."""
    
        def __init__(self, nelectrons: int, nbasis: int, natoms: int,
                nsteps: int = 10, **kwargs):
    
            self.nelectrons = nelectrons
            self.nbasis = nbasis
            self.natoms = natoms
            self.nsteps = nsteps
    
            self.gammas = CircularBuffer(self.nsteps, (self.nelectrons//2, self.nbasis))
            self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
            self.descriptors = CircularBuffer(self.nsteps,
                ((self.natoms - 1)*self.natoms//2, ))
    
            self.tangent: Optional[np.ndarray] = None
    
            self._set_options(**kwargs)
    
        def load_data(self, coords: np.ndarray, coeff: np.ndarray,
                overlap: np.ndarray):
            """Load a new data point in the extrapolator."""
            coeff = self._crop_coeff(coeff)
            coeff = self._normalize(coeff, overlap)
    
            if self.tangent is None:
                self._set_tangent(coeff)
    
            self.gammas.push(self._grassmann_log(coeff))
            self.descriptors.push(self._compute_descriptor(coords))
            self.overlaps.push(overlap)
    
        def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
            """Get a new electronic density to be used as a guess."""
    
            prev_descriptors = self.descriptors.get(self.nsteps)
            descriptor = self._compute_descriptor(coords)
            fit_coefficients = fitting.linear(prev_descriptors, descriptor)
    
            gammas = self.gammas.get(self.nsteps)
            gamma = fitting.linear_combination(gammas, fit_coefficients)
    
            fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
    
            if self.options["verbose"]:
                print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
    
            if overlap is None:
                overlaps = self.overlaps.get(self.nsteps)
                overlap = fitting.linear_combination(overlaps, fit_coefficients)
                inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
            else:
                inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
    
            c_guess = self._grassmann_exp(gamma)
            c_guess = inverse_sqrt_overlap @ c_guess
    
            return c_guess @ c_guess.T
    
        def _set_options(self, **kwargs):
            """Parse additional options from the additional keyword arguments."""
            self.options = {}
            if "verbose" in kwargs:
                self.options["verbose"] = kwargs["verbose"]
            else:
                self.options["verbose"] = False
    
        def _get_tangent(self) -> np.ndarray:
            """Get the tangent point."""
            if self.tangent is not None:
                return self.tangent
            raise ValueError("Tangent point is not set.")
    
        def _crop_coeff(self, coeff) -> np.ndarray:
            """Crop the coefficient matrix to remove the virtual orbitals."""
            return coeff[:, :self.nelectrons//2]
    
        def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
            """Normalize the coefficients such that C.T * C = 1, D * D = D."""
            return self._sqrt_overlap(overlap).dot(coeff)
    
        def _set_tangent(self, c: np.ndarray):
            """Set the tangent point."""
            if self.tangent is None:
                self.tangent = c
            else:
                raise ValueError("Resetting the tangent.")
    
        def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
            """Map from the manifold to the tangent plane."""
            tangent = self._get_tangent()
            return grassmann.log(coeff, tangent)
    
        def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray:
            """Map from the tangent plane to the manifold."""
            tangent = self._get_tangent()
            return grassmann.exp(gamma, tangent)
    
        def _sqrt_overlap(self, overlap) -> np.ndarray:
            """Compute the square root of the overlap matrix."""
            q, s, vt = np.linalg.svd(overlap, full_matrices=False)
            return q @ np.diag(np.sqrt(s)) @ vt
    
        def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
            """Compute the square root of the overlap matrix."""
            q, s, vt = np.linalg.svd(overlap, full_matrices=False)
            return q @ np.diag(1.0/np.sqrt(s)) @ vt
    
        def _compute_descriptor(self, coords) -> np.ndarray:
            """Given a set of coordinates compute the corresponding descriptor."""
            return descriptors.distance(coords)