diff --git a/src/atomate2/forcefields/utils.py b/src/atomate2/forcefields/utils.py index ec65b5d282..80ddee6042 100644 --- a/src/atomate2/forcefields/utils.py +++ b/src/atomate2/forcefields/utils.py @@ -421,7 +421,8 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N import matgl from matgl.ext.ase import PESCalculator - potential = matgl.load_model("M3GNet-MP-2021.2.8-PES") + path = kwargs.get("path", "M3GNet-MP-2021.2.8-PES") + potential = matgl.load_model(path) calculator = PESCalculator(potential, **kwargs) elif calculator_name == MLFF.MACE: diff --git a/tests/forcefields/test_utils.py b/tests/forcefields/test_utils.py index 9b4b6417f5..e8d3d53f9f 100644 --- a/tests/forcefields/test_utils.py +++ b/tests/forcefields/test_utils.py @@ -163,3 +163,25 @@ def test_fix_symmetry(fix_symmetry): assert symmetry_init["number"] == symmetry_final["number"] == 229 else: assert symmetry_init["number"] != symmetry_final["number"] == 99 + + +def test_m3gnet_pot(): + import matgl + from matgl.ext.ase import PESCalculator + + kwargs_calc = {"path": "M3GNet-MP-2021.2.8-DIRECT-PES", "stress_weight": 2.0} + kwargs_default = {"stress_weight": 2.0} + + m3gnet_calculator = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_calc) + + # uses "M3GNet-MP-2021.2.8-PES" per default + m3gnet_default = ase_calculator(calculator_meta="MLFF.M3GNet", **kwargs_default) + + potential = matgl.load_model("M3GNet-MP-2021.2.8-DIRECT-PES") + m3gnet_pes_calc = PESCalculator(potential=potential, stress_weight=2.0) + + assert str(m3gnet_pes_calc.potential) == str(m3gnet_calculator.potential) + # casting necessary because can't be compared + assert str(m3gnet_pes_calc.potential) != str(m3gnet_default.potential) + assert m3gnet_pes_calc.stress_weight == m3gnet_calculator.stress_weight + assert m3gnet_pes_calc.stress_weight == m3gnet_default.stress_weight