Skip to content
Snippets Groups Projects
Select Git revision
  • c76c9b8c784e785f88711f90b678a94ff032b801
  • main default protected
  • askarpza-main-patch-76094
  • polynomial_regression
  • optimization
  • v0.8.0
  • v0.7.1
  • v0.7.0
  • v0.6.0
  • v0.5.0
  • v0.4.1
  • v0.4.0
  • v0.3.0
  • v0.2.0
14 results

buffer.py

Blame
  • user avatar
    Michele Nottoli authored
    c76c9b8c
    History
    buffer.py 923 B
    """Module that defines a circular buffer for storing the last properties
    in a molecular dynamics simulation."""
    
    from typing import List, Tuple
    import numpy as np
    
    class CircularBuffer:
    
        """Circular buffer to store the last `n` matrices."""
    
        def __init__(self, n: int, shape: Tuple[int, ...]):
            self.n = n
            self.shape = shape
            self.buffer = [np.zeros(shape, dtype=np.float64) for _ in range(n)]
            self.index = 0
    
        def push(self, data):
            """Add a new matrix to the buffer."""
            self.buffer[self.index] = data.copy()
            self.index = (self.index + 1) % self.n
    
        def get(self, m) -> List[np.ndarray]:
            """Get the last `m` matrices."""
            if m > self.n:
                raise ValueError("`m` should be less than or equal to the buffer `n`")
    
            start_idx = (self.index - m) % self.n
            return [self.buffer[i] for i in range(start_idx, start_idx + m)]