Skip to content
Snippets Groups Projects
Commit 84b04f68 authored by Askarpour, Zahra's avatar Askarpour, Zahra
Browse files

test_fitting works

parent 8b774296
No related branches found
No related tags found
No related merge requests found
Pipeline #2060 failed
import os
import sys
import numpy as np
import gext
import gext.descriptors
import gext.fitting
import gext.grassmann
from tests import utils
SMALL = 1e-8
THRESHOLD = 5e-2
regularization = 0.0
# load test data from json file
data = utils.load_json(f"tests/urea.json")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting_regularization=regularization,
fitting="diff")
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
fitting_calculator = extrapolator.fitting_calculator
# check if things are reasonable
for start in range(0, 8):
vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
error = np.linalg.norm(target - fitted_target, ord=np.inf)
assert error < THRESHOLD
# if we put the target in the vectors used for the fitting,
# check that we get an error smaller than the regularization
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
print(np.linalg.norm(target - fitted_target, ord=np.inf), max(SMALL, regularization))
assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment