From 76adbf0aad9781d597e3d394fd3b5a6e933d3841 Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Thu, 19 Oct 2023 13:23:57 +0200 Subject: [PATCH] Linted. --- grext/__init__.py | 4 ++++ grext/buffer.py | 5 ++++- grext/fitting.py | 5 +++-- grext/grassmann.py | 29 ++++++++++------------------- grext/main.py | 25 +++++++++++++++++-------- 5 files changed, 38 insertions(+), 30 deletions(-) diff --git a/grext/__init__.py b/grext/__init__.py index e69de29..06beea2 100644 --- a/grext/__init__.py +++ b/grext/__init__.py @@ -0,0 +1,4 @@ +"""The package grext provides tools for generating new guesses for the +self consistent field in molecular dynamics simulations.""" + +from .main import Extrapolator diff --git a/grext/buffer.py b/grext/buffer.py index e5a60c8..f734505 100644 --- a/grext/buffer.py +++ b/grext/buffer.py @@ -1,5 +1,8 @@ -import numpy as np +"""Module that defines a circular buffer for storing the last properties +in a molecular dynamics simulation.""" + from typing import List, Tuple +import numpy as np class CircularBuffer: diff --git a/grext/fitting.py b/grext/fitting.py index 7993e0b..99c607b 100644 --- a/grext/fitting.py +++ b/grext/fitting.py @@ -1,6 +1,7 @@ +"""Module that defines fitting functions.""" def linear(): - pass + """Simple least square minimization fitting.""" def quasi_time_reversible(): - pass + """Time reversible least square minimization fitting.""" diff --git a/grext/grassmann.py b/grext/grassmann.py index dd63225..f99d06c 100644 --- a/grext/grassmann.py +++ b/grext/grassmann.py @@ -1,36 +1,27 @@ +"""Module that defines the bare Grassmann operations.""" + import numpy as np -def log_plain(c: np.ndarray, c0: np.ndarray) -> np.ndarray: +def log_alt(c: np.ndarray, c0: np.ndarray) -> np.ndarray: + """Grassmann logarithm alterative version.""" c0c_inv = np.linalg.inv(c0.T @ c) - L = c @ c0c_inv - c0 - q, s, vt = np.linalg.svd(L, full_matrices=False) + l = c @ c0c_inv - c0 + q, s, vt = np.linalg.svd(l, full_matrices=False) arctan_s = np.diag(np.arctan(s)) return q @ arctan_s @ vt def log(c: np.ndarray, c0: np.ndarray) -> np.ndarray: + """Grassmann logarithm.""" psi, s, rt = np.linalg.svd(c.T @ c0, full_matrices=False) cstar = c @ psi @ rt - L = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar - u, s, vt = np.linalg.svd(L, full_matrices=False) + l = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar + u, s, vt = np.linalg.svd(l, full_matrices=False) arcsin_s = np.diag(np.arcsin(s)) return u @ arcsin_s @ vt def exp(gamma: np.ndarray, c0: np.ndarray) -> np.ndarray: + """Grassmann exponential.""" q, s, vt = np.linalg.svd(gamma, full_matrices=False) sin_s = np.diag(np.sin(s)) cos_s = np.diag(np.cos(s)) return c0 @ vt.T @ cos_s @ vt + q @ sin_s @ vt - -def psi(d: np.ndarray, n: np.ndarray) -> np.ndarray: - a = d[0:n, 0:n] - b = d[n:, 0:n] - ainv = np.linalg.inv(a) - return b @ ainv - -def phi(b: np.ndarray) -> np.ndarray: - nb_n, n = b.shape - q = np.linalg.inv(np.identity(n) + b.T @ b) - l = np.zeros((nb_n + n, n)) - l[0:n,:] = np.identity(n) - l[n:,:] = b - return l @ q @ l.T diff --git a/grext/main.py b/grext/main.py index 099311b..61b1c3f 100644 --- a/grext/main.py +++ b/grext/main.py @@ -1,12 +1,16 @@ -import numpy as np +"""Main module containing the Extrapolator class.""" + from typing import Optional +import numpy as np from . import grassmann from .buffer import CircularBuffer -class GrassmannExt: +class Extrapolator: - """Module for performing Grassmann extrapolations.""" + """Class for performing Grassmann extrapolations. On initialization + it requires the number of electrons, the number of basis functions + and the number of atoms of the molecule.""" def __init__(self, nelectrons: int, nbasis: int, natoms: int, nsteps: int = 10): @@ -20,7 +24,7 @@ class GrassmannExt: self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis)) self.coords = CircularBuffer(self.nsteps, (self.natoms, 3)) - self.is_tangent_set = False + self.tangent: Optional[np.ndarray] = None def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): @@ -28,13 +32,15 @@ class GrassmannExt: coeff = self._crop_coeff(coeff) coeff = self._normalize(coeff, overlap) + if self.tangent is not None: + self._set_tangent(coeff) + self.gammas.push(self._grassmann_log(coeff)) self.coords.push(coords) self.overlaps.push(overlap) def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]): """Get a new electronic density to be used as a guess.""" - pass def _crop_coeff(self, coeff) -> np.ndarray: """Crop the coefficient matrix to remove the virtual orbitals.""" @@ -48,13 +54,16 @@ class GrassmannExt: def _set_tangent(self, c: np.ndarray): """Set the tangent point.""" - self.is_tangent_set = True self.tangent = c def _grassmann_log(self, coeff: np.ndarray): """Map from the manifold to the tangent plane.""" - return grassmann.log(coeff, self.tangent) + if self.tangent is not None: + return grassmann.log(coeff, self.tangent) + raise ValueError("Tangent point is not set.") def _grassmann_exp(self, gamma: np.ndarray): """Map from the tangent plane to the manifold.""" - return grassmann.exp(gamma, self.tangent) + if self.tangent is not None: + return grassmann.exp(gamma, self.tangent) + raise ValueError("Tangent point is not set.") -- GitLab