diff --git a/grext/buffer.py b/grext/buffer.py index 64c937b852b4d248bfb5d4bb03e1efcd60325fa9..bf9d8cfcc09627ff0884a6067e4a473d5cc95124 100644 --- a/grext/buffer.py +++ b/grext/buffer.py @@ -32,4 +32,4 @@ class CircularBuffer: raise ValueError("`m` is larger than the stored matrices.") start_idx = (self.index - m) % self.n - return [self.buffer[i] for i in range(start_idx, start_idx + m)] + return [self.buffer[(start_idx + i) % self.n] for i in range(m)] diff --git a/tests/test_buffer.py b/tests/test_buffer.py index e2d31a9ac9911b34e01b4246a27752a7e0e926dc..659c5bd945197c2cc8f2712406d2f2823841bb4a 100644 --- a/tests/test_buffer.py +++ b/tests/test_buffer.py @@ -10,7 +10,7 @@ def test_buffer(): shape = (5, 5) buffer_size = 10 - nframes = 100 + nframes = 20 buffer = CircularBuffer(buffer_size, shape) @@ -31,6 +31,11 @@ def test_buffer(): assert len(buffer.get(0)) == 0 + for m in [buffer_size//2, buffer_size//2-1, buffer_size//2-2]: + matrices = buffer.get(m) + for value, matrix in zip(list(range(buffer_size//2 - m, buffer_size//2)), matrices): + assert matrix[0,0] == value + # finish the loading for i in range(buffer_size // 2, nframes): matrix = np.full(shape, i) @@ -51,3 +56,15 @@ def test_buffer(): matrices = buffer.get(m) for value, matrix in zip(list(range(nframes - m, nframes)), matrices): assert matrix[0,0] == value + +def test_buffer_manual(): + shape = (5, 5) + buffer = CircularBuffer(6, shape) + + for i in range(6): + matrix = np.full(shape, i) + buffer.push(matrix) + + matrices = buffer.get(6) + for matrix, value in zip(matrices, range(6)): + assert matrix[0, 0] == value