Skip to content

Commit

Permalink
Fix missing default for deep_call of EAM model
Browse files Browse the repository at this point in the history
  • Loading branch information
ralf-meyer committed Aug 26, 2024
1 parent edcebde commit a9579b1
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions mleam/models/eam_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,8 @@ def __init__(
self.force_method = force_method

inputs = {"types": tf.keras.Input(shape=(None, 1), ragged=True, dtype=tf.int32)}

if self.preprocessed_input:
# Determine the maximum cutoff value to pass to DeepEAMPotential.
# Defaults to 7.5 if 'cut_b' if missing for one or all pair_types.
# The 'or' in the max function is used as fallback in case the list
# comprehension returns an empty list
if cutoff is None:
cutoff = max(
[params.get(key, 7.5) for key in params if key[0] == "cut_b"]
or [7.5]
)
inputs["pair_types"] = tf.keras.Input(
shape=(None, None, 1), ragged=True, dtype=inputs["types"].dtype
)
Expand All @@ -111,6 +103,15 @@ def __init__(
)
else:
self.cutoff = cutoff
if self.cutoff is None:
# Determine the maximum cutoff value to pass to DeepEAMPotential.
# Defaults to 7.5 if 'cut_b' if missing for one or all pair_types.
# The 'or' in the max function is used as fallback in case the list
# comprehension returns an empty list
self.cutoff = max(
[params.get(key, 7.5) for key in params if key[0] == "cut_b"]
or [7.5]
)
inputs["positions"] = tf.keras.Input(shape=(None, 3), ragged=True)

self._set_inputs(inputs)
Expand Down Expand Up @@ -696,7 +697,7 @@ class NNEmbeddingModel(SMATB):
def get_embedding(self, type):
return NNSqrtEmbedding(
layers=self.params.get(("F_layers", type), [20, 20]),
regularization=self.reg,
regularization=self.hyperparams.get("regularization", 1e-5),
name="%s-Embedding" % type,
)

Expand All @@ -707,7 +708,7 @@ def get_embedding(self):
# tuple is used.
return NNSqrtEmbedding(
layers=self.params.get(("F_layers",), [20, 20]),
regularization=self.reg,
regularization=self.hyperparams.get("regularization", 1e-5),
name="Common-Embedding",
)

Expand All @@ -729,7 +730,7 @@ def get_rho(self, pair_type):
return NNRho(
pair_type,
layers=self.params.get(("rho_layers", pair_type), [20, 20]),
regularization=self.reg,
regularization=self.hyperparams.get("regularization", 1e-5),
name="Rho-%s" % pair_type,
)

Expand All @@ -739,7 +740,7 @@ def get_rho(self, pair_type):
return NNRhoExp(
pair_type,
layers=self.params.get(("rho_layers", pair_type), [20, 20]),
regularization=self.reg,
regularization=self.hyperparams.get("regularization", 1e-5),
name="Rho-%s" % pair_type,
)

Expand Down Expand Up @@ -769,7 +770,7 @@ def get_rho(self, pair_type):
return NNRho(
pair_type,
layers=self.params.get(("rho_layers", pair_type), [20, 20]),
regularization=self.reg,
regularization=self.hyperparams.get("regularization", 1e-5),
name="Rho-%s" % pair_type,
)

Expand Down

0 comments on commit a9579b1

Please sign in to comment.