Skip to content
Snippets Groups Projects
Commit 335969a8 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Added a test for the fitting.

parent e573ca5d
No related branches found
No related tags found
No related merge requests found
"""Module that defines fitting functions.""" """Module that defines fitting functions."""
def linear(): from typing import List
import numpy as np
def linear(vectors: List[np.ndarray], target: np.ndarray):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
A = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(A, target, rcond=None)
return np.array(coefficients, dtype=np.float64)
def quasi_time_reversible(): def quasi_time_reversible():
"""Time reversible least square minimization fitting.""" """Time reversible least square minimization fitting."""
def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
A = np.vstack(vectors).T
return A @ coefficients
import pytest
import os
import sys
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import grext
import grext.descriptors
import grext.fitting
import grext.grassmann
import utils
SMALL = 1e-10
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_descriptor_fitting(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, nframes)
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
# we check if the error goes down with a larger data set
errors = []
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
for start in range(0, 9):
vectors = descriptors[start:-1]
fit_coefficients = grext.fitting.linear(vectors, target)
fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
assert errors[0] < errors[-1]
# we check that we get a zero error if we put the target in the vectors
# used for the fitting
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = grext.fitting.linear(vectors, target)
fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment