From b5b642585e4896de2c427f2a6bd73000b6579ba9 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Wed, 13 Mar 2024 13:42:50 +0100
Subject: [PATCH] Initial version of a polynomial regression.

---
 gext/fitting.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++++
 gext/main.py    | 33 ++++++++++++++--------
 2 files changed, 95 insertions(+), 12 deletions(-)

diff --git a/gext/fitting.py b/gext/fitting.py
index d9b7997..fc5224f 100644
--- a/gext/fitting.py
+++ b/gext/fitting.py
@@ -115,3 +115,77 @@ class QuasiTimeReversible(AbstractFitting):
             full_coefficients = np.concatenate(([-1.0], coefficients[:-1],
                 2.0*coefficients[-1:], coefficients[-2::-1]))
         return np.array(full_coefficients, dtype=np.float64)
+
+class PolynomialRegression(AbstractFitting):
+
+    """Polynomial regression."""
+
+    supported_options = {
+        "regularization": 1e-3,
+        "minorder": 1,
+        "maxorder": 1,
+        "outerprod": False}
+
+    def set_options(self, **kwargs):
+        """Set options for quasi time reversible fitting"""
+        super().set_options(**kwargs)
+
+        if self.options["regularization"] < 0 \
+                or self.options["regularization"] > 100:
+            raise ValueError("Unsupported value for regularization")
+
+        if self.options["minorder"] < 0 or self.options["minorder"] > 3:
+            raise ValueError("minorder must be >= 0 and <= 3")
+        if self.options["minorder"] > self.options["maxorder"]:
+            raise ValueError("minorder must be <= maxorder")
+        if self.options["maxorder"] > 3:
+            raise ValueError("maxorder must be <= 3")
+
+        self.matrix = np.zeros(0, dtype=np.float64)
+        self.gamma_shape = None
+
+    def get_orders(self, descriptors):
+        orders = []
+        if 0 >= self.options["minorder"] and 0 <= self.options["maxorder"]:
+            if len(descriptors.shape) > 1:
+                orders.append(np.ones((descriptors.shape[0], 1)))
+            else:
+                orders.append(np.ones(1))
+        if 1 >= self.options["minorder"] and 1 <= self.options["maxorder"]:
+            orders.append(descriptors)
+        if 2 >= self.options["minorder"] and 2 <= self.options["maxorder"]:
+            if self.options["outerprod"]:
+                orders.append(np.array([np.outer(d, d).flatten() for d in descriptors]))
+            else:
+                orders.append(descriptors**2)
+        if 3 >= self.options["minorder"] and 3 <= self.options["maxorder"]:
+            orders.append(descriptors**3)
+        if len(orders) > 1:
+            return np.hstack(orders)
+        else:
+            return orders[0]
+
+    def fit(self, descriptor_list: List[np.ndarray], gamma_list: List[np.ndarray]):
+        """Given a set of vectors and a set of gammas, construct the
+        transformation matrix."""
+
+        if self.gamma_shape is None:
+            self.gamma_shape = gamma_list[0].shape
+
+        descriptors = np.array(descriptor_list, dtype=np.float64)
+        gammas = np.reshape(gamma_list,
+            (len(gamma_list), self.gamma_shape[0]*self.gamma_shape[1]))
+
+        vander = self.get_orders(descriptors)
+        a = vander.T @ vander
+        b = vander.T @ gammas
+        if self.options["regularization"] > 0.0:
+            a += np.identity(len(b))*self.options["regularization"]**2
+
+        self.matrix = np.linalg.solve(a, b)
+
+    def apply(self, descriptor):
+        """Apply the matrix to the current descriptor."""
+
+        gamma = self.get_orders(np.array([descriptor])) @ self.matrix
+        return np.reshape(gamma, self.gamma_shape)
diff --git a/gext/main.py b/gext/main.py
index f1546f3..a9b2b43 100644
--- a/gext/main.py
+++ b/gext/main.py
@@ -4,7 +4,7 @@ from typing import Optional
 import numpy as np
 
 from . import grassmann
-from .fitting import LeastSquare, QuasiTimeReversible
+from .fitting import LeastSquare, QuasiTimeReversible, PolynomialRegression
 from .descriptors import Distance, Coulomb
 from .buffer import CircularBuffer
 
@@ -86,6 +86,8 @@ class Extrapolator:
             self.fitting_calculator = LeastSquare()
         elif self.options["fitting"] == "qtr":
             self.fitting_calculator = QuasiTimeReversible()
+        elif self.options["fitting"] == "polynomialregression":
+            self.fitting_calculator = PolynomialRegression()
         else:
             raise ValueError("Unsupported fitting")
         self.fitting_calculator.set_options(**fitting_options)
@@ -131,25 +133,32 @@ class Extrapolator:
         if overlap is None and not self.options["store_overlap"]:
             raise ValueError("Guessing without overlap requires `store_overlap` true.")
 
-        # use the descriptors to find the fitting coefficients
+        # get the required quantities
         prev_descriptors = self.descriptors.get(n)
+        gammas = self.gammas.get(n)
         descriptor = self._compute_descriptor(coords)
-        fit_coefficients = self._fit(prev_descriptors, descriptor)
 
-        # use the fitting coefficients and the previous gammas to
-        # extrapolate a new gamma
-        gammas = self.gammas.get(n)
-        gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
+        # use the descriptors to find the fitting coefficients
+        if self.options["fitting"] == "polynomialregression":
+            self.fitting_calculator.fit(prev_descriptors, gammas)
+            gamma = self.fitting_calculator.apply(descriptor)
+        else:
+            fit_coefficients = self._fit(prev_descriptors, descriptor)
+            # use the fitting coefficients and the previous gammas to
+            # extrapolate a new gamma
+            gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
 
-        if self.options["verbose"]:
-            fit_descriptor = self.fitting_calculator.linear_combination(
-                prev_descriptors, fit_coefficients)
-            print("error on descriptor:", \
-                np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
+            if self.options["verbose"]:
+                fit_descriptor = self.fitting_calculator.linear_combination(
+                    prev_descriptors, fit_coefficients)
+                print("error on descriptor:", \
+                    np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
 
         # if the overlap is not given, use the coefficients to fit
         # a new overlap
         if overlap is None:
+            if self.options["fitting"] == "polynomialregression":
+                raise ValueError("The option polynomial regression needs the overlap")
             overlaps = self.overlaps.get(n)
             overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
             inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
-- 
GitLab