diff --git a/grext/buffer.py b/grext/buffer.py index c5ce6bc9611ee63490b53011c84c82820714809e..64c937b852b4d248bfb5d4bb03e1efcd60325fa9 100644 --- a/grext/buffer.py +++ b/grext/buffer.py @@ -13,16 +13,23 @@ class CircularBuffer: self.shape = shape self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)] self.index = 0 + self.count = 0 def push(self, data): """Add a new matrix to the buffer.""" self.buffer[self.index] = data.copy() self.index = (self.index + 1) % self.n + if self.count < self.n: + self.count += 1 def get(self, m) -> List[np.ndarray]: """Get the last `m` matrices.""" - if m > self.n: - raise ValueError("`m` should be less than or equal to the buffer `n`") + if m < 0: + raise ValueError("`m` should be larger than 0.") + elif m > self.n: + raise ValueError("`m` should be less than or equal to the buffer `n`.") + elif m > self.count: + raise ValueError("`m` is larger than the stored matrices.") start_idx = (self.index - m) % self.n return [self.buffer[i] for i in range(start_idx, start_idx + m)] diff --git a/grext/main.py b/grext/main.py index f12718a1ae197500272c31e29b0613cc9fffb11c..7249619887fff8ac1846b22fa17860ffcc181488 100644 --- a/grext/main.py +++ b/grext/main.py @@ -31,7 +31,7 @@ class Extrapolator: self.tangent: Optional[np.ndarray] = None - self.options = kwargs + self._set_options(**kwargs) def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap: np.ndarray): @@ -73,6 +73,14 @@ class Extrapolator: return c_guess @ c_guess.T + def _set_options(self, **kwargs): + """Parse additional options from the additional keyword arguments.""" + self.options = {} + if "verbose" in kwargs: + self.options["verbose"] = kwargs["verbose"] + else: + self.options["verbose"] = False + def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" if self.tangent is not None: diff --git a/tests/test_buffer.py b/tests/test_buffer.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d31a9ac9911b34e01b4246a27752a7e0e926dc --- /dev/null +++ b/tests/test_buffer.py @@ -0,0 +1,53 @@ +import pytest +import os +import sys +import numpy as np + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from grext.buffer import CircularBuffer + +def test_buffer(): + + shape = (5, 5) + buffer_size = 10 + nframes = 100 + + buffer = CircularBuffer(buffer_size, shape) + + # partial load + for i in range(buffer_size // 2): + matrix = np.full(shape, i) + buffer.push(matrix) + + # do some tests on a partially filled buffer + with pytest.raises(ValueError): + buffer.get(buffer_size+1) + + with pytest.raises(ValueError): + buffer.get(buffer_size//2 + 1) + + with pytest.raises(ValueError): + buffer.get(-1) + + assert len(buffer.get(0)) == 0 + + # finish the loading + for i in range(buffer_size // 2, nframes): + matrix = np.full(shape, i) + buffer.push(matrix) + + # do some tests on a fully filled buffer + with pytest.raises(ValueError): + buffer.get(buffer_size+1) + + buffer.get(buffer_size//2 + 1) + + with pytest.raises(ValueError): + buffer.get(-1) + + assert len(buffer.get(0)) == 0 + + for m in [buffer_size, buffer_size-1, buffer_size-2]: + matrices = buffer.get(m) + for value, matrix in zip(list(range(nframes - m, nframes)), matrices): + assert matrix[0,0] == value