From 0c054fa47d325cdb78af827e83e90778c9044892 Mon Sep 17 00:00:00 2001
From: Zahra Askarpour <Zahra.Askarpour@mathematik.uni-stuttgart.de>
Date: Thu, 15 Feb 2024 11:36:48 +0100
Subject: [PATCH] DiffFitting

---
 gext/fitting.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++++-
 gext/main.py    |  9 +++++---
 2 files changed, 61 insertions(+), 4 deletions(-)

diff --git a/gext/fitting.py b/gext/fitting.py
index efd09f0..84bac2d 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -41,6 +41,58 @@ class AbstractFitting(abc.ABC):
             result += vector*coeff
         return result
 
+
+class DiffFitting(AbstractFitting):
+
+    """Simple least square minimization fitting."""
+
+    supported_options = {
+        "regularization": 0.0,
+    }
+
+    def set_options(self, **kwargs):
+        """Set options for least square minimization"""
+        super().set_options(**kwargs)
+
+        if self.options["regularization"] < 0 \
+                or self.options["regularization"] > 100:
+            raise ValueError("Unsupported value for regularization")
+
+    def fit(self, vectors: List[np.ndarray], target: np.ndarray):
+        """Given a set of vectors and a target return the fitting
+        coefficients."""
+        target=target-vectors[-1]
+        VECTORS=[]
+        print("lenvector", len(vectors))
+        if len(vectors)>1:
+            for i in range(2, len(vectors)+1):
+                print("lenvector", len(vectors))
+                VECTORS.append(vectors[i-2]-vectors[-1])
+                print(len(VECTORS))
+            matrix = np.array(VECTORS).T
+            a = matrix.T @ matrix
+            b = matrix.T @ target
+            if self.options["regularization"] > 0.0:
+                a += np.identity(len(b))*self.options["regularization"]
+            coefficients = np.linalg.solve(a, b)
+            print("coefficients", coefficients)
+            return np.array(coefficients, dtype=np.float64)
+
+    def linear_combination(self, vectors: List[np.ndarray],
+            coefficients: np. ndarray) -> np.ndarray:
+        """Given a set of vectors (or matrices) and the corresponding
+        coefficients, build their linear combination."""
+        result = np.zeros(vectors[0].shape, dtype=np.float64)
+        VECTORS_DiffFitting=[]
+        if len(vectors)>1:
+            for i in range(2,len(vectors)+1):
+                  VECTORS_DiffFitting.append(vectors[i-2]-vectors[-1])
+            for coeff, vector in zip(coefficients, vectors):
+                result += vector*coeff
+                result=result+vectors[-1]
+                print(result.shape)
+            return result
+
 class LeastSquare(AbstractFitting):
 
     """Simple least square minimization fitting."""
@@ -66,8 +118,9 @@ class LeastSquare(AbstractFitting):
         if self.options["regularization"] > 0.0:
             a += np.identity(len(b))*self.options["regularization"]
         coefficients = np.linalg.solve(a, b)
+        print(coefficients)
         return np.array(coefficients, dtype=np.float64)
-
+        
 class QuasiTimeReversible(AbstractFitting):
 
     """Quasi time reversible fitting scheme."""
@@ -115,3 +168,4 @@ class QuasiTimeReversible(AbstractFitting):
             full_coefficients = np.concatenate(([-1.0], coefficients[:-1],
                 2.0*coefficients[-1:], coefficients[-2::-1]))
         return np.array(full_coefficients, dtype=np.float64)
+
diff --git a/gext/main.py b/gext/main.py
index f1546f3..e09140c 100644
--- a/gext/main.py
+++ b/gext/main.py
@@ -4,7 +4,7 @@ from typing import Optional
 import numpy as np
 
 from . import grassmann
-from .fitting import LeastSquare, QuasiTimeReversible
+from .fitting import LeastSquare, QuasiTimeReversible,DiffFitting
 from .descriptors import Distance, Coulomb
 from .buffer import CircularBuffer
 
@@ -18,7 +18,7 @@ class Extrapolator:
         "verbose": False,
         "nsteps": 6,
         "descriptor": "distance",
-        "fitting": "leastsquare",
+        "fitting": "diff",
         "allow_partially_filled": True,
         "store_overlap": True,
     }
@@ -84,6 +84,8 @@ class Extrapolator:
 
         if self.options["fitting"] == "leastsquare":
             self.fitting_calculator = LeastSquare()
+        elif self.options["fitting"] == "diff":
+            self.fitting_calculator = DiffFitting()
         elif self.options["fitting"] == "qtr":
             self.fitting_calculator = QuasiTimeReversible()
         else:
@@ -135,12 +137,13 @@ class Extrapolator:
         prev_descriptors = self.descriptors.get(n)
         descriptor = self._compute_descriptor(coords)
         fit_coefficients = self._fit(prev_descriptors, descriptor)
+        print(fit_coefficients)
 
         # use the fitting coefficients and the previous gammas to
         # extrapolate a new gamma
         gammas = self.gammas.get(n)
         gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
-
+       
         if self.options["verbose"]:
             fit_descriptor = self.fitting_calculator.linear_combination(
                 prev_descriptors, fit_coefficients)
-- 
GitLab