Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lambench/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import lambench
from pathlib import Path
from collections import defaultdict
from lambench.workflow.entrypoint import gather_models
from lambench.workflow.entrypoint import gather_model_params, gather_model
from datetime import datetime

#############################
Expand All @@ -13,7 +13,7 @@


def get_leaderboard_models(timestamp: Optional[datetime] = None) -> list:
models = gather_models()
models = [gather_model(param, "") for param in gather_model_params()]
if timestamp is not None:
models = [
model for model in models if model.model_metadata.date_added <= timestamp
Expand Down
51 changes: 40 additions & 11 deletions lambench/models/ase_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations
import logging
from functools import cached_property
from pathlib import Path
from typing import Callable, Literal, Optional

Expand All @@ -17,6 +16,7 @@
from ase.filters import FrechetCellFilter
from ase.io import write
from ase.optimize import FIRE
from ase.calculators.emt import EMT
from dftd3.ase import DFTD3
from tqdm import tqdm

Expand Down Expand Up @@ -79,9 +79,10 @@ class ASEModel(BaseLargeAtomModel):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._calc = None

@cached_property
def calc(self, head=None) -> Calculator:
@property
def calc(self) -> Calculator:
"""ASE Calculator with the model loaded."""
calculator_dispatch = {
"MACE": self._init_mace_calculator,
Expand All @@ -96,9 +97,19 @@ def calc(self, head=None) -> Calculator:
}

if self.model_family not in calculator_dispatch:
raise ValueError(f"Model {self.model_name} is not supported by ASEModel")
logging.warning(
f"Model {self.model_name} is not supported by ASEModel, using EMT as default calculator."
)
self._calc = EMT()

else:
self._calc = calculator_dispatch[self.model_family]()
return self._calc

return calculator_dispatch[self.model_family]()
@calc.setter
def calc(self, value: Calculator):
logging.warning("Overriding the default calculator.")
self._calc = value

def _init_mace_calculator(self) -> Calculator:
from mace.calculators import mace_mp
Expand Down Expand Up @@ -139,7 +150,10 @@ def _init_uma_calculator(self) -> Calculator:
from fairchem.core import FAIRChemCalculator

predictor = load_predict_unit(self.model_path, device="cuda")
return FAIRChemCalculator(predictor, task_name="omat")
if self.model_domain == "molecules":
return FAIRChemCalculator(predictor, task_name="omol")
else:
return FAIRChemCalculator(predictor, task_name="omat")

def _init_mattersim_calculator(self) -> Calculator:
from mattersim.forcefield import MatterSimCalculator
Expand All @@ -149,10 +163,16 @@ def _init_mattersim_calculator(self) -> Calculator:
def _init_dp_calculator(self) -> Calculator:
from deepmd.calculator import DP

return DP(
model=self.model_path,
head="MP_traj_v024_alldata_mixu",
)
if self.supports_omol and self.model_domain == "molecules":
return DP(
model=self.model_path,
head="OMol25",
)
else:
return DP(
model=self.model_path,
head="MP_traj_v024_alldata_mixu",
)

def _init_grace_calculator(self) -> Calculator:
from tensorpotential.calculator import grace_fm
Expand Down Expand Up @@ -181,7 +201,16 @@ def evaluate(
import torch

torch.set_default_dtype(torch.float32)
return self.run_ase_dptest(self, task.test_data, task.dispersion_correction)
# Use corresponding DFT label for models supporting OMol25 on Molecules tasks
if isinstance(task.test_data, dict):
if self.supports_omol and self.model_domain == "molecules":
data_path = task.test_data["wB97"]
else:
data_path = task.test_data["PBE"]
else:
data_path = task.test_data

return self.run_ase_dptest(self, data_path, task.dispersion_correction)
elif isinstance(task, CalculatorTask):
if task.task_name == "nve_md":
from lambench.tasks.calculator.nve_md.nve_md import (
Expand Down
4 changes: 4 additions & 0 deletions lambench/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class BaseLargeAtomModel(BaseModel):
show_finetune_task (bool): Flag indicating if the finetune task should be displayed or executed. Default is False.
show_calculator_task (bool): Flag indicating if the calculator task should be displayed or executed. Default is False.
skip_tasks (list[SkipTaskType]): List of task types that should be skipped during evaluation.
supports_omol (bool): Flag indicating if the model is trained with OMol25 or not.
model_domain (Optional[str]): The model head or task_name to be used for models with multiple domains. Default is None, referring to the head used for `materials` often MPTrj.
Methods:
evaluate(task) -> dict[str, float]:
Abstract method for evaluating the model on a given task. Implementations should return
Expand All @@ -58,6 +60,8 @@ class BaseLargeAtomModel(BaseModel):
show_finetune_task: bool = False
show_calculator_task: bool = False
skip_tasks: list[SkipTaskType] = []
supports_omol: bool = False
model_domain: Optional[str] = None

@abstractmethod
def evaluate(self, task) -> dict[str, float]:
Expand Down
2 changes: 1 addition & 1 deletion lambench/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BaseTask(BaseModel):
"""

task_name: str
test_data: Path
test_data: Path | dict[str, Path]
task_config: ClassVar[Path]
model_config = ConfigDict(extra="allow")
workdir: Path = Path(tempfile.gettempdir()) / "lambench"
Expand Down
12 changes: 8 additions & 4 deletions lambench/tasks/direct/direct_tasks.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
ANI:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
HEA25_S:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S"
HEA25_bulk:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25"
MoS2:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2"
MD22:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
REANN_CO2_Ni100:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100"
NequIP_NC_2022:
Expand All @@ -24,6 +20,10 @@ HPt_NC_2022:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022"
Ca_batteries_CM2021:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries"
AQM:
test_data:
PBE: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000"
wB97: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000_OMol-wb97mv-def2tzvpd-ORCA600"
## DEPRECATED
# Collision:
# test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"
Expand All @@ -39,3 +39,7 @@ Ca_batteries_CM2021:
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
# Torsionnet500:
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500"
# ANI:
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
# MD22:
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
57 changes: 35 additions & 22 deletions lambench/workflow/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,40 @@
MODELS = Path(lambench.__file__).parent / "models/models_config.yml"


def gather_models(
def gather_model_params(
model_names: Optional[list[str]] = None,
) -> list[BaseLargeAtomModel]:
) -> list[dict]:
"""
Gather models from the models_config.yml file.
Gather model parameters from the models_config.yml file for selected models.
"""

models = []
model_params = []
with open(MODELS, "r") as f:
model_config: list[dict] = yaml.safe_load(f)
for model_param in model_config:
if model_names and model_param["model_name"] not in model_names:
continue
if model_param["model_type"] == "DP":
models.append(DPModel(**model_param))
elif model_param["model_type"] == "ASE":
models.append(ASEModel(**model_param))
else:
raise ValueError(
f"Model type {model_param['model_type']} is not supported."
)
return models
model_params.append(model_param)

return model_params


def gather_model(model_param: dict, model_domain: str) -> BaseLargeAtomModel:
model_param = model_param.copy()
model_param["model_domain"] = model_domain
if model_param["model_type"] == "DP":
return DPModel(**model_param)
elif model_param["model_type"] == "ASE":
return ASEModel(**model_param)
else:
raise ValueError(f"Model type {model_param['model_type']} is not supported.")


job_list: TypeAlias = list[tuple[BaseTask, BaseLargeAtomModel]]


def gather_task_type(
models: list[BaseLargeAtomModel],
model_params: list[dict],
task_class: Type[BaseTask],
task_names: Optional[list[str]] = None,
) -> job_list:
Expand All @@ -54,18 +59,26 @@ def gather_task_type(
tasks = []
with open(task_class.task_config, "r") as f:
task_configs: dict[str, dict] = yaml.safe_load(f)
for model in models:
if not hasattr(model, "_finetune") and issubclass(
for model_param in model_params:
if model_param["model_type"] != "DP" and issubclass(
task_class, PropertyFinetuneTask
):
continue # Regular ASEModel does not support PropertyFinetuneTask
for task_name, task_params in task_configs.items():
if (task_names and task_name not in task_names) or task_class.__name__ in (
model.skip_tasks
model_param.get("skip_tasks", [])
):
continue
task = task_class(task_name=task_name, **task_params)
if not task.exist(model.model_name):
if not task.exist(model_param["model_name"]):
# model_domain = task.domain if task.domain else "" # in the future we may have tasks with specific domain.

# currently only need to distinguish direct tasks for molecules and materials due to OMol25 training set.
if task_name in []: # to be added in a separate PR.
model_domain = "molecules"
else:
model_domain = "materials"
model = gather_model(model_param, model_domain)
tasks.append((task, model))
return tasks

Expand All @@ -77,18 +90,18 @@ def gather_jobs(
) -> job_list:
jobs: job_list = []

models = gather_models(model_names)
if not models:
model_params = gather_model_params(model_names)
if not model_params:
logging.warning("No models found, skipping task gathering.")
return jobs

logging.info(f"Found {len(models)} models, gathering tasks.")
logging.info(f"Found {len(model_params)} models, gathering tasks.")
for task_class in BaseTask.__subclasses__():
if task_types and task_class.__name__ not in task_types:
continue
jobs.extend(
gather_task_type(
models=models, task_class=task_class, task_names=task_names
model_params=model_params, task_class=task_class, task_names=task_names
)
)

Expand Down
16 changes: 4 additions & 12 deletions tests/tasks/calculator/test_nve_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from lambench.metrics.utils import aggregated_nve_md_results
import pytest
from ase import Atoms
from ase.calculators.emt import EMT
from lambench.models.ase_models import ASEModel
import numpy as np

Expand All @@ -19,13 +18,7 @@ def setup_testing_data():


@pytest.fixture
def setup_calculator():
"""Fixture to provide an ASE calculator (EMT)."""
return EMT()


@pytest.fixture
def setup_model(setup_calculator):
def setup_model():
"""Fixture to provide an ASE model."""
ase_models = ASEModel(
model_family="TEST",
Expand All @@ -39,15 +32,14 @@ def setup_model(setup_calculator):
},
virtualenv="",
)
ase_models.calc = setup_calculator
return ase_models


def test_nve_simulation_metrics(setup_testing_data, setup_calculator):
def test_nve_simulation_metrics(setup_testing_data, setup_model):
"""Test NVE simulation metrics for std, and steps."""
result = nve_simulation_single(
setup_testing_data,
setup_calculator,
setup_model.calc,
timestep=1.0,
num_steps=100,
temperature_K=300,
Expand All @@ -58,7 +50,7 @@ def test_nve_simulation_metrics(setup_testing_data, setup_calculator):
assert isinstance(result["slope"], float), "Slope should be a float."


def test_nve_simulation_crash_handling(setup_testing_data, setup_calculator):
def test_nve_simulation_crash_handling(setup_testing_data):
"""Test crash handling by simulating an intentional crash."""
atoms = setup_testing_data

Expand Down
19 changes: 9 additions & 10 deletions tests/workflow/test_entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
from lambench.models.dp_models import DPModel
from lambench.tasks import PropertyFinetuneTask
import pytest
from lambench.workflow.entrypoint import gather_task_type
from unittest.mock import MagicMock


def _create_dp_model(skip_tasks=[]):
return DPModel(
model_name="test_model",
model_family="test_family",
model_type="DP",
model_path="test_path",
virtualenv="test_env",
model_metadata={
return {
"model_name": "test_model",
"model_family": "test_family",
"model_type": "DP",
"model_path": "test_path",
"virtualenv": "test_env",
"model_metadata": {
"pretty_name": "test",
"date_added": "2023-10-01",
"extra_content": "test",
"num_parameters": 1000,
"packages": {"torch": "2.0.0"},
},
skip_tasks=skip_tasks,
)
"skip_tasks": skip_tasks,
}


@pytest.fixture
Expand Down