diff --git a/grext/buffer.py b/grext/buffer.py index f734505b24d75581019f33cbf2223e6c844deb6f..c5ce6bc9611ee63490b53011c84c82820714809e 100644 --- a/grext/buffer.py +++ b/grext/buffer.py @@ -8,7 +8,7 @@ class CircularBuffer: """Circular buffer to store the last `n` matrices.""" - def __init__(self, n: int, shape: Tuple[int, int]): + def __init__(self, n: int, shape: Tuple[int, ...]): self.n = n self.shape = shape self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)] diff --git a/grext/descriptors.py b/grext/descriptors.py new file mode 100644 index 0000000000000000000000000000000000000000..87fbd99355bb15e1b9b4acb7aed353e3ce6128fe --- /dev/null +++ b/grext/descriptors.py @@ -0,0 +1,12 @@ +"""Module which provides functions to compute descriptors.""" + +import numpy as np +from scipy.spatial.distance import pdist + +def distance(coords: np.ndarray) -> np.ndarray: + """Compute the distance matric as a descriptor.""" + return pdist(coords, metric="euclidean") + +def coulomb(coords: np.ndarray) -> np.ndarray: + """Compute the Coulomb matrix as a descriptor.""" + return 1.0/distance(coords) diff --git a/grext/main.py b/grext/main.py index 25d9861e2c84b6f7119008699d88ccf4e3fe0237..1bf9adc5c0c0476a17beba304d2bdb3ddbc5ade7 100644 --- a/grext/main.py +++ b/grext/main.py @@ -5,6 +5,7 @@ import numpy as np from . import grassmann from . import fitting +from . import descriptors from .buffer import CircularBuffer class Extrapolator: @@ -25,7 +26,8 @@ class Extrapolator: self.gammas = CircularBuffer(self.nsteps, (self.nelectrons, self.nbasis)) self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis)) - self.coords = CircularBuffer(self.nsteps, (self.natoms, 3)) + self.descriptors = CircularBuffer(self.nsteps, + ((self.natoms - 1)*self.natoms//2, )) self.tangent: Optional[np.ndarray] = None @@ -39,11 +41,12 @@ class Extrapolator: self._set_tangent(coeff) self.gammas.push(self._grassmann_log(coeff)) - self.coords.push(coords) + self.descriptors.push(self._compute_descriptor(coords)) self.overlaps.push(overlap) def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]): """Get a new electronic density to be used as a guess.""" + descriptor = self._compute_descriptor(coords) coefficients = fitting.linear() def _get_tangent(self) -> np.ndarray: @@ -77,7 +80,11 @@ class Extrapolator: tangent = self._get_tangent() return grassmann.exp(gamma, tangent) - def _sqrt_overlap(self, overlap): + def _sqrt_overlap(self, overlap) -> np.ndarray: """Compute the square root of the overlap matrix.""" q, s, vt = np.linalg.svd(overlap, full_matrices=False) return q @ np.diag(np.sqrt(s)) @ vt + + def _compute_descriptor(self, coords) -> np.ndarray: + """Given a set of coordinates compute the corresponding descriptor.""" + return descriptors.distance(coords)