diff --git a/grext/main.py b/grext/main.py index 7011a00c5ba60493de675c7a7c9aea51ce9b0356..2382c02bc3fc29b48454713e246ececeaa060ce1 100644 --- a/grext/main.py +++ b/grext/main.py @@ -46,7 +46,7 @@ class Extrapolator: """Get a new electronic density to be used as a guess.""" coefficients = fitting.linear() - def get_tangent(self) -> np.ndarray: + def _get_tangent(self) -> np.ndarray: """Get the tangent point.""" if self.tangent is not None: return self.tangent @@ -66,15 +66,13 @@ class Extrapolator: def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray: """Map from the manifold to the tangent plane.""" - if self.tangent is not None: - return grassmann.log(coeff, self.tangent) - raise ValueError("Tangent point is not set.") + tangent = self._get_tangent() + return grassmann.log(coeff, tangent) def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray: """Map from the tangent plane to the manifold.""" - if self.tangent is not None: - return grassmann.exp(gamma, self.tangent) - raise ValueError("Tangent point is not set.") + tangent = self._get_tangent() + return grassmann.exp(gamma, tangent) def _sqrt_overlap(self, overlap): """Compute the square root of the overlap matrix.""" diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py index e510ded1629e25f0ffb284c6bef9046571659548..cfb4c9c3b5ebb5eccb311a97923a565be8653211 100644 --- a/tests/test_grassmann.py +++ b/tests/test_grassmann.py @@ -29,7 +29,7 @@ def test_grassmann_urea(datafile): extrapolator.load_data(coords, coeff, overlap) # check the Grassmann projections - c0 = extrapolator.get_tangent() + c0 = extrapolator._get_tangent() for (coeff, gamma, overlap) in zip(data["coefficients"], extrapolator.gammas.get(nframes), extrapolator.overlaps.get(nframes)):