Skip to content
Snippets Groups Projects
Commit 0a72fd08 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Lint.

parent e7af1061
Branches
Tags
1 merge request!4Options
Pipeline #1951 passed
"""Module which provides functionality to perform fitting.""" """Module which provides functionality to perform fitting."""
import abc
from typing import List from typing import List
import numpy as np import numpy as np
import abc
class AbstractFitting(abc.ABC): class AbstractFitting(abc.ABC):
"""Base class for fitting schemes."""
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
...@@ -35,6 +37,7 @@ class LeastSquare(AbstractFitting): ...@@ -35,6 +37,7 @@ class LeastSquare(AbstractFitting):
} }
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Set options for least square minimization"""
self.options = {} self.options = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.supported_options: if key in self.supported_options:
...@@ -43,7 +46,7 @@ class LeastSquare(AbstractFitting): ...@@ -43,7 +46,7 @@ class LeastSquare(AbstractFitting):
raise ValueError(f"Unsupported option: {key}") raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items(): for option, default_value in self.supported_options.items():
if not option in self.options: if option not in self.options:
self.options[option] = default_value self.options[option] = default_value
if self.options["regularization"] < 0 \ if self.options["regularization"] < 0 \
...@@ -59,8 +62,10 @@ class LeastSquare(AbstractFitting): ...@@ -59,8 +62,10 @@ class LeastSquare(AbstractFitting):
class QuasiTimeReversible(AbstractFitting): class QuasiTimeReversible(AbstractFitting):
def set_options(**kwargs): """Quasi time reversible fitting scheme. Not yet implemented."""
"""TODO"""
def compute(self): def set_options(self, **kwargs):
"""Set options for quasi time reversible fitting"""
def compute(self, vectors: List[np.ndarray], target: np.ndarray):
"""Time reversible least square minimization fitting.""" """Time reversible least square minimization fitting."""
...@@ -108,7 +108,8 @@ class Extrapolator: ...@@ -108,7 +108,8 @@ class Extrapolator:
gammas = self.gammas.get(n) gammas = self.gammas.get(n)
gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients) gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
fit_descriptor = self.fitting_calculator.linear_combination(prev_descriptors, fit_coefficients) fit_descriptor = self.fitting_calculator.linear_combination(
prev_descriptors, fit_coefficients)
if self.options["verbose"]: if self.options["verbose"]:
print("error on descriptor:", \ print("error on descriptor:", \
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment