Select Git revision
buffer.py 923 B
"""Module that defines a circular buffer for storing the last properties
in a molecular dynamics simulation."""
from typing import List, Tuple
import numpy as np
class CircularBuffer:
"""Circular buffer to store the last `n` matrices."""
def __init__(self, n: int, shape: Tuple[int, ...]):
self.n = n
self.shape = shape
self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)]
self.index = 0
def push(self, data):
"""Add a new matrix to the buffer."""
self.buffer[self.index] = data.copy()
self.index = (self.index + 1) % self.n
def get(self, m) -> List[np.ndarray]:
"""Get the last `m` matrices."""
if m > self.n:
raise ValueError("`m` should be less than or equal to the buffer `n`")
start_idx = (self.index - m) % self.n
return [self.buffer[i] for i in range(start_idx, start_idx + m)]