From 76adbf0aad9781d597e3d394fd3b5a6e933d3841 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Thu, 19 Oct 2023 13:23:57 +0200
Subject: [PATCH] Linted.
---
grext/__init__.py | 4 ++++
grext/buffer.py | 5 ++++-
grext/fitting.py | 5 +++--
grext/grassmann.py | 29 ++++++++++-------------------
grext/main.py | 25 +++++++++++++++++--------
5 files changed, 38 insertions(+), 30 deletions(-)
diff --git a/grext/__init__.py b/grext/__init__.py
index e69de29..06beea2 100644
--- a/grext/__init__.py
+++ b/grext/__init__.py
@@ -0,0 +1,4 @@
+"""The package grext provides tools for generating new guesses for the
+self consistent field in molecular dynamics simulations."""
+
+from .main import Extrapolator
diff --git a/grext/buffer.py b/grext/buffer.py
index e5a60c8..f734505 100644
--- a/grext/buffer.py
+++ b/grext/buffer.py
@@ -1,5 +1,8 @@
-import numpy as np
+"""Module that defines a circular buffer for storing the last properties
+in a molecular dynamics simulation."""
+
from typing import List, Tuple
+import numpy as np
class CircularBuffer:
diff --git a/grext/fitting.py b/grext/fitting.py
index 7993e0b..99c607b 100644
--- a/grext/fitting.py
+++ b/grext/fitting.py
@@ -1,6 +1,7 @@
+"""Module that defines fitting functions."""
def linear():
- pass
+ """Simple least square minimization fitting."""
def quasi_time_reversible():
- pass
+ """Time reversible least square minimization fitting."""
diff --git a/grext/grassmann.py b/grext/grassmann.py
index dd63225..f99d06c 100644
--- a/grext/grassmann.py
+++ b/grext/grassmann.py
@@ -1,36 +1,27 @@
+"""Module that defines the bare Grassmann operations."""
+
import numpy as np
-def log_plain(c: np.ndarray, c0: np.ndarray) -> np.ndarray:
+def log_alt(c: np.ndarray, c0: np.ndarray) -> np.ndarray:
+ """Grassmann logarithm alterative version."""
c0c_inv = np.linalg.inv(c0.T @ c)
- L = c @ c0c_inv - c0
- q, s, vt = np.linalg.svd(L, full_matrices=False)
+ l = c @ c0c_inv - c0
+ q, s, vt = np.linalg.svd(l, full_matrices=False)
arctan_s = np.diag(np.arctan(s))
return q @ arctan_s @ vt
def log(c: np.ndarray, c0: np.ndarray) -> np.ndarray:
+ """Grassmann logarithm."""
psi, s, rt = np.linalg.svd(c.T @ c0, full_matrices=False)
cstar = c @ psi @ rt
- L = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar
- u, s, vt = np.linalg.svd(L, full_matrices=False)
+ l = (np.identity(c.shape[0]) - c0 @ c0.T) @ cstar
+ u, s, vt = np.linalg.svd(l, full_matrices=False)
arcsin_s = np.diag(np.arcsin(s))
return u @ arcsin_s @ vt
def exp(gamma: np.ndarray, c0: np.ndarray) -> np.ndarray:
+ """Grassmann exponential."""
q, s, vt = np.linalg.svd(gamma, full_matrices=False)
sin_s = np.diag(np.sin(s))
cos_s = np.diag(np.cos(s))
return c0 @ vt.T @ cos_s @ vt + q @ sin_s @ vt
-
-def psi(d: np.ndarray, n: np.ndarray) -> np.ndarray:
- a = d[0:n, 0:n]
- b = d[n:, 0:n]
- ainv = np.linalg.inv(a)
- return b @ ainv
-
-def phi(b: np.ndarray) -> np.ndarray:
- nb_n, n = b.shape
- q = np.linalg.inv(np.identity(n) + b.T @ b)
- l = np.zeros((nb_n + n, n))
- l[0:n,:] = np.identity(n)
- l[n:,:] = b
- return l @ q @ l.T
diff --git a/grext/main.py b/grext/main.py
index 099311b..61b1c3f 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -1,12 +1,16 @@
-import numpy as np
+"""Main module containing the Extrapolator class."""
+
from typing import Optional
+import numpy as np
from . import grassmann
from .buffer import CircularBuffer
-class GrassmannExt:
+class Extrapolator:
- """Module for performing Grassmann extrapolations."""
+ """Class for performing Grassmann extrapolations. On initialization
+ it requires the number of electrons, the number of basis functions
+ and the number of atoms of the molecule."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int,
nsteps: int = 10):
@@ -20,7 +24,7 @@ class GrassmannExt:
self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
self.coords = CircularBuffer(self.nsteps, (self.natoms, 3))
- self.is_tangent_set = False
+ self.tangent: Optional[np.ndarray] = None
def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray):
@@ -28,13 +32,15 @@ class GrassmannExt:
coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap)
+ if self.tangent is not None:
+ self._set_tangent(coeff)
+
self.gammas.push(self._grassmann_log(coeff))
self.coords.push(coords)
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]):
"""Get a new electronic density to be used as a guess."""
- pass
def _crop_coeff(self, coeff) -> np.ndarray:
"""Crop the coefficient matrix to remove the virtual orbitals."""
@@ -48,13 +54,16 @@ class GrassmannExt:
def _set_tangent(self, c: np.ndarray):
"""Set the tangent point."""
- self.is_tangent_set = True
self.tangent = c
def _grassmann_log(self, coeff: np.ndarray):
"""Map from the manifold to the tangent plane."""
- return grassmann.log(coeff, self.tangent)
+ if self.tangent is not None:
+ return grassmann.log(coeff, self.tangent)
+ raise ValueError("Tangent point is not set.")
def _grassmann_exp(self, gamma: np.ndarray):
"""Map from the tangent plane to the manifold."""
- return grassmann.exp(gamma, self.tangent)
+ if self.tangent is not None:
+ return grassmann.exp(gamma, self.tangent)
+ raise ValueError("Tangent point is not set.")
--
GitLab