From 3d456f81a7462348373dfcb0a18d4e7f5cb29032 Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Fri, 20 Oct 2023 16:24:23 +0200
Subject: [PATCH] Fixed an error in the buffers, added a test.

---
 grext/buffer.py      | 11 +++++++--
 grext/main.py        | 10 ++++++++-
 tests/test_buffer.py | 53 ++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 71 insertions(+), 3 deletions(-)
 create mode 100644 tests/test_buffer.py

diff --git a/grext/buffer.py b/grext/buffer.py
index c5ce6bc..64c937b 100644
--- a/grext/buffer.py
+++ b/grext/buffer.py
@@ -13,16 +13,23 @@ class CircularBuffer:
         self.shape = shape
         self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)]
         self.index = 0
+        self.count = 0
 
     def push(self, data):
         """Add a new matrix to the buffer."""
         self.buffer[self.index] = data.copy()
         self.index = (self.index + 1) % self.n
+        if self.count < self.n:
+            self.count += 1
 
     def get(self, m) -> List[np.ndarray]:
         """Get the last `m` matrices."""
-        if m > self.n:
-            raise ValueError("`m` should be less than or equal to the buffer `n`")
+        if m < 0:
+            raise ValueError("`m` should be larger than 0.")
+        elif m > self.n:
+            raise ValueError("`m` should be less than or equal to the buffer `n`.")
+        elif m > self.count:
+            raise ValueError("`m` is larger than the stored matrices.")
 
         start_idx = (self.index - m) % self.n
         return [self.buffer[i] for i in range(start_idx, start_idx + m)]
diff --git a/grext/main.py b/grext/main.py
index f12718a..7249619 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -31,7 +31,7 @@ class Extrapolator:
 
         self.tangent: Optional[np.ndarray] = None
 
-        self.options = kwargs
+        self._set_options(**kwargs)
 
     def load_data(self, coords: np.ndarray, coeff: np.ndarray,
             overlap: np.ndarray):
@@ -73,6 +73,14 @@ class Extrapolator:
 
         return c_guess @ c_guess.T
 
+    def _set_options(self, **kwargs):
+        """Parse additional options from the additional keyword arguments."""
+        self.options = {}
+        if "verbose" in kwargs:
+            self.options["verbose"] = kwargs["verbose"]
+        else:
+            self.options["verbose"] = False
+
     def _get_tangent(self) -> np.ndarray:
         """Get the tangent point."""
         if self.tangent is not None:
diff --git a/tests/test_buffer.py b/tests/test_buffer.py
new file mode 100644
index 0000000..e2d31a9
--- /dev/null
+++ b/tests/test_buffer.py
@@ -0,0 +1,53 @@
+import pytest
+import os
+import sys
+import numpy as np
+
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+from grext.buffer import CircularBuffer
+
+def test_buffer():
+
+    shape = (5, 5)
+    buffer_size = 10
+    nframes = 100
+
+    buffer = CircularBuffer(buffer_size, shape)
+
+    # partial load
+    for i in range(buffer_size // 2):
+        matrix = np.full(shape, i)
+        buffer.push(matrix)
+
+    # do some tests on a partially filled buffer
+    with pytest.raises(ValueError):
+        buffer.get(buffer_size+1)
+
+    with pytest.raises(ValueError):
+        buffer.get(buffer_size//2 + 1)
+
+    with pytest.raises(ValueError):
+        buffer.get(-1)
+
+    assert len(buffer.get(0)) == 0
+
+    # finish the loading
+    for i in range(buffer_size // 2, nframes):
+        matrix = np.full(shape, i)
+        buffer.push(matrix)
+
+    # do some tests on a fully filled buffer
+    with pytest.raises(ValueError):
+        buffer.get(buffer_size+1)
+
+    buffer.get(buffer_size//2 + 1)
+
+    with pytest.raises(ValueError):
+        buffer.get(-1)
+
+    assert len(buffer.get(0)) == 0
+
+    for m in [buffer_size, buffer_size-1, buffer_size-2]:
+        matrices = buffer.get(m)
+        for value, matrix in zip(list(range(nframes - m, nframes)), matrices):
+            assert matrix[0,0] == value
-- 
GitLab