Skip to content
Snippets Groups Projects
Commit 3d456f81 authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Fixed an error in the buffers, added a test.

parent a60d562e
No related branches found
No related tags found
No related merge requests found
...@@ -13,16 +13,23 @@ class CircularBuffer: ...@@ -13,16 +13,23 @@ class CircularBuffer:
self.shape = shape self.shape = shape
self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)] self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)]
self.index = 0 self.index = 0
self.count = 0
def push(self, data): def push(self, data):
"""Add a new matrix to the buffer.""" """Add a new matrix to the buffer."""
self.buffer[self.index] = data.copy() self.buffer[self.index] = data.copy()
self.index = (self.index + 1) % self.n self.index = (self.index + 1) % self.n
if self.count < self.n:
self.count += 1
def get(self, m) -> List[np.ndarray]: def get(self, m) -> List[np.ndarray]:
"""Get the last `m` matrices.""" """Get the last `m` matrices."""
if m > self.n: if m < 0:
raise ValueError("`m` should be less than or equal to the buffer `n`") raise ValueError("`m` should be larger than 0.")
elif m > self.n:
raise ValueError("`m` should be less than or equal to the buffer `n`.")
elif m > self.count:
raise ValueError("`m` is larger than the stored matrices.")
start_idx = (self.index - m) % self.n start_idx = (self.index - m) % self.n
return [self.buffer[i] for i in range(start_idx, start_idx + m)] return [self.buffer[i] for i in range(start_idx, start_idx + m)]
...@@ -31,7 +31,7 @@ class Extrapolator: ...@@ -31,7 +31,7 @@ class Extrapolator:
self.tangent: Optional[np.ndarray] = None self.tangent: Optional[np.ndarray] = None
self.options = kwargs self._set_options(**kwargs)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray,
overlap: np.ndarray): overlap: np.ndarray):
...@@ -73,6 +73,14 @@ class Extrapolator: ...@@ -73,6 +73,14 @@ class Extrapolator:
return c_guess @ c_guess.T return c_guess @ c_guess.T
def _set_options(self, **kwargs):
"""Parse additional options from the additional keyword arguments."""
self.options = {}
if "verbose" in kwargs:
self.options["verbose"] = kwargs["verbose"]
else:
self.options["verbose"] = False
def _get_tangent(self) -> np.ndarray: def _get_tangent(self) -> np.ndarray:
"""Get the tangent point.""" """Get the tangent point."""
if self.tangent is not None: if self.tangent is not None:
......
import pytest
import os
import sys
import numpy as np
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from grext.buffer import CircularBuffer
def test_buffer():
shape = (5, 5)
buffer_size = 10
nframes = 100
buffer = CircularBuffer(buffer_size, shape)
# partial load
for i in range(buffer_size // 2):
matrix = np.full(shape, i)
buffer.push(matrix)
# do some tests on a partially filled buffer
with pytest.raises(ValueError):
buffer.get(buffer_size+1)
with pytest.raises(ValueError):
buffer.get(buffer_size//2 + 1)
with pytest.raises(ValueError):
buffer.get(-1)
assert len(buffer.get(0)) == 0
# finish the loading
for i in range(buffer_size // 2, nframes):
matrix = np.full(shape, i)
buffer.push(matrix)
# do some tests on a fully filled buffer
with pytest.raises(ValueError):
buffer.get(buffer_size+1)
buffer.get(buffer_size//2 + 1)
with pytest.raises(ValueError):
buffer.get(-1)
assert len(buffer.get(0)) == 0
for m in [buffer_size, buffer_size-1, buffer_size-2]:
matrices = buffer.get(m)
for value, matrix in zip(list(range(nframes - m, nframes)), matrices):
assert matrix[0,0] == value
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment