From 092a1c02c1f39f9a1f6880fd18c968b80299c4fc Mon Sep 17 00:00:00 2001 From: Michele Nottoli <michele.nottoli@gmail.com> Date: Thu, 19 Oct 2023 16:13:27 +0200 Subject: [PATCH] Made get_tangent private. --- grext/main.py | 12 +++++------- tests/test_grassmann.py | 2 +- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/grext/main.py b/grext/main.py index 7011a00..2382c02 100644 --- a/grext/main.py +++ b/grext/main.py @@ -46,7 +46,7 @@ class Extrapolator: """Get a new electronic density to be used as a guess.""" coefficients = fitting.linear() - def get_tangent(self) -> np.ndarray: + def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" if self.tangent is not None: return self.tangent @@ -66,15 +66,13 @@ class Extrapolator: def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray: """Map from the manifold to the tangent plane.""" - if self.tangent is not None: - return grassmann.log(coeff, self.tangent) - raise ValueError("Tangent point is not set.") + tangent = self._get_tangent() + return grassmann.log(coeff, tangent) def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray: """Map from the tangent plane to the manifold.""" - if self.tangent is not None: - return grassmann.exp(gamma, self.tangent) - raise ValueError("Tangent point is not set.") + tangent = self._get_tangent() + return grassmann.exp(gamma, tangent) def _sqrt_overlap(self, overlap): """Compute the square root of the overlap matrix.""" diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py index e510ded..cfb4c9c 100644 --- a/tests/test_grassmann.py +++ b/tests/test_grassmann.py @@ -29,7 +29,7 @@ def test_grassmann_urea(datafile): extrapolator.load_data(coords, coeff, overlap) # check the Grassmann projections - c0 = extrapolator.get_tangent() + c0 = extrapolator._get_tangent() for (coeff, gamma, overlap) in zip(data["coefficients"], extrapolator.gammas.get(nframes), extrapolator.overlaps.get(nframes)): -- GitLab