Skip to content
Snippets Groups Projects
Commit 0c054fa4 authored by Askarpour, Zahra's avatar Askarpour, Zahra
Browse files

DiffFitting

parent 7c4d15ed
No related branches found
No related tags found
No related merge requests found
Pipeline #2059 failed
...@@ -41,6 +41,58 @@ class AbstractFitting(abc.ABC): ...@@ -41,6 +41,58 @@ class AbstractFitting(abc.ABC):
result += vector*coeff result += vector*coeff
return result return result
class DiffFitting(AbstractFitting):
"""Simple least square minimization fitting."""
supported_options = {
"regularization": 0.0,
}
def set_options(self, **kwargs):
"""Set options for least square minimization"""
super().set_options(**kwargs)
if self.options["regularization"] < 0 \
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting
coefficients."""
target=target-vectors[-1]
VECTORS=[]
print("lenvector", len(vectors))
if len(vectors)>1:
for i in range(2, len(vectors)+1):
print("lenvector", len(vectors))
VECTORS.append(vectors[i-2]-vectors[-1])
print(len(VECTORS))
matrix = np.array(VECTORS).T
a = matrix.T @ matrix
b = matrix.T @ target
if self.options["regularization"] > 0.0:
a += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(a, b)
print("coefficients", coefficients)
return np.array(coefficients, dtype=np.float64)
def linear_combination(self, vectors: List[np.ndarray],
coefficients: np. ndarray) -> np.ndarray:
"""Given a set of vectors (or matrices) and the corresponding
coefficients, build their linear combination."""
result = np.zeros(vectors[0].shape, dtype=np.float64)
VECTORS_DiffFitting=[]
if len(vectors)>1:
for i in range(2,len(vectors)+1):
VECTORS_DiffFitting.append(vectors[i-2]-vectors[-1])
for coeff, vector in zip(coefficients, vectors):
result += vector*coeff
result=result+vectors[-1]
print(result.shape)
return result
class LeastSquare(AbstractFitting): class LeastSquare(AbstractFitting):
"""Simple least square minimization fitting.""" """Simple least square minimization fitting."""
...@@ -66,6 +118,7 @@ class LeastSquare(AbstractFitting): ...@@ -66,6 +118,7 @@ class LeastSquare(AbstractFitting):
if self.options["regularization"] > 0.0: if self.options["regularization"] > 0.0:
a += np.identity(len(b))*self.options["regularization"] a += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(a, b) coefficients = np.linalg.solve(a, b)
print(coefficients)
return np.array(coefficients, dtype=np.float64) return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting): class QuasiTimeReversible(AbstractFitting):
...@@ -115,3 +168,4 @@ class QuasiTimeReversible(AbstractFitting): ...@@ -115,3 +168,4 @@ class QuasiTimeReversible(AbstractFitting):
full_coefficients = np.concatenate(([-1.0], coefficients[:-1], full_coefficients = np.concatenate(([-1.0], coefficients[:-1],
2.0*coefficients[-1:], coefficients[-2::-1])) 2.0*coefficients[-1:], coefficients[-2::-1]))
return np.array(full_coefficients, dtype=np.float64) return np.array(full_coefficients, dtype=np.float64)
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import numpy as np import numpy as np
from . import grassmann from . import grassmann
from .fitting import LeastSquare, QuasiTimeReversible from .fitting import LeastSquare, QuasiTimeReversible,DiffFitting
from .descriptors import Distance, Coulomb from .descriptors import Distance, Coulomb
from .buffer import CircularBuffer from .buffer import CircularBuffer
...@@ -18,7 +18,7 @@ class Extrapolator: ...@@ -18,7 +18,7 @@ class Extrapolator:
"verbose": False, "verbose": False,
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
"fitting": "leastsquare", "fitting": "diff",
"allow_partially_filled": True, "allow_partially_filled": True,
"store_overlap": True, "store_overlap": True,
} }
...@@ -84,6 +84,8 @@ class Extrapolator: ...@@ -84,6 +84,8 @@ class Extrapolator:
if self.options["fitting"] == "leastsquare": if self.options["fitting"] == "leastsquare":
self.fitting_calculator = LeastSquare() self.fitting_calculator = LeastSquare()
elif self.options["fitting"] == "diff":
self.fitting_calculator = DiffFitting()
elif self.options["fitting"] == "qtr": elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible() self.fitting_calculator = QuasiTimeReversible()
else: else:
...@@ -135,6 +137,7 @@ class Extrapolator: ...@@ -135,6 +137,7 @@ class Extrapolator:
prev_descriptors = self.descriptors.get(n) prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
fit_coefficients = self._fit(prev_descriptors, descriptor) fit_coefficients = self._fit(prev_descriptors, descriptor)
print(fit_coefficients)
# use the fitting coefficients and the previous gammas to # use the fitting coefficients and the previous gammas to
# extrapolate a new gamma # extrapolate a new gamma
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment