Select Git revision
main.py 8.48 KiB
"""Main module containing the Extrapolator class."""
from typing import Optional
import numpy as np
from . import grassmann
from .fitting import LeastSquare, QuasiTimeReversible,DiffFitting
from .descriptors import Distance, Coulomb
from .buffer import CircularBuffer
class Extrapolator:
"""Class for performing Grassmann extrapolations. On initialization
it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule."""
supported_options = {
"verbose": False,
"nsteps": 6,
"descriptor": "distance",
"fitting": "diff",
"allow_partially_filled": True,
"store_overlap": True,
}
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
if not (isinstance(nelectrons, int) and isinstance(nbasis, int) \
and isinstance(natoms, int)):
raise ValueError("Dimensions are not integers")
self.nelectrons = nelectrons
self.nbasis = nbasis
self.natoms = natoms
self.set_options(**kwargs)
self.gammas = CircularBuffer(self.options["nsteps"],
(self.nelectrons//2, self.nbasis))
self.descriptors = CircularBuffer(self.options["nsteps"],
((self.natoms - 1)*self.natoms//2, ))
if self.options["store_overlap"]:
self.overlaps = CircularBuffer(self.options["nsteps"],
(self.nbasis, self.nbasis))
self.tangent: Optional[np.ndarray] = None
def set_options(self, **kwargs):
"""Given an arbitrary amount of keyword arguments, parse them if
specified, set default values if not specified and raise an error
if invalid arguments are passed."""
self.options = {}
descriptor_options = {}
fitting_options = {}
# set specified options
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
elif key.startswith("descriptor_"):
descriptor_options[key[11:]] = value
elif key.startswith("fitting_"):
fitting_options[key[8:]] = value
else:
raise ValueError(f"Unsupported option: {key}")
# set unspecified options with defaults
for option, default_value in self.supported_options.items():
if not option in self.options:
self.options[option] = default_value
# do some check on the options, set things and pipe options
# to submodules
if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100:
raise ValueError("Unsupported nsteps")
if self.options["descriptor"] == "distance":
self.descriptor_calculator = Distance()
elif self.options["descriptor"] == "coulomb":
self.descriptor_calculator = Coulomb()
else:
raise ValueError("Unsupported descriptor")
self.descriptor_calculator.set_options(**descriptor_options)
if self.options["fitting"] == "leastsquare":
self.fitting_calculator = LeastSquare()
elif self.options["fitting"] == "diff":
self.fitting_calculator = DiffFitting()
elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible()
else:
raise ValueError("Unsupported fitting")
self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap):
"""Load a new data point in the extrapolator."""
# Crop the coefficient matrix up to the number of electron
# pairs, then apply S^1/2
coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap)
# if it is the first time we load data, set the tangent point
if self.tangent is None:
self._set_tangent(coeff)
# push the new data to the corresponding vectors
self.gammas.push(self._grassmann_log(coeff))
self.descriptors.push(self._compute_descriptor(coords))
if self.options["store_overlap"]:
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new electronic density matrix to be used as a guess."""
c_guess = self.guess_coefficients(coords, overlap)
return c_guess @ c_guess.T
def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess."""
# check if we have enough data points to perform an extrapolation
count = self.descriptors.count
if self.options["allow_partially_filled"]:
if count == 0:
raise ValueError("Not enough data loaded in the extrapolator")
n = min(self.options["nsteps"], count)
else:
n = self.options["nsteps"]
if count < n:
raise ValueError("Not enough data loaded in the extrapolator")
if overlap is None and not self.options["store_overlap"]:
raise ValueError("Guessing without overlap requires `store_overlap` true.")
# use the descriptors to find the fitting coefficients
prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords)
fit_coefficients = self._fit(prev_descriptors, descriptor)
print(fit_coefficients)
# use the fitting coefficients and the previous gammas to
# extrapolate a new gamma
gammas = self.gammas.get(n)
gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
if self.options["verbose"]:
fit_descriptor = self.fitting_calculator.linear_combination(
prev_descriptors, fit_coefficients)
print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
# if the overlap is not given, use the coefficients to fit
# a new overlap
if overlap is None:
overlaps = self.overlaps.get(n)
overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
# use the overlap and gamma to find a new set of coefficients
c_guess = self._grassmann_exp(gamma)
return inverse_sqrt_overlap @ c_guess
def _get_tangent(self) -> np.ndarray:
"""Get the tangent point."""
if self.tangent is not None:
return self.tangent
raise ValueError("Tangent point is not set.")
def _crop_coeff(self, coeff) -> np.ndarray:
"""Crop the coefficient matrix to remove the virtual orbitals."""
return coeff[:, :self.nelectrons//2]
def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
"""Normalize the coefficients such that C.T * C = 1, D * D = D."""
return self._sqrt_overlap(overlap).dot(coeff)
def _set_tangent(self, c: np.ndarray):
"""Set the tangent point."""
if self.tangent is None:
self.tangent = c
else:
raise ValueError("Resetting the tangent.")
def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
"""Map from the manifold to the tangent plane."""
tangent = self._get_tangent()
return grassmann.log(coeff, tangent)
def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray:
"""Map from the tangent plane to the manifold."""
tangent = self._get_tangent()
return grassmann.exp(gamma, tangent)
def _sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(np.sqrt(s)) @ vt
def _inverse_sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(1.0/np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return self.descriptor_calculator.compute(coords)
def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
"""Fit the current descriptor using previous descriptors and
the specified fitting scheme."""
return self.fitting_calculator.fit(prev_descriptors, descriptor)