Skip to content
Snippets Groups Projects
Commit 952b9d3f authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Now a simple extrapolation is working.

parent 44b65723
Branches
Tags
No related merge requests found
......@@ -13,5 +13,7 @@ def quasi_time_reversible():
"""Time reversible least square minimization fitting."""
def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
A = np.vstack(vectors).T
return A @ coefficients
result = np.zeros(vectors[0].shape, dtype=np.float64)
for coeff, vector in zip(coefficients, vectors):
result += vector*coeff
return result
......@@ -17,20 +17,22 @@ class Extrapolator:
value of 10."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int,
nsteps: int = 10):
nsteps: int = 10, **kwargs):
self.nelectrons = nelectrons
self.nbasis = nbasis
self.natoms = natoms
self.nsteps = nsteps
self.gammas = CircularBuffer(self.nsteps, (self.nelectrons, self.nbasis))
self.gammas = CircularBuffer(self.nsteps, (self.nelectrons//2, self.nbasis))
self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
self.descriptors = CircularBuffer(self.nsteps,
((self.natoms - 1)*self.natoms//2, ))
self.tangent: Optional[np.ndarray] = None
self.options = kwargs
def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray):
"""Load a new data point in the extrapolator."""
......@@ -44,19 +46,24 @@ class Extrapolator:
self.descriptors.push(self._compute_descriptor(coords))
self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray:
def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray] = None) -> np.ndarray:
"""Get a new electronic density to be used as a guess."""
prev_descriptors = self.descriptors.get(self.nsteps)
gammas = self.gammas.get(self.nsteps)
descriptor = self._compute_descriptor(coords)
coefficients = fitting.linear(prev_descriptors, descriptor)
fit_coefficients = fitting.linear(prev_descriptors, descriptor)
gammas = self.gammas.get(self.nsteps)
gamma = fitting.linear_combination(gammas, fit_coefficients)
fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
gamma = fitting.linear_combination(gammas, coefficients)
if self.options["verbose"]:
print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
if overlap is None:
overlaps = self.overlaps.get(self.nsteps)
overlap = fitting.linear_combination(overlaps, coefficients)
overlap = fitting.linear_combination(overlaps, fit_coefficients)
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
......@@ -74,7 +81,7 @@ class Extrapolator:
def _crop_coeff(self, coeff) -> np.ndarray:
"""Crop the coefficient matrix to remove the virtual orbitals."""
return coeff[:, :self.nelectrons]
return coeff[:, :self.nelectrons//2]
def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
"""Normalize the coefficients such that C.T * C = 1, D * D = D."""
......
import pytest
import os
import sys
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import grext
import grext.grassmann
import utils
SMALL = 1e-10
THRESHOLD = 1e-2
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_extrapolation(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
assert n < nframes
# initialize an extrapolator
extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, n)
# load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:n],
data["coefficients"][:n], data["overlaps"][:n]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index n
guessed_density = extrapolator.guess(data["trajectory"][n], data["overlaps"][n])
coeff = data["coefficients"][n][:, :nelectrons//2]
density = coeff @ coeff.T
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment