Skip to content
Snippets Groups Projects
Commit a7068085 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Checked descriptors.

parent 7da1d3f3
No related branches found
No related tags found
No related merge requests found
Pipeline #2064 failed
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import numpy as np import numpy as np
from scipy.spatial.distance import pdist from scipy.spatial.distance import pdist
class BaseFitting: class BaseDescriptor:
supported_options = {} supported_options = {}
...@@ -16,7 +16,7 @@ class BaseFitting: ...@@ -16,7 +16,7 @@ class BaseFitting:
if len(kwargs) > 0: if len(kwargs) > 0:
raise ValueError("Invalid arguments given to the descriptor class.") raise ValueError("Invalid arguments given to the descriptor class.")
class Distance(BaseFitting): class Distance(BaseDescriptor):
"""Distance matrix descriptors.""" """Distance matrix descriptors."""
...@@ -34,7 +34,7 @@ class Coulomb(Distance): ...@@ -34,7 +34,7 @@ class Coulomb(Distance):
"""Compute the Coulomb matrix as a descriptor.""" """Compute the Coulomb matrix as a descriptor."""
return 1.0/super().compute(coords) return 1.0/super().compute(coords)
class FlattenMatrix(BaseFitting): class FlattenMatrix(BaseDescriptor):
"""Use the quantity as it is, just flatten it.""" """Use the quantity as it is, just flatten it."""
......
...@@ -117,9 +117,8 @@ class LeastSquare(AbstractFitting): ...@@ -117,9 +117,8 @@ class LeastSquare(AbstractFitting):
if self.options["regularization"] > 0.0: if self.options["regularization"] > 0.0:
a += np.identity(len(b))*self.options["regularization"] a += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(a, b) coefficients = np.linalg.solve(a, b)
print(coefficients)
return np.array(coefficients, dtype=np.float64) return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting): class QuasiTimeReversible(AbstractFitting):
"""Quasi time reversible fitting scheme.""" """Quasi time reversible fitting scheme."""
......
...@@ -10,7 +10,7 @@ import gext.fitting ...@@ -10,7 +10,7 @@ import gext.fitting
import gext.grassmann import gext.grassmann
import utils import utils
SMALL = 1e-8 SMALL = 2e-8
THRESHOLD = 5e-2 THRESHOLD = 5e-2
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
...@@ -27,7 +27,7 @@ def test_least_square(datafile, regularization): ...@@ -27,7 +27,7 @@ def test_least_square(datafile, regularization):
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting_regularization=regularization, nsteps=nframes, fitting_regularization=regularization,
fitting="leastsquare") fitting="leastsquare", descriptor="distance")
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
...@@ -69,7 +69,8 @@ def test_quasi_time_reversible(datafile, regularization): ...@@ -69,7 +69,8 @@ def test_quasi_time_reversible(datafile, regularization):
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting="qtr", fitting_regularization=regularization) nsteps=nframes, fitting="qtr", fitting_regularization=regularization,
descriptor="distance")
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment