diff --git a/grext/fitting.py b/grext/fitting.py index 431581cc042e1ac9ec186fa7614ac6d7337a3c8b..6dfd0de00fa993a62e46e68a9cb5d9619302bed3 100644 --- a/grext/fitting.py +++ b/grext/fitting.py @@ -5,14 +5,16 @@ import numpy as np def linear(vectors: List[np.ndarray], target: np.ndarray): """Simple least square minimization fitting.""" - A = np.vstack(vectors).T - coefficients, _, _, _ = np.linalg.lstsq(A, target, rcond=None) + matrix = np.vstack(vectors).T + coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) return np.array(coefficients, dtype=np.float64) def quasi_time_reversible(): """Time reversible least square minimization fitting.""" def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray: + """Given a set of vectors (or matrices) and the corresponding + coefficients, build their linear combination.""" result = np.zeros(vectors[0].shape, dtype=np.float64) for coeff, vector in zip(coefficients, vectors): result += vector*coeff