diff --git a/grext/fitting.py b/grext/fitting.py
index 9e4d26dc6279a2a34b40deab21abebe9afab110f..431581cc042e1ac9ec186fa7614ac6d7337a3c8b 100644
--- a/grext/fitting.py
+++ b/grext/fitting.py
@@ -13,5 +13,7 @@ def quasi_time_reversible():
     """Time reversible least square minimization fitting."""
 
 def linear_combination(vectors: List[np.ndarray], coefficients: np.ndarray) -> np.ndarray:
-    A = np.vstack(vectors).T
-    return A @ coefficients
+    result = np.zeros(vectors[0].shape, dtype=np.float64)
+    for coeff, vector in zip(coefficients, vectors):
+        result += vector*coeff
+    return result
diff --git a/grext/main.py b/grext/main.py
index e17766dbf8787fb08a96bc8a3f65f92b85a76555..f12718a1ae197500272c31e29b0613cc9fffb11c 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -17,20 +17,22 @@ class Extrapolator:
     value of 10."""
 
     def __init__(self, nelectrons: int, nbasis: int, natoms: int,
-            nsteps: int = 10):
+            nsteps: int = 10, **kwargs):
 
         self.nelectrons = nelectrons
         self.nbasis = nbasis
         self.natoms = natoms
         self.nsteps = nsteps
 
-        self.gammas = CircularBuffer(self.nsteps, (self.nelectrons, self.nbasis))
+        self.gammas = CircularBuffer(self.nsteps, (self.nelectrons//2, self.nbasis))
         self.overlaps = CircularBuffer(self.nsteps, (self.nbasis, self.nbasis))
         self.descriptors = CircularBuffer(self.nsteps,
             ((self.natoms - 1)*self.natoms//2, ))
 
         self.tangent: Optional[np.ndarray] = None
 
+        self.options = kwargs
+
     def load_data(self, coords: np.ndarray, coeff: np.ndarray,
             overlap: np.ndarray):
         """Load a new data point in the extrapolator."""
@@ -44,19 +46,24 @@ class Extrapolator:
         self.descriptors.push(self._compute_descriptor(coords))
         self.overlaps.push(overlap)
 
-    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray]) -> np.ndarray:
+    def guess(self, coords: np.ndarray, overlap: Optional[np.ndarray] = None) -> np.ndarray:
         """Get a new electronic density to be used as a guess."""
 
         prev_descriptors = self.descriptors.get(self.nsteps)
-        gammas = self.gammas.get(self.nsteps)
         descriptor = self._compute_descriptor(coords)
-        coefficients = fitting.linear(prev_descriptors, descriptor)
+        fit_coefficients = fitting.linear(prev_descriptors, descriptor)
+
+        gammas = self.gammas.get(self.nsteps)
+        gamma = fitting.linear_combination(gammas, fit_coefficients)
+
+        fit_descriptor = fitting.linear_combination(prev_descriptors, fit_coefficients)
 
-        gamma = fitting.linear_combination(gammas, coefficients)
+        if self.options["verbose"]:
+            print("error on descriptor:", np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
 
         if overlap is None:
             overlaps = self.overlaps.get(self.nsteps)
-            overlap = fitting.linear_combination(overlaps, coefficients)
+            overlap = fitting.linear_combination(overlaps, fit_coefficients)
             inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
         else:
             inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
@@ -74,7 +81,7 @@ class Extrapolator:
 
     def _crop_coeff(self, coeff) -> np.ndarray:
         """Crop the coefficient matrix to remove the virtual orbitals."""
-        return coeff[:, :self.nelectrons]
+        return coeff[:, :self.nelectrons//2]
 
     def _normalize(self, coeff: np.ndarray, overlap: np.ndarray) -> np.ndarray:
         """Normalize the coefficients such that C.T * C = 1, D * D = D."""
diff --git a/tests/test_extrapolation.py b/tests/test_extrapolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcb992a4d9986ac31826d77746a192203bf8054
--- /dev/null
+++ b/tests/test_extrapolation.py
@@ -0,0 +1,43 @@
+import pytest
+import os
+import sys
+import numpy as np
+
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+import grext
+import grext.grassmann
+import utils
+
+SMALL = 1e-10
+THRESHOLD = 1e-2
+
+@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
+def test_extrapolation(datafile):
+
+    # load test data from json file
+    data = utils.load_json(f"tests/{datafile}")
+    nelectrons = data["nelectrons"]
+    natoms = data["trajectory"].shape[1]
+    nbasis = data["overlaps"].shape[1]
+    nframes = data["trajectory"].shape[0]
+
+    # amount of data we want to use for fitting
+    n = 9
+    assert n < nframes
+
+    # initialize an extrapolator
+    extrapolator = grext.Extrapolator(nelectrons, nbasis, natoms, n)
+
+    # load data in the extrapolator up to index n - 1
+    for (coords, coeff, overlap) in zip(data["trajectory"][:n],
+            data["coefficients"][:n], data["overlaps"][:n]):
+        extrapolator.load_data(coords, coeff, overlap)
+
+    # check an extrapolation at index n
+    guessed_density = extrapolator.guess(data["trajectory"][n], data["overlaps"][n])
+    coeff = data["coefficients"][n][:, :nelectrons//2]
+    density = coeff @ coeff.T
+
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
+    assert np.linalg.norm(guessed_density - density, ord=np.inf) \
+          /np.linalg.norm(density, ord=np.inf) < THRESHOLD