Skip to content
Snippets Groups Projects
Commit 7dfaccc0 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Time reversible working.

parent 0f0a4592
No related branches found
No related tags found
1 merge request!6QTR
......@@ -86,4 +86,31 @@ class QuasiTimeReversible(AbstractFitting):
def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Time reversible least square minimization fitting."""
return np.zeros(0)
past_target = vectors[0]
matrix = np.array(vectors[1:]).T
q = matrix.shape[1]
if q == 1:
time_reversible_matrix = matrix
elif q%2 == 0:
time_reversible_matrix = matrix[:, :q//2] + matrix[:, :q//2-1:-1]
else:
time_reversible_matrix = matrix[:, :q//2+1] + matrix[:, :q//2-1:-1]
A = time_reversible_matrix.T @ time_reversible_matrix
b = time_reversible_matrix.T @ (target + past_target)
if self.options["regularization"] > 0.0:
A += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(A, b)
if q == 1:
full_coefficients = np.concatenate(([-1.0], coefficients))
elif q%2 == 0:
full_coefficients = np.concatenate(([-1.0], coefficients,
coefficients[::-1]))
else:
full_coefficients = np.concatenate(([-1.0], coefficients[:-1],
2.0*coefficients[-1:], coefficients[-2::-1]))
return np.array(full_coefficients, dtype=np.float64)
......@@ -77,7 +77,7 @@ class Extrapolator:
elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible()
else:
raise ValueError("Unsupported descriptor")
raise ValueError("Unsupported fitting")
self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap):
......
......@@ -11,10 +11,11 @@ import gext.grassmann
import utils
SMALL = 1e-8
THRESHOLD = 5e-2
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.1])
def test_descriptor_fitting(datafile, regularization):
@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05])
def test_least_square(datafile, regularization):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
......@@ -25,33 +26,65 @@ def test_descriptor_fitting(datafile, regularization):
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting_regularization=regularization)
nsteps=nframes, fitting_regularization=regularization,
fitting="leastsquare")
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
# we check if the error goes down with a larger data set
errors = []
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
fitting_calculator = gext.fitting.LeastSquare()
fitting_calculator = extrapolator.fitting_calculator
# check if things are reasonable
for start in range(0, 9):
vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf))
error = np.linalg.norm(target - fitted_target, ord=np.inf)
assert error < THRESHOLD
assert errors[0] < errors[-1]
# we check that we get a zero error if we put the target in the vectors
# used for the fitting
# if we put the target in the vectors used for the fitting,
# check that we get an error smaller than the regularization
vectors = descriptors[:-1]
vectors[0] = target
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL
assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization)
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05])
def test_quasi_time_reversible(datafile, regularization):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting="qtr", fitting_regularization=regularization)
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
fitting_calculator = extrapolator.fitting_calculator
# check if things are reasonable
for start in range(0, 8):
vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
error = np.linalg.norm(target - fitted_target, ord=np.inf)
assert error < THRESHOLD
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment