Skip to content
Snippets Groups Projects
Commit 6d7c535a authored by Michele Nottoli's avatar Michele Nottoli
Browse files

Merge branch 'update' into 'main'

QTR

See merge request !6
parents 4e708ccc 9d524710
No related branches found
No related tags found
1 merge request!6QTR
Pipeline #1964 passed
...@@ -43,8 +43,16 @@ The behavior can be finely controlled by passing additional keyword arguments to ...@@ -43,8 +43,16 @@ The behavior can be finely controlled by passing additional keyword arguments to
This is an up to date list of available keyword options: This is an up to date list of available keyword options:
- `nsteps`: integer parameter, number of steps to be used in the extrapolation. - `nsteps`: integer, default 6, number of steps to be used in the extrapolation.
**Note:** Calling `guess` before loading `nsteps` data points will cause a `ValueError`. - `verbose`: boolean, default False, if True print additional information.
- `descriptor`: string, default "distance", possible options are "distance" and "coulomb".
- `fitting`: string, default "leastsquare", possible options are "leastsquare" and "qtr".
- `allow_partially_filled`: bool, default True. If True allow to do a guess before `nsteps` data points have been loaded, if False asking for a guess before `nsteps` data points will cause a `ValueError`.
- `store_overlap`: bool, default True. Store the overlaps for later usage in calling guess without passing the current overlap. It can be disabled for performance, but calling guess will require passing the overlap.
Some options can be piped to the fitting modules.
- `fitting_regularization`: float, default 0.0. Controls the regularization for both the "leastsquare" and "qtr" fitting schemes.
## Acknowledgments ## Acknowledgments
......
...@@ -7,6 +7,8 @@ class Distance: ...@@ -7,6 +7,8 @@ class Distance:
"""Distance matrix descriptors.""" """Distance matrix descriptors."""
supported_options = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
...@@ -24,6 +26,8 @@ class Coulomb(Distance): ...@@ -24,6 +26,8 @@ class Coulomb(Distance):
"""Coulomb matrix descriptors.""" """Coulomb matrix descriptors."""
supported_options = {}
def compute(self, coords: np.ndarray) -> np.ndarray: def compute(self, coords: np.ndarray) -> np.ndarray:
"""Compute the Coulomb matrix as a descriptor.""" """Compute the Coulomb matrix as a descriptor."""
return 1.0/super().compute(coords) return 1.0/super().compute(coords)
...@@ -8,16 +8,29 @@ class AbstractFitting(abc.ABC): ...@@ -8,16 +8,29 @@ class AbstractFitting(abc.ABC):
"""Base class for fitting schemes.""" """Base class for fitting schemes."""
supported_options = {}
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.set_options(**kwargs) self.set_options(**kwargs)
@abc.abstractmethod @abc.abstractmethod
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Base method for setting options.""" """Base method for setting options."""
self.options = {}
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
else:
raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items():
if option not in self.options:
self.options[option] = default_value
@abc.abstractmethod @abc.abstractmethod
def compute(self, vectors: List[np.ndarray], target:np.ndarray): def fit(self, vectors: List[np.ndarray], target:np.ndarray):
"""Base method for computing new fitting coefficients.""" """Base method for computing new fitting coefficients."""
return np.zeros(0)
def linear_combination(self, vectors: List[np.ndarray], def linear_combination(self, vectors: List[np.ndarray],
coefficients: np. ndarray) -> np.ndarray: coefficients: np. ndarray) -> np.ndarray:
...@@ -38,34 +51,66 @@ class LeastSquare(AbstractFitting): ...@@ -38,34 +51,66 @@ class LeastSquare(AbstractFitting):
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Set options for least square minimization""" """Set options for least square minimization"""
self.options = {} super().set_options(**kwargs)
for key, value in kwargs.items():
if key in self.supported_options:
self.options[key] = value
else:
raise ValueError(f"Unsupported option: {key}")
for option, default_value in self.supported_options.items():
if option not in self.options:
self.options[option] = default_value
if self.options["regularization"] < 0 \ if self.options["regularization"] < 0 \
or self.options["regularization"] > 100: or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization") raise ValueError("Unsupported value for regularization")
def compute(self, vectors: List[np.ndarray], target: np.ndarray): def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Given a set of vectors and a target return the fitting """Given a set of vectors and a target return the fitting
coefficients.""" coefficients."""
matrix = np.vstack(vectors).T matrix = np.array(vectors).T
coefficients, _, _, _ = np.linalg.lstsq(matrix, target, rcond=None) a = matrix.T @ matrix
b = matrix.T @ target
if self.options["regularization"] > 0.0:
a += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(a, b)
return np.array(coefficients, dtype=np.float64) return np.array(coefficients, dtype=np.float64)
class QuasiTimeReversible(AbstractFitting): class QuasiTimeReversible(AbstractFitting):
"""Quasi time reversible fitting scheme. Not yet implemented.""" """Quasi time reversible fitting scheme. Not yet implemented."""
supported_options = {
"regularization": 0.0,
}
def set_options(self, **kwargs): def set_options(self, **kwargs):
"""Set options for quasi time reversible fitting""" """Set options for quasi time reversible fitting"""
super().set_options(**kwargs)
if self.options["regularization"] < 0 \
or self.options["regularization"] > 100:
raise ValueError("Unsupported value for regularization")
def compute(self, vectors: List[np.ndarray], target: np.ndarray): def fit(self, vectors: List[np.ndarray], target: np.ndarray):
"""Time reversible least square minimization fitting.""" """Time reversible least square minimization fitting."""
past_target = vectors[0]
matrix = np.array(vectors[1:]).T
q = matrix.shape[1]
if q == 1:
time_reversible_matrix = matrix
elif q%2 == 0:
time_reversible_matrix = matrix[:, :q//2] + matrix[:, :q//2-1:-1]
else:
time_reversible_matrix = matrix[:, :q//2+1] + matrix[:, :q//2-1:-1]
a = time_reversible_matrix.T @ time_reversible_matrix
b = time_reversible_matrix.T @ (target + past_target)
if self.options["regularization"] > 0.0:
a += np.identity(len(b))*self.options["regularization"]
coefficients = np.linalg.solve(a, b)
if q == 1:
full_coefficients = np.concatenate(([-1.0], coefficients))
elif q%2 == 0:
full_coefficients = np.concatenate(([-1.0], coefficients,
coefficients[::-1]))
else:
full_coefficients = np.concatenate(([-1.0], coefficients[:-1],
2.0*coefficients[-1:], coefficients[-2::-1]))
return np.array(full_coefficients, dtype=np.float64)
...@@ -12,29 +12,32 @@ class Extrapolator: ...@@ -12,29 +12,32 @@ class Extrapolator:
"""Class for performing Grassmann extrapolations. On initialization """Class for performing Grassmann extrapolations. On initialization
it requires the number of electrons, the number of basis functions it requires the number of electrons, the number of basis functions
and the number of atoms of the molecule. The number of previous and the number of atoms of the molecule."""
steps used by the extrapolator is an optional argument with default
value of 6."""
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): supported_options = {
self.supported_options = {
"verbose": False, "verbose": False,
"nsteps": 6, "nsteps": 6,
"descriptor": "distance", "descriptor": "distance",
"fitting": "leastsquare", "fitting": "leastsquare",
"allow_partially_filled": True, "allow_partially_filled": True,
"store_overlap": True,
} }
def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs):
if not (type(nelectrons) == int and type(nbasis) == int and type(natoms) == int):
raise ValueError("Dimensions are not integers")
self.nelectrons = nelectrons self.nelectrons = nelectrons
self.nbasis = nbasis self.nbasis = nbasis
self.natoms = natoms self.natoms = natoms
self.set_options(**kwargs) self.set_options(**kwargs)
self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis)) self.gammas = CircularBuffer(self.options["nsteps"], (self.nelectrons//2, self.nbasis))
self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis))
self.descriptors = CircularBuffer(self.options["nsteps"], self.descriptors = CircularBuffer(self.options["nsteps"],
((self.natoms - 1)*self.natoms//2, )) ((self.natoms - 1)*self.natoms//2, ))
if self.options["store_overlap"]:
self.overlaps = CircularBuffer(self.options["nsteps"], (self.nbasis, self.nbasis))
self.tangent: Optional[np.ndarray] = None self.tangent: Optional[np.ndarray] = None
...@@ -47,6 +50,7 @@ class Extrapolator: ...@@ -47,6 +50,7 @@ class Extrapolator:
descriptor_options = {} descriptor_options = {}
fitting_options = {} fitting_options = {}
# set specified options
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.supported_options: if key in self.supported_options:
self.options[key] = value self.options[key] = value
...@@ -57,11 +61,14 @@ class Extrapolator: ...@@ -57,11 +61,14 @@ class Extrapolator:
else: else:
raise ValueError(f"Unsupported option: {key}") raise ValueError(f"Unsupported option: {key}")
# set unspecified options with defaults
for option, default_value in self.supported_options.items(): for option, default_value in self.supported_options.items():
if not option in self.options: if not option in self.options:
self.options[option] = default_value self.options[option] = default_value
if self.options["nsteps"] <= 1 or self.options["nsteps"] >= 100: # do some check on the options, set things and pipe options
# to submodules
if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100:
raise ValueError("Unsupported nsteps") raise ValueError("Unsupported nsteps")
if self.options["descriptor"] == "distance": if self.options["descriptor"] == "distance":
...@@ -77,20 +84,26 @@ class Extrapolator: ...@@ -77,20 +84,26 @@ class Extrapolator:
elif self.options["fitting"] == "qtr": elif self.options["fitting"] == "qtr":
self.fitting_calculator = QuasiTimeReversible() self.fitting_calculator = QuasiTimeReversible()
else: else:
raise ValueError("Unsupported descriptor") raise ValueError("Unsupported fitting")
self.fitting_calculator.set_options(**fitting_options) self.fitting_calculator.set_options(**fitting_options)
def load_data(self, coords: np.ndarray, coeff: np.ndarray, def load_data(self, coords: np.ndarray, coeff: np.ndarray, overlap):
overlap: np.ndarray):
"""Load a new data point in the extrapolator.""" """Load a new data point in the extrapolator."""
# Crop the coefficient matrix up to the number of electron
# pairs, then apply S^1/2
coeff = self._crop_coeff(coeff) coeff = self._crop_coeff(coeff)
coeff = self._normalize(coeff, overlap) coeff = self._normalize(coeff, overlap)
# if it is the first time we load data, set the tangent point
if self.tangent is None: if self.tangent is None:
self._set_tangent(coeff) self._set_tangent(coeff)
# push the new data to the corresponding vectors
self.gammas.push(self._grassmann_log(coeff)) self.gammas.push(self._grassmann_log(coeff))
self.descriptors.push(self._compute_descriptor(coords)) self.descriptors.push(self._compute_descriptor(coords))
if self.options["store_overlap"]:
self.overlaps.push(overlap) self.overlaps.push(overlap)
def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray: def guess(self, coords: np.ndarray, overlap=None) -> np.ndarray:
...@@ -101,25 +114,38 @@ class Extrapolator: ...@@ -101,25 +114,38 @@ class Extrapolator:
def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray: def guess_coefficients(self, coords: np.ndarray, overlap=None) -> np.ndarray:
"""Get a new coefficient matrix to be used as a guess.""" """Get a new coefficient matrix to be used as a guess."""
# check if we have enough data points to perform an extrapolation
count = self.descriptors.count
if self.options["allow_partially_filled"]: if self.options["allow_partially_filled"]:
n = min(self.options["nsteps"], self.descriptors.count) if count == 0:
raise ValueError("Not enough data loaded in the extrapolator")
n = min(self.options["nsteps"], count)
else: else:
n = self.options["nsteps"] n = self.options["nsteps"]
if count < n:
raise ValueError("Not enough data loaded in the extrapolator")
if overlap is None and not self.options["store_overlap"]:
raise ValueError("Guessing without overlap requires `store_overlap` true.")
# use the descriptors to find the fitting coefficients
prev_descriptors = self.descriptors.get(n) prev_descriptors = self.descriptors.get(n)
descriptor = self._compute_descriptor(coords) descriptor = self._compute_descriptor(coords)
fit_coefficients = self.fitting_calculator.compute(prev_descriptors, descriptor) fit_coefficients = self._fit(prev_descriptors, descriptor)
# use the fitting coefficients and the previous gammas to
# extrapolate a new gamma
gammas = self.gammas.get(n) gammas = self.gammas.get(n)
gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients) gamma = self.fitting_calculator.linear_combination(gammas, fit_coefficients)
if self.options["verbose"]:
fit_descriptor = self.fitting_calculator.linear_combination( fit_descriptor = self.fitting_calculator.linear_combination(
prev_descriptors, fit_coefficients) prev_descriptors, fit_coefficients)
if self.options["verbose"]:
print("error on descriptor:", \ print("error on descriptor:", \
np.linalg.norm(fit_descriptor - descriptor, ord=np.inf)) np.linalg.norm(fit_descriptor - descriptor, ord=np.inf))
# if the overlap is not given, use the coefficients to fit
# a new overlap
if overlap is None: if overlap is None:
overlaps = self.overlaps.get(n) overlaps = self.overlaps.get(n)
overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients) overlap = self.fitting_calculator.linear_combination(overlaps, fit_coefficients)
...@@ -127,10 +153,9 @@ class Extrapolator: ...@@ -127,10 +153,9 @@ class Extrapolator:
else: else:
inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap) inverse_sqrt_overlap = self._inverse_sqrt_overlap(overlap)
# use the overlap and gamma to find a new set of coefficients
c_guess = self._grassmann_exp(gamma) c_guess = self._grassmann_exp(gamma)
c_guess = inverse_sqrt_overlap @ c_guess return inverse_sqrt_overlap @ c_guess
return c_guess
def _get_tangent(self) -> np.ndarray: def _get_tangent(self) -> np.ndarray:
"""Get the tangent point.""" """Get the tangent point."""
...@@ -176,3 +201,8 @@ class Extrapolator: ...@@ -176,3 +201,8 @@ class Extrapolator:
def _compute_descriptor(self, coords) -> np.ndarray: def _compute_descriptor(self, coords) -> np.ndarray:
"""Given a set of coordinates compute the corresponding descriptor.""" """Given a set of coordinates compute the corresponding descriptor."""
return self.descriptor_calculator.compute(coords) return self.descriptor_calculator.compute(coords)
def _fit(self, prev_descriptors, descriptor) -> np.ndarray:
"""Fit the current descriptor using previous descriptors and
the specified fitting scheme."""
return self.fitting_calculator.fit(prev_descriptors, descriptor)
...@@ -10,10 +10,12 @@ import gext.fitting ...@@ -10,10 +10,12 @@ import gext.fitting
import gext.grassmann import gext.grassmann
import utils import utils
SMALL = 1e-10 SMALL = 1e-8
THRESHOLD = 5e-2
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_descriptor_fitting(datafile): @pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05])
def test_least_square(datafile, regularization):
# load test data from json file # load test data from json file
data = utils.load_json(f"tests/{datafile}") data = utils.load_json(f"tests/{datafile}")
...@@ -23,33 +25,107 @@ def test_descriptor_fitting(datafile): ...@@ -23,33 +25,107 @@ def test_descriptor_fitting(datafile):
nframes = data["trajectory"].shape[0] nframes = data["trajectory"].shape[0]
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=nframes) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting_regularization=regularization,
fitting="leastsquare")
# load data in the extrapolator # load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"], for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]): data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap) extrapolator.load_data(coords, coeff, overlap)
# we check if the error goes down with a larger data set
errors = []
descriptors = extrapolator.descriptors.get(10) descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1] target = descriptors[-1]
fitting_calculator = gext.fitting.LeastSquare() fitting_calculator = extrapolator.fitting_calculator
# check if things are reasonable
for start in range(0, 9): for start in range(0, 9):
vectors = descriptors[start:-1] vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.compute(vectors, target) fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
errors.append(np.linalg.norm(target - fitted_target, ord=np.inf)) error = np.linalg.norm(target - fitted_target, ord=np.inf)
assert error < THRESHOLD
assert errors[0] < errors[-1] # if we put the target in the vectors used for the fitting,
# check that we get an error smaller than the regularization
# we check that we get a zero error if we put the target in the vectors
# used for the fitting
vectors = descriptors[:-1] vectors = descriptors[:-1]
vectors[0] = target vectors[0] = target
fit_coefficients = fitting_calculator.compute(vectors, target) fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients) fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
assert np.linalg.norm(target - fitted_target, ord=np.inf) < SMALL assert np.linalg.norm(target - fitted_target, ord=np.inf) < max(SMALL, regularization)
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
@pytest.mark.parametrize("regularization", [0.0, 0.01, 0.05])
def test_quasi_time_reversible(datafile, regularization):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting="qtr", fitting_regularization=regularization)
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
descriptors = extrapolator.descriptors.get(10)
target = descriptors[-1]
fitting_calculator = extrapolator.fitting_calculator
# check if things are reasonable
for start in range(0, 8):
vectors = descriptors[start:-1]
fit_coefficients = fitting_calculator.fit(vectors, target)
fitted_target = fitting_calculator.linear_combination(vectors, fit_coefficients)
error = np.linalg.norm(target - fitted_target, ord=np.inf)
assert error < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_time_reversibility(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=nframes, fitting="qtr")
# load data in the extrapolator
for (coords, coeff, overlap) in zip(data["trajectory"],
data["coefficients"], data["overlaps"]):
extrapolator.load_data(coords, coeff, overlap)
descriptors = extrapolator.descriptors.get(10)
# we symmetrize the future and past targets (remember that it is
# quasi time reversible, not exactly time reversible)
target = descriptors[0] + descriptors[-1]
descriptors[0] = target
descriptors[-1] = target
fitting_calculator = extrapolator.fitting_calculator
# fit the future target
fit_coefficients = fitting_calculator.fit(descriptors[:-1], descriptors[-1])
fitted_target = fitting_calculator.linear_combination(descriptors[:-1], fit_coefficients)
# fit the past target
reversed_descriptors = list(reversed(descriptors))
fit_coefficients = fitting_calculator.fit(reversed_descriptors[:-1], reversed_descriptors[-1])
fitted_target_reverse = fitting_calculator.linear_combination(reversed_descriptors[:-1], fit_coefficients)
# check the time reversibility
assert np.linalg.norm(fitted_target - fitted_target_reverse, ord=np.inf) < SMALL
...@@ -9,10 +9,13 @@ import gext.grassmann ...@@ -9,10 +9,13 @@ import gext.grassmann
import utils import utils
SMALL = 1e-10 SMALL = 1e-10
THRESHOLD = 1e-2 THRESHOLD = 5e-2
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"]) @pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_extrapolation(datafile): @pytest.mark.parametrize("fitting", ["leastsquare", "qtr"])
@pytest.mark.parametrize("regularization", [0.0, 1e-6, 5e-6])
@pytest.mark.parametrize("descriptor", ["distance", "coulomb"])
def test_extrapolation(datafile, fitting, regularization, descriptor):
# load test data from json file # load test data from json file
data = utils.load_json(f"tests/{datafile}") data = utils.load_json(f"tests/{datafile}")
...@@ -26,7 +29,9 @@ def test_extrapolation(datafile): ...@@ -26,7 +29,9 @@ def test_extrapolation(datafile):
assert n < nframes assert n < nframes
# initialize an extrapolator # initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n) extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms,
nsteps=n, fitting=fitting, fitting_regularization=regularization,
descriptor=descriptor)
# load data in the extrapolator up to index n - 1 # load data in the extrapolator up to index n - 1
for (coords, coeff, overlap) in zip(data["trajectory"][:n], for (coords, coeff, overlap) in zip(data["trajectory"][:n],
...@@ -106,3 +111,37 @@ def test_coefficient_extrapolation(datafile): ...@@ -106,3 +111,37 @@ def test_coefficient_extrapolation(datafile):
assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD assert np.linalg.norm(guessed_density - density, ord=np.inf) < THRESHOLD
assert np.linalg.norm(guessed_density - density, ord=np.inf) \ assert np.linalg.norm(guessed_density - density, ord=np.inf) \
/np.linalg.norm(density, ord=np.inf) < THRESHOLD /np.linalg.norm(density, ord=np.inf) < THRESHOLD
@pytest.mark.parametrize("datafile", ["urea.json", "glucose.json"])
def test_errors(datafile):
# load test data from json file
data = utils.load_json(f"tests/{datafile}")
nelectrons = data["nelectrons"]
natoms = data["trajectory"].shape[1]
nbasis = data["overlaps"].shape[1]
nframes = data["trajectory"].shape[0]
# amount of data we want to use for fitting
n = 9
assert n < nframes
# initialize an extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n)
with pytest.raises(ValueError):
extrapolator.guess(data["trajectory"][0])
# initialize a new extrapolator
extrapolator = gext.Extrapolator(nelectrons, nbasis, natoms, nsteps=n,
allow_partially_filled=False)
# load data in the extrapolator up to index m - 1
m = 4
for (coords, coeff, overlap) in zip(data["trajectory"][:m],
data["coefficients"][:m], data["overlaps"][:m]):
extrapolator.load_data(coords, coeff, overlap)
# check an extrapolation at index m
with pytest.raises(ValueError):
extrapolator.guess(data["trajectory"][m])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment