Select Git revision
main.py 4.72 KiB
"""Main module containing the Extrapolator class."""
from typing import Optional
import numpy as np
from . import grassmann
from . import fitting
from . import descriptors
from .buffer import CircularBuffer
class Extrapolator:
"""Class for performing Grassmann extrapolations. On initialization
it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule. The number of previous
steps used by the extrapolator is an optional argument with default
value of 10."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int,
nsteps: int = 10, **kwargs):
self.nelectrons = nelectrons
self.nbasis = nbasis
self.natoms = natoms
self.nsteps = nsteps
self.gammas = CircularBuffer(self.nsteps, (self.nelectrons//2, self.nbasis))
self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
self.descriptors = CircularBuffer(self.nsteps,
((self.natoms - 1)*self.natoms//2, ))
self.tangent: Optional[np.ndarray] = None
self._set_options(**kwargs)
def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray):
"""Load a new data point in the extrapolator."""
coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap)
if self.tangent is None:
self._set_tangent(coeff)
self.gammas.push(self._grassmann_log(coeff))
self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.nsteps)
descriptor = self._compute_descriptor(coords)
fit_coefficients = fitting.linear(prev_descriptors, descriptor)
gammas = self.gammas.get(self.nsteps)
gamma = fitting.linear_combination(gammas, fit_coefficients)
fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
if self.options["verbose"]:
print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
if overlap is None:
overlaps = self.overlaps.get(self.nsteps)
overlap = fitting.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
c_guess = self._grassmann_exp(gamma)
c_guess = inverse_sqrt_overlap @ c_guess
return c_guess @ c_guess.T
def _set_options(self, **kwargs):
"""Parse additional options from the additional keyword arguments."""
self.options = {}
if "verbose" in kwargs:
self.options["verbose"] = kwargs["verbose"]
else:
self.options["verbose"] = False
def _get_tangent(self) -> np.ndarray:
"""Get the tangent point."""
if self.tangent is not None:
return self.tangent
raise ValueError("Tangent point is not set.")
def _crop_coeff(self, coeff) -> np.ndarray:
"""Crop the coefficient matrix to remove the virtual orbitals."""
return coeff[:, :self.nelectrons//2]
def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
"""Normalize the coefficients such that C.T * C = 1, D * D = D."""
return self._sqrt_overlap(overlap).dot(coeff)
def _set_tangent(self, c: np.ndarray):
"""Set the tangent point."""
if self.tangent is None:
self.tangent = c
else:
raise ValueError("Resetting the tangent.")
def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
"""Map from the manifold to the tangent plane."""
tangent = self._get_tangent()
return grassmann.log(coeff, tangent)
def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray:
"""Map from the tangent plane to the manifold."""
tangent = self._get_tangent()
return grassmann.exp(gamma, tangent)
def _sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(np.sqrt(s)) @ vt
def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(1.0/np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return descriptors.distance(coords)