diff --git a/gext/fitting.py b/gext/fitting.py index 3c6adb732b43070cdefc09125b50eb720c99ff1f..4d058aaa349e2e3d764bafc4b9c64b435afcdbd3 100644 --- a/gext/fitting.py +++ b/gext/fitting.py @@ -1,11 +1,13 @@ """Module which provides functionality to perform fitting.""" +import abc from typing import List import numpy as np -import abc class AbstractFitting(abc.ABC): + """Base class for fitting schemes.""" + def __init__(self, **kwargs): self.set_options(**kwargs) @@ -35,6 +37,7 @@ class LeastSquare(AbstractFitting): } def set_options(self, **kwargs): + """Set options for least square minimization""" self.options = {} for key, value in kwargs.items(): if key in self.supported_options: @@ -43,7 +46,7 @@ class LeastSquare(AbstractFitting): raise ValueError(f"Unsupported option: {key}") for option, default_value in self.supported_options.items(): - if not option in self.options: + if option not in self.options: self.options[option] = default_value if self.options["regularization"] < 0 \ @@ -59,8 +62,10 @@ class LeastSquare(AbstractFitting): class QuasiTimeReversible(AbstractFitting): - def set_options(**kwargs): - """TODO""" + """Quasi time reversible fitting scheme. Not yet implemented.""" - def compute(self): + def set_options(self, **kwargs): + """Set options for quasi time reversible fitting""" + + def compute(self, vectors: List[np.ndarray], target: np.ndarray): """Time reversible least square minimization fitting.""" diff --git a/gext/main.py b/gext/main.py index b4f4288fa7e88e7d8bd9da334787b6d6a2d03525..7b4e3e569ab6049c4f03622f8deb3b5ba8e2191c 100644 --- a/gext/main.py +++ b/gext/main.py @@ -108,7 +108,8 @@ class Extrapolator: gammas = self.gammas.get(n) gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients) - fit_descriptor = self.fitting_calculator.linear_combination(prev_descriptors, fit_coefficients) + fit_descriptor = self.fitting_calculator.linear_combination( + prev_descriptors, fit_coefficients) if self.options["verbose"]: print("error on descriptor:", \