Skip to content
Snippets Groups Projects
Commit 335969a8 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Added a test for the fitting.

parent e573ca5d
Branches
Tags
No related merge requests found
"""Module that defines fitting functions.""" """Module that defines fitting functions."""
def linear(): from typing import List
import numpy as np
def linear(vectors: List[np.ndarray], target: np.ndarray):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
A = np.vstack(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(A, target, rcond=None)
return np.array(coefficients, dtype=np.float64)
def quasi_time_reversible(): def quasi_time_reversible():
"""Time reversible least square minimization fitting.""" """Time reversible least square minimization fitting."""
def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
A = np.vstack(vectors).T
return A @ coefficients
import pytest
import os
import sys
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import grext
import grext.descriptors
import grext.fitting
import grext.grassmann
import utils
SMALL = 1e-10
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_descriptor_fitting(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, nframes)
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
# we check if the error goes down with a larger data set
errors = []
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
for start in range(0, 9):
vectors = descriptors[start:-1]
fit_coefficients = grext.fitting.linear(vectors, target)
fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
assert errors[0] < errors[-1]
# we check that we get a zero error if we put the target in the vectors
# used for the fitting
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = grext.fitting.linear(vectors, target)
fitted_target = grext.fitting.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment