Skip to content
Snippets Groups Projects
Commit c76c9b8c authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Switched to descriptors.

parent f5c831c7
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@ class CircularBuffer:
"""Circular buffer to store the last `n` matrices."""
def __init__(self, n: int, shape: Tuple[int, int]):
def __init__(self, n: int, shape: Tuple[int, ...]):
self.n = n
self.shape = shape
self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)]
......
"""Module which provides functions to compute descriptors."""
import numpy as np
from scipy.spatial.distance import pdist
def distance(coords: np.ndarray) -> np.ndarray:
"""Compute the distance matric as a descriptor."""
return pdist(coords, metric="euclidean")
def coulomb(coords: np.ndarray) -> np.ndarray:
"""Compute the Coulomb matrix as a descriptor."""
return 1.0/distance(coords)
......@@ -5,6 +5,7 @@ import numpy as np
from . import grassmann
from . import fitting
from . import descriptors
from .buffer import CircularBuffer
class Extrapolator:
......@@ -25,7 +26,8 @@ class Extrapolator:
self.gammas = CircularBuffer(self.nsteps, (self.nelectrons, self.nbasis))
self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
self.coords = CircularBuffer(self.nsteps, (self.natoms, 3))
self.descriptors = CircularBuffer(self.nsteps,
((self.natoms - 1)*self.natoms//2, ))
self.tangent: Optional[np.ndarray] = None
......@@ -39,11 +41,12 @@ class Extrapolator:
self._set_tangent(coeff)
self.gammas.push(self._grassmann_log(coeff))
self.coords.push(coords)
self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]):
"""Get a new electronic density to be used as a guess."""
descriptor = self._compute_descriptor(coords)
coefficients = fitting.linear()
def _get_tangent(self) -> np.ndarray:
......@@ -77,7 +80,11 @@ class Extrapolator:
tangent = self._get_tangent()
return grassmann.exp(gamma, tangent)
def _sqrt_overlap(self, overlap):
def _sqrt_overlap(self, overlap) -> np.ndarray:
"""Compute the square root of the overlap matrix."""
q, s, vt = np.linalg.svd(overlap, full_matrices=False)
return q @ np.diag(np.sqrt(s)) @ vt
def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor."""
return descriptors.distance(coords)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment