From 84b04f6803cccf1133dea8755d4f44eec00b9348 Mon Sep 17 00:00:00 2001
From: Zahra Askarpour <Zahra.Askarpour@mathematik.uni-stuttgart.de>
Date: Thu, 15 Feb 2024 12:06:40 +0100
Subject: [PATCH] test_fitting works

---
 test_diff_fitting.py | 54 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 54 insertions(+)
 create mode 100644 test_diff_fitting.py

diff --git a/test_diff_fitting.py b/test_diff_fitting.py
new file mode 100644
index 0000000..0f25591
--- /dev/null
+++ b/test_diff_fitting.py
@@ -0,0 +1,54 @@
+import os
+import sys
+import numpy as np
+
+import gext
+import gext.descriptors
+import gext.fitting
+import gext.grassmann
+from tests import utils
+
+SMALL = 1e-8
+THRESHOLD = 5e-2
+
+regularization = 0.0
+
+# load test data from json file
+data = utils.load_json(f"tests/urea.json")
+nelectrons = data["nelectrons"]
+natoms = data["trajectory"].shape[1]
+nbasis = data["overlaps"].shape[1]
+nframes = data["trajectory"].shape[0]
+
+# initialize an extrapolator
+extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
+    nsteps=nframes, fitting_regularization=regularization,
+    fitting="diff")
+
+# load data in the extrapolator
+for (coords, coeff, overlap) in zip(data["trajectory"],
+        data["coefficients"], data["overlaps"]):
+    extrapolator.load_data(coords, coeff, overlap)
+
+descriptors = extrapolator.descriptors.get(10)
+target = descriptors[-1]
+
+fitting_calculator = extrapolator.fitting_calculator
+
+# check if things are reasonable
+for start in range(0, 8):
+    vectors = descriptors[start:-1]
+    fit_coefficients = fitting_calculator.fit(vectors, target)
+    fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
+    error = np.linalg.norm(target - fitted_target, ord=np.inf)
+    assert error < THRESHOLD
+
+# if we put the target in the vectors used for the fitting,
+# check that we get an error smaller than the regularization
+vectors = descriptors[:-1]
+vectors[0] = target
+fit_coefficients = fitting_calculator.fit(vectors, target)
+fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
+
+print(np.linalg.norm(target - fitted_target, ord=np.inf), max(SMALL, regularization))
+assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization)
-- 
GitLab