Skip to content
Merged
15 changes: 12 additions & 3 deletions lambench/models/ase_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from functools import cached_property
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, Literal, Optional

import dpdata
import numpy as np
Expand All @@ -12,10 +12,12 @@
)
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
from ase.constraints import FixSymmetry
from ase.filters import FrechetCellFilter
from ase.io import write
from ase.optimize import FIRE
from dftd3.ase import DFTD3
from tqdm import tqdm

from lambench.models.basemodel import BaseLargeAtomModel
Expand Down Expand Up @@ -179,7 +181,7 @@ def evaluate(
import torch

torch.set_default_dtype(torch.float32)
return self.run_ase_dptest(self, task.test_data)
return self.run_ase_dptest(self, task.test_data, task.damping)
elif isinstance(task, CalculatorTask):
if task.task_name == "nve_md":
from lambench.tasks.calculator.nve_md.nve_md import (
Expand Down Expand Up @@ -265,7 +267,12 @@ def evaluate(
)

@staticmethod
def run_ase_dptest(model: ASEModel, test_data: Path) -> dict:
def run_ase_dptest(
model: ASEModel,
test_data: Path,
damping: Literal["d3bj", "d3zero"] | None = "d3bj",
# check all supported levels at dftd3.qcschema._available_levels
) -> dict:
# Add fparam for charge and spin multiplicity if needed
datatype = DataType(
"fparam",
Expand All @@ -277,6 +284,8 @@ def run_ase_dptest(model: ASEModel, test_data: Path) -> dict:
dpdata.LabeledSystem.register_data_type(datatype)

calc = model.calc
if damping:
calc = SumCalculator([calc, DFTD3(method="PBE", damping=damping)])

energy_err = []
energy_pre = []
Expand Down
5 changes: 3 additions & 2 deletions lambench/tasks/direct/direct_tasks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import ClassVar
from typing import ClassVar, Literal
from lambench.tasks.base_task import BaseTask
from lambench.databases.direct_predict_table import DirectPredictRecord

Expand All @@ -12,6 +12,7 @@ class DirectPredictTask(BaseTask):

record_type: ClassVar = DirectPredictRecord
task_config: ClassVar = Path(__file__).parent / "direct_tasks.yml"

damping: Literal["d3bj", "d3zero"] | None = None
def __init__(self, task_name: str, **kwargs):
super().__init__(task_name=task_name, test_data=kwargs["test_data"])
self.damping = kwargs.get("damping")
Loading