diff --git a/gext/main.py b/gext/main.py index 18716ea08e979a20452a8ac915f8261bc4e0eae8..71de73a5488307de15cefb3360d7dc86c399f5fd 100644 --- a/gext/main.py +++ b/gext/main.py @@ -25,6 +25,9 @@ class Extrapolator: def __init__(self, nelectrons: int, nbasis: int, natoms: int, **kwargs): + if not (type(nelectrons) == int and type(nbasis) == int and type(natoms) == int): + raise ValueError("Dimensions are not integers") + self.nelectrons = nelectrons self.nbasis = nbasis self.natoms = natoms @@ -47,6 +50,7 @@ class Extrapolator: descriptor_options = {} fitting_options = {} + # set specified options for key, value in kwargs.items(): if key in self.supported_options: self.options[key] = value @@ -57,10 +61,13 @@ class Extrapolator: else: raise ValueError(f"Unsupported option: {key}") + # set unspecified options with defaults for option, default_value in self.supported_options.items(): if not option in self.options: self.options[option] = default_value + # do some check on the options, set things and pipe options + # to submodules if self.options["nsteps"] < 1 or self.options["nsteps"] >= 100: raise ValueError("Unsupported nsteps")