From a7068085b16e966c7e67f35470bd024418ebecfe Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Mon, 4 Mar 2024 11:46:44 +0100 Subject: [PATCH] Checked descriptors. --- gext/descriptors.py | 6 +++--- gext/fitting.py | 3 +-- tests/test_descriptor_fitting.py | 7 ++++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/gext/descriptors.py b/gext/descriptors.py index 2163412..838446f 100644 --- a/gext/descriptors.py +++ b/gext/descriptors.py @@ -3,7 +3,7 @@ import numpy as np from scipy.spatial.distance import pdist -class BaseFitting: +class BaseDescriptor: supported_options = {} @@ -16,7 +16,7 @@ class BaseFitting: if len(kwargs) > 0: raise ValueError("Invalid arguments given to the descriptor class.") -class Distance(BaseFitting): +class Distance(BaseDescriptor): """Distance matrix descriptors.""" @@ -34,7 +34,7 @@ class Coulomb(Distance): """Compute the Coulomb matrix as a descriptor.""" return 1.0/super().compute(coords) -class FlattenMatrix(BaseFitting): +class FlattenMatrix(BaseDescriptor): """Use the quantity as it is, just flatten it.""" diff --git a/gext/fitting.py b/gext/fitting.py index d67c7c5..0e50e2f 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -117,9 +117,8 @@ class LeastSquare(AbstractFitting): if self.options["regularization"] > 0.0: a += np.identity(len(b))*self.options["regularization"] coefficients = np.linalg.solve(a, b) - print(coefficients) return np.array(coefficients, dtype=np.float64) - + class QuasiTimeReversible(AbstractFitting): """Quasi time reversible fitting scheme.""" diff --git a/tests/test_descriptor_fitting.py b/tests/test_descriptor_fitting.py index f0ae573..d8afede 100644 --- a/tests/test_descriptor_fitting.py +++ b/tests/test_descriptor_fitting.py @@ -10,7 +10,7 @@ import gext.fitting import gext.grassmann import utils -SMALL = 1e-8 +SMALL = 2e-8 THRESHOLD = 5e-2 @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) @@ -27,7 +27,7 @@ def test_least_square(datafile, regularization): # initialize an extrapolator extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes, fitting_regularization=regularization, - fitting="leastsquare") + fitting="leastsquare", descriptor="distance") # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], @@ -69,7 +69,8 @@ def test_quasi_time_reversible(datafile, regularization): # initialize an extrapolator extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, - nsteps=nframes, fitting="qtr", fitting_regularization=regularization) + nsteps=nframes, fitting="qtr", fitting_regularization=regularization, + descriptor="distance") # load data in the extrapolator for (coords, coeff, overlap) in zip(data["trajectory"], -- GitLab