diff --git a/lambench/metrics/utils.py b/lambench/metrics/utils.py index 52ce8770..baaa803d 100644 --- a/lambench/metrics/utils.py +++ b/lambench/metrics/utils.py @@ -4,7 +4,7 @@ import lambench from pathlib import Path from collections import defaultdict -from lambench.workflow.entrypoint import gather_models +from lambench.workflow.entrypoint import gather_model_params, gather_model from datetime import datetime ############################# @@ -13,7 +13,7 @@ def get_leaderboard_models(timestamp: Optional[datetime] = None) -> list: - models = gather_models() + models = [gather_model(param, "") for param in gather_model_params()] if timestamp is not None: models = [ model for model in models if model.model_metadata.date_added <= timestamp diff --git a/lambench/models/ase_models.py b/lambench/models/ase_models.py index 15b1b765..a7f11480 100644 --- a/lambench/models/ase_models.py +++ b/lambench/models/ase_models.py @@ -1,6 +1,5 @@ from __future__ import annotations import logging -from functools import cached_property from pathlib import Path from typing import Callable, Literal, Optional @@ -17,6 +16,7 @@ from ase.filters import FrechetCellFilter from ase.io import write from ase.optimize import FIRE +from ase.calculators.emt import EMT from dftd3.ase import DFTD3 from tqdm import tqdm @@ -79,9 +79,10 @@ class ASEModel(BaseLargeAtomModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._calc = None - @cached_property - def calc(self, head=None) -> Calculator: + @property + def calc(self) -> Calculator: """ASE Calculator with the model loaded.""" calculator_dispatch = { "MACE": self._init_mace_calculator, @@ -96,9 +97,19 @@ def calc(self, head=None) -> Calculator: } if self.model_family not in calculator_dispatch: - raise ValueError(f"Model {self.model_name} is not supported by ASEModel") + logging.warning( + f"Model {self.model_name} is not supported by ASEModel, using EMT as default calculator." + ) + self._calc = EMT() + + else: + self._calc = calculator_dispatch[self.model_family]() + return self._calc - return calculator_dispatch[self.model_family]() + @calc.setter + def calc(self, value: Calculator): + logging.warning("Overriding the default calculator.") + self._calc = value def _init_mace_calculator(self) -> Calculator: from mace.calculators import mace_mp @@ -139,7 +150,10 @@ def _init_uma_calculator(self) -> Calculator: from fairchem.core import FAIRChemCalculator predictor = load_predict_unit(self.model_path, device="cuda") - return FAIRChemCalculator(predictor, task_name="omat") + if self.model_domain == "molecules": + return FAIRChemCalculator(predictor, task_name="omol") + else: + return FAIRChemCalculator(predictor, task_name="omat") def _init_mattersim_calculator(self) -> Calculator: from mattersim.forcefield import MatterSimCalculator @@ -149,10 +163,16 @@ def _init_mattersim_calculator(self) -> Calculator: def _init_dp_calculator(self) -> Calculator: from deepmd.calculator import DP - return DP( - model=self.model_path, - head="MP_traj_v024_alldata_mixu", - ) + if self.supports_omol and self.model_domain == "molecules": + return DP( + model=self.model_path, + head="OMol25", + ) + else: + return DP( + model=self.model_path, + head="MP_traj_v024_alldata_mixu", + ) def _init_grace_calculator(self) -> Calculator: from tensorpotential.calculator import grace_fm @@ -181,7 +201,16 @@ def evaluate( import torch torch.set_default_dtype(torch.float32) - return self.run_ase_dptest(self, task.test_data, task.dispersion_correction) + # Use corresponding DFT label for models supporting OMol25 on Molecules tasks + if isinstance(task.test_data, dict): + if self.supports_omol and self.model_domain == "molecules": + data_path = task.test_data["wB97"] + else: + data_path = task.test_data["PBE"] + else: + data_path = task.test_data + + return self.run_ase_dptest(self, data_path, task.dispersion_correction) elif isinstance(task, CalculatorTask): if task.task_name == "nve_md": from lambench.tasks.calculator.nve_md.nve_md import ( diff --git a/lambench/models/basemodel.py b/lambench/models/basemodel.py index 782edcf1..4c6064df 100644 --- a/lambench/models/basemodel.py +++ b/lambench/models/basemodel.py @@ -42,6 +42,8 @@ class BaseLargeAtomModel(BaseModel): show_finetune_task (bool): Flag indicating if the finetune task should be displayed or executed. Default is False. show_calculator_task (bool): Flag indicating if the calculator task should be displayed or executed. Default is False. skip_tasks (list[SkipTaskType]): List of task types that should be skipped during evaluation. + supports_omol (bool): Flag indicating if the model is trained with OMol25 or not. + model_domain (Optional[str]): The model head or task_name to be used for models with multiple domains. Default is None, referring to the head used for `materials` often MPTrj. Methods: evaluate(task) -> dict[str, float]: Abstract method for evaluating the model on a given task. Implementations should return @@ -58,6 +60,8 @@ class BaseLargeAtomModel(BaseModel): show_finetune_task: bool = False show_calculator_task: bool = False skip_tasks: list[SkipTaskType] = [] + supports_omol: bool = False + model_domain: Optional[str] = None @abstractmethod def evaluate(self, task) -> dict[str, float]: diff --git a/lambench/tasks/base_task.py b/lambench/tasks/base_task.py index 08c0c33e..8d4c8020 100644 --- a/lambench/tasks/base_task.py +++ b/lambench/tasks/base_task.py @@ -26,7 +26,7 @@ class BaseTask(BaseModel): """ task_name: str - test_data: Path + test_data: Path | dict[str, Path] task_config: ClassVar[Path] model_config = ConfigDict(extra="allow") workdir: Path = Path(tempfile.gettempdir()) / "lambench" diff --git a/lambench/tasks/direct/direct_tasks.yml b/lambench/tasks/direct/direct_tasks.yml index 9d940c89..eb3f44da 100644 --- a/lambench/tasks/direct/direct_tasks.yml +++ b/lambench/tasks/direct/direct_tasks.yml @@ -1,13 +1,9 @@ -ANI: - test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI" HEA25_S: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S" HEA25_bulk: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25" MoS2: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2" -MD22: - test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22" REANN_CO2_Ni100: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100" NequIP_NC_2022: @@ -24,6 +20,10 @@ HPt_NC_2022: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022" Ca_batteries_CM2021: test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries" +AQM: + test_data: + PBE: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000" + wB97: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000_OMol-wb97mv-def2tzvpd-ORCA600" ## DEPRECATED # Collision: # test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision" @@ -39,3 +39,7 @@ Ca_batteries_CM2021: # test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB" # Torsionnet500: # test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500" +# ANI: +# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI" +# MD22: +# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22" diff --git a/lambench/workflow/entrypoint.py b/lambench/workflow/entrypoint.py index 8213edca..23075d4d 100644 --- a/lambench/workflow/entrypoint.py +++ b/lambench/workflow/entrypoint.py @@ -16,35 +16,40 @@ MODELS = Path(lambench.__file__).parent / "models/models_config.yml" -def gather_models( +def gather_model_params( model_names: Optional[list[str]] = None, -) -> list[BaseLargeAtomModel]: +) -> list[dict]: """ - Gather models from the models_config.yml file. + Gather model parameters from the models_config.yml file for selected models. """ - models = [] + model_params = [] with open(MODELS, "r") as f: model_config: list[dict] = yaml.safe_load(f) for model_param in model_config: if model_names and model_param["model_name"] not in model_names: continue - if model_param["model_type"] == "DP": - models.append(DPModel(**model_param)) - elif model_param["model_type"] == "ASE": - models.append(ASEModel(**model_param)) - else: - raise ValueError( - f"Model type {model_param['model_type']} is not supported." - ) - return models + model_params.append(model_param) + + return model_params + + +def gather_model(model_param: dict, model_domain: str) -> BaseLargeAtomModel: + model_param = model_param.copy() + model_param["model_domain"] = model_domain + if model_param["model_type"] == "DP": + return DPModel(**model_param) + elif model_param["model_type"] == "ASE": + return ASEModel(**model_param) + else: + raise ValueError(f"Model type {model_param['model_type']} is not supported.") job_list: TypeAlias = list[tuple[BaseTask, BaseLargeAtomModel]] def gather_task_type( - models: list[BaseLargeAtomModel], + model_params: list[dict], task_class: Type[BaseTask], task_names: Optional[list[str]] = None, ) -> job_list: @@ -54,18 +59,26 @@ def gather_task_type( tasks = [] with open(task_class.task_config, "r") as f: task_configs: dict[str, dict] = yaml.safe_load(f) - for model in models: - if not hasattr(model, "_finetune") and issubclass( + for model_param in model_params: + if model_param["model_type"] != "DP" and issubclass( task_class, PropertyFinetuneTask ): continue # Regular ASEModel does not support PropertyFinetuneTask for task_name, task_params in task_configs.items(): if (task_names and task_name not in task_names) or task_class.__name__ in ( - model.skip_tasks + model_param.get("skip_tasks", []) ): continue task = task_class(task_name=task_name, **task_params) - if not task.exist(model.model_name): + if not task.exist(model_param["model_name"]): + # model_domain = task.domain if task.domain else "" # in the future we may have tasks with specific domain. + + # currently only need to distinguish direct tasks for molecules and materials due to OMol25 training set. + if task_name in []: # to be added in a separate PR. + model_domain = "molecules" + else: + model_domain = "materials" + model = gather_model(model_param, model_domain) tasks.append((task, model)) return tasks @@ -77,18 +90,18 @@ def gather_jobs( ) -> job_list: jobs: job_list = [] - models = gather_models(model_names) - if not models: + model_params = gather_model_params(model_names) + if not model_params: logging.warning("No models found, skipping task gathering.") return jobs - logging.info(f"Found {len(models)} models, gathering tasks.") + logging.info(f"Found {len(model_params)} models, gathering tasks.") for task_class in BaseTask.__subclasses__(): if task_types and task_class.__name__ not in task_types: continue jobs.extend( gather_task_type( - models=models, task_class=task_class, task_names=task_names + model_params=model_params, task_class=task_class, task_names=task_names ) ) diff --git a/tests/tasks/calculator/test_nve_md.py b/tests/tasks/calculator/test_nve_md.py index fe1ac59e..ceb598f1 100644 --- a/tests/tasks/calculator/test_nve_md.py +++ b/tests/tasks/calculator/test_nve_md.py @@ -5,7 +5,6 @@ from lambench.metrics.utils import aggregated_nve_md_results import pytest from ase import Atoms -from ase.calculators.emt import EMT from lambench.models.ase_models import ASEModel import numpy as np @@ -19,13 +18,7 @@ def setup_testing_data(): @pytest.fixture -def setup_calculator(): - """Fixture to provide an ASE calculator (EMT).""" - return EMT() - - -@pytest.fixture -def setup_model(setup_calculator): +def setup_model(): """Fixture to provide an ASE model.""" ase_models = ASEModel( model_family="TEST", @@ -39,15 +32,14 @@ def setup_model(setup_calculator): }, virtualenv="", ) - ase_models.calc = setup_calculator return ase_models -def test_nve_simulation_metrics(setup_testing_data, setup_calculator): +def test_nve_simulation_metrics(setup_testing_data, setup_model): """Test NVE simulation metrics for std, and steps.""" result = nve_simulation_single( setup_testing_data, - setup_calculator, + setup_model.calc, timestep=1.0, num_steps=100, temperature_K=300, @@ -58,7 +50,7 @@ def test_nve_simulation_metrics(setup_testing_data, setup_calculator): assert isinstance(result["slope"], float), "Slope should be a float." -def test_nve_simulation_crash_handling(setup_testing_data, setup_calculator): +def test_nve_simulation_crash_handling(setup_testing_data): """Test crash handling by simulating an intentional crash.""" atoms = setup_testing_data diff --git a/tests/workflow/test_entrypoint.py b/tests/workflow/test_entrypoint.py index 1d6adb9b..d6c98a83 100644 --- a/tests/workflow/test_entrypoint.py +++ b/tests/workflow/test_entrypoint.py @@ -1,4 +1,3 @@ -from lambench.models.dp_models import DPModel from lambench.tasks import PropertyFinetuneTask import pytest from lambench.workflow.entrypoint import gather_task_type @@ -6,21 +5,21 @@ def _create_dp_model(skip_tasks=[]): - return DPModel( - model_name="test_model", - model_family="test_family", - model_type="DP", - model_path="test_path", - virtualenv="test_env", - model_metadata={ + return { + "model_name": "test_model", + "model_family": "test_family", + "model_type": "DP", + "model_path": "test_path", + "virtualenv": "test_env", + "model_metadata": { "pretty_name": "test", "date_added": "2023-10-01", "extra_content": "test", "num_parameters": 1000, "packages": {"torch": "2.0.0"}, }, - skip_tasks=skip_tasks, - ) + "skip_tasks": skip_tasks, + } @pytest.fixture