From a7068085b16e966c7e67f35470bd024418ebecfe Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Mon, 4 Mar 2024 11:46:44 +0100
Subject: [PATCH] Checked descriptors.

---
 gext/descriptors.py              | 6 +++---
 gext/fitting.py                  | 3 +--
 tests/test_descriptor_fitting.py | 7 ++++---
 3 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/gext/descriptors.py b/gext/descriptors.py
index 2163412..838446f 100644
--- a/gext/descriptors.py
+++ b/gext/descriptors.py
@@ -3,7 +3,7 @@
 import numpy as np
 from scipy.spatial.distance import pdist
 
-class BaseFitting:
+class BaseDescriptor:
 
     supported_options = {}
 
@@ -16,7 +16,7 @@ class BaseFitting:
         if len(kwargs) > 0:
             raise ValueError("Invalid arguments given to the descriptor class.")
 
-class Distance(BaseFitting):
+class Distance(BaseDescriptor):
 
     """Distance matrix descriptors."""
 
@@ -34,7 +34,7 @@ class Coulomb(Distance):
         """Compute the Coulomb matrix as a descriptor."""
         return 1.0/super().compute(coords)
 
-class FlattenMatrix(BaseFitting):
+class FlattenMatrix(BaseDescriptor):
 
     """Use the quantity as it is, just flatten it."""
 
diff --git a/gext/fitting.py b/gext/fitting.py
index d67c7c5..0e50e2f 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -117,9 +117,8 @@ class LeastSquare(AbstractFitting):
         if self.options["regularization"] > 0.0:
             a += np.identity(len(b))*self.options["regularization"]
         coefficients = np.linalg.solve(a, b)
-        print(coefficients)
         return np.array(coefficients, dtype=np.float64)
-        
+
 class QuasiTimeReversible(AbstractFitting):
 
     """Quasi time reversible fitting scheme."""
diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py
index f0ae573..d8afede 100644
--- a/tests/test_descriptor_fitting.py
+++ b/tests/test_descriptor_fitting.py
@@ -10,7 +10,7 @@ import gext.fitting
 import gext.grassmann
 import utils
 
-SMALL = 1e-8
+SMALL = 2e-8
 THRESHOLD = 5e-2
 
 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
@@ -27,7 +27,7 @@ def test_least_square(datafile, regularization):
     # initialize an extrapolator
     extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
         nsteps=nframes, fitting_regularization=regularization,
-        fitting="leastsquare")
+        fitting="leastsquare", descriptor="distance")
 
     # load data in the extrapolator
     for (coords, coeff, overlap) in zip(data["trajectory"],
@@ -69,7 +69,8 @@ def test_quasi_time_reversible(datafile, regularization):
 
     # initialize an extrapolator
     extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
-        nsteps=nframes, fitting="qtr", fitting_regularization=regularization)
+        nsteps=nframes, fitting="qtr", fitting_regularization=regularization,
+        descriptor="distance")
 
     # load data in the extrapolator
     for (coords, coeff, overlap) in zip(data["trajectory"],
-- 
GitLab