Skip to content
Snippets Groups Projects
Commit 76adbf0a authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Linted.

parent 9e313cdf
No related branches found
No related tags found
No related merge requests found
"""The package grext provides tools for generating new guesses for the
self consistent field in molecular dynamics simulations."""
from .main import Extrapolator
import numpy as np """Module that defines a circular buffer for storing the last properties
in a molecular dynamics simulation."""
from typing import List, Tuple from typing import List, Tuple
import numpy as np
class CircularBuffer: class CircularBuffer:
......
"""Module that defines fitting functions."""
def linear(): def linear():
pass """Simple least square minimization fitting."""
def quasi_time_reversible(): def quasi_time_reversible():
pass """Time reversible least square minimization fitting."""
"""Module that defines the bare Grassmann operations."""
import numpy as np import numpy as np
def log_plain(c: np.ndarray, c0: np.ndarray) -> np.ndarray: def log_alt(c: np.ndarray, c0: np.ndarray) -> np.ndarray:
"""Grassmann logarithm alterative version."""
c0c_inv = np.linalg.inv(c0.T @ c) c0c_inv = np.linalg.inv(c0.T @ c)
L = c @ c0c_inv - c0 l = c @ c0c_inv - c0
q, s, vt = np.linalg.svd(L, full_matrices=False) q, s, vt = np.linalg.svd(l, full_matrices=False)
arctan_s = np.diag(np.arctan(s)) arctan_s = np.diag(np.arctan(s))
return q @ arctan_s @ vt return q @ arctan_s @ vt
def log(c: np.ndarray, c0: np.ndarray) -> np.ndarray: def log(c: np.ndarray, c0: np.ndarray) -> np.ndarray:
"""Grassmann logarithm."""
psi, s, rt = np.linalg.svd(c.T @ c0, full_matrices=False) psi, s, rt = np.linalg.svd(c.T @ c0, full_matrices=False)
cstar = c @ psi @ rt cstar = c @ psi @ rt
L = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar l = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar
u, s, vt = np.linalg.svd(L, full_matrices=False) u, s, vt = np.linalg.svd(l, full_matrices=False)
arcsin_s = np.diag(np.arcsin(s)) arcsin_s = np.diag(np.arcsin(s))
return u @ arcsin_s @ vt return u @ arcsin_s @ vt
def exp(gamma: np.ndarray, c0: np.ndarray) -> np.ndarray: def exp(gamma: np.ndarray, c0: np.ndarray) -> np.ndarray:
"""Grassmann exponential."""
q, s, vt = np.linalg.svd(gamma, full_matrices=False) q, s, vt = np.linalg.svd(gamma, full_matrices=False)
sin_s = np.diag(np.sin(s)) sin_s = np.diag(np.sin(s))
cos_s = np.diag(np.cos(s)) cos_s = np.diag(np.cos(s))
return c0 @ vt.T @ cos_s @ vt + q @ sin_s @ vt return c0 @ vt.T @ cos_s @ vt + q @ sin_s @ vt
def psi(d: np.ndarray, n: np.ndarray) -> np.ndarray:
a = d[0:n, 0:n]
b = d[n:, 0:n]
ainv = np.linalg.inv(a)
return b @ ainv
def phi(b: np.ndarray) -> np.ndarray:
nb_n, n = b.shape
q = np.linalg.inv(np.identity(n) + b.T @ b)
l = np.zeros((nb_n + n, n))
l[0:n,:] = np.identity(n)
l[n:,:] = b
return l @ q @ l.T
import numpy as np """Main module containing the Extrapolator class."""
from typing import Optional from typing import Optional
import numpy as np
from . import grassmann from . import grassmann
from .buffer import CircularBuffer from .buffer import CircularBuffer
class GrassmannExt: class Extrapolator:
"""Module for performing Grassmann extrapolations.""" """Class for performing Grassmann extrapolations. On initialization
it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int, def __init__(self, nelectrons: int, nbasis: int, natoms: int,
nsteps: int = 10): nsteps: int = 10):
...@@ -20,7 +24,7 @@ class GrassmannExt: ...@@ -20,7 +24,7 @@ class GrassmannExt:
self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis)) self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
self.coords = CircularBuffer(self.nsteps, (self.natoms, 3)) self.coords = CircularBuffer(self.nsteps, (self.natoms, 3))
self.is_tangent_set = False self.tangent: Optional[np.ndarray] = None
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray): overlap: np.ndarray):
...@@ -28,13 +32,15 @@ class GrassmannExt: ...@@ -28,13 +32,15 @@ class GrassmannExt:
coeff = self._crop_coeff(coeff) coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap) coeff = self._normalize(coeff, overlap)
if self.tangent is not None:
self._set_tangent(coeff)
self.gammas.push(self._grassmann_log(coeff)) self.gammas.push(self._grassmann_log(coeff))
self.coords.push(coords) self.coords.push(coords)
self.overlaps.push(overlap) self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]): def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]):
"""Get a new electronic density to be used as a guess.""" """Get a new electronic density to be used as a guess."""
pass
def _crop_coeff(self, coeff) -> np.ndarray: def _crop_coeff(self, coeff) -> np.ndarray:
"""Crop the coefficient matrix to remove the virtual orbitals.""" """Crop the coefficient matrix to remove the virtual orbitals."""
...@@ -48,13 +54,16 @@ class GrassmannExt: ...@@ -48,13 +54,16 @@ class GrassmannExt:
def _set_tangent(self, c: np.ndarray): def _set_tangent(self, c: np.ndarray):
"""Set the tangent point.""" """Set the tangent point."""
self.is_tangent_set = True
self.tangent = c self.tangent = c
def _grassmann_log(self, coeff: np.ndarray): def _grassmann_log(self, coeff: np.ndarray):
"""Map from the manifold to the tangent plane.""" """Map from the manifold to the tangent plane."""
if self.tangent is not None:
return grassmann.log(coeff, self.tangent) return grassmann.log(coeff, self.tangent)
raise ValueError("Tangent point is not set.")
def _grassmann_exp(self, gamma: np.ndarray): def _grassmann_exp(self, gamma: np.ndarray):
"""Map from the tangent plane to the manifold.""" """Map from the tangent plane to the manifold."""
if self.tangent is not None:
return grassmann.exp(gamma, self.tangent) return grassmann.exp(gamma, self.tangent)
raise ValueError("Tangent point is not set.")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment