From 092a1c02c1f39f9a1f6880fd18c968b80299c4fc Mon Sep 17 00:00:00 2001
From: Michele Nottoli <michele.nottoli@gmail.com>
Date: Thu, 19 Oct 2023 16:13:27 +0200
Subject: [PATCH] Made get_tangent private.

---
 grext/main.py           | 12 +++++-------
 tests/test_grassmann.py |  2 +-
 2 files changed, 6 insertions(+), 8 deletions(-)

diff --git a/grext/main.py b/grext/main.py
index 7011a00..2382c02 100644
--- a/grext/main.py
+++ b/grext/main.py
@@ -46,7 +46,7 @@ class Extrapolator:
         """Get a new electronic density to be used as a guess."""
         coefficients = fitting.linear()
 
-    def get_tangent(self) -> np.ndarray:
+    def _get_tangent(self) -> np.ndarray:
         """Get the tangent point."""
         if self.tangent is not None:
             return self.tangent
@@ -66,15 +66,13 @@ class Extrapolator:
 
     def _grassmann_log(self, coeff: np.ndarray) -> np.ndarray:
         """Map from the manifold to the tangent plane."""
-        if self.tangent is not None:
-            return grassmann.log(coeff, self.tangent)
-        raise ValueError("Tangent point is not set.")
+        tangent = self._get_tangent()
+        return grassmann.log(coeff, tangent)
 
     def _grassmann_exp(self, gamma: np.ndarray) -> np.ndarray:
         """Map from the tangent plane to the manifold."""
-        if self.tangent is not None:
-            return grassmann.exp(gamma, self.tangent)
-        raise ValueError("Tangent point is not set.")
+        tangent = self._get_tangent()
+        return grassmann.exp(gamma, tangent)
 
     def _sqrt_overlap(self, overlap):
         """Compute the square root of the overlap matrix."""
diff --git a/tests/test_grassmann.py b/tests/test_grassmann.py
index e510ded..cfb4c9c 100644
--- a/tests/test_grassmann.py
+++ b/tests/test_grassmann.py
@@ -29,7 +29,7 @@ def test_grassmann_urea(datafile):
         extrapolator.load_data(coords, coeff, overlap)
 
     # check the Grassmann projections
-    c0 = extrapolator.get_tangent()
+    c0 = extrapolator._get_tangent()
     for (coeff, gamma, overlap) in zip(data["coefficients"],
             extrapolator.gammas.get(nframes), extrapolator.overlaps.get(nframes)):
 
-- 
GitLab