Skip to content

Commit 8dc1cc9

Browse files
committed
Add a friendly deprecation warning for pae_enabled model preset
1 parent e346bbb commit 8dc1cc9

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

openfold3/projects/of3_all_atom/project_entry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,32 @@
1818
from dataclasses import dataclass
1919

2020
from ml_collections import ConfigDict
21-
from pydantic import BaseModel
21+
from pydantic import BaseModel, field_validator
2222
from pydantic import ConfigDict as PydanticConfigDict
2323

2424
from openfold3.core.config.config_utils import load_yaml
2525
from openfold3.projects.of3_all_atom.config.model_config import model_config
2626
from openfold3.projects.of3_all_atom.runner import OpenFold3AllAtom
2727

28+
logger = logging.getLogger(__name__)
29+
2830

2931
class ModelUpdate(BaseModel):
3032
model_config = PydanticConfigDict(extra="forbid")
3133
presets: list[str] = []
3234
custom: dict = {}
3335

36+
@field_validator("presets", mode="before")
37+
def warn_if_pae_enabled_included(cls, v):
38+
if "pae_enabled" in v:
39+
logger.warning(
40+
"The `pae_enabled` model preset is deprecated and will be removed."
41+
" from the ModelUpdate.presets list. Please remove `pae_enabled`"
42+
" from your model update section"
43+
)
44+
v.remove("pae_enabled")
45+
return v
46+
3447

3548
@dataclass
3649
class OF3ProjectEntry:

openfold3/tests/test_entry_points.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,14 @@ def test_low_mem_model_config_preset(self, tmp_path, dummy_ckpt_file):
316316
# test existing setting in experiment runner is not overwritten
317317
assert not model_cfg.settings.memory.eval.use_lma
318318

319+
def test_model_update_with_pae_enabled_triggers_warning(self):
320+
with patch(
321+
"openfold3.projects.of3_all_atom.project_entry.logger"
322+
) as mock_logger:
323+
ModelUpdate.model_validate({"presets": ["predict", "pae_enabled"]})
324+
warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list]
325+
assert any("model preset is deprecated" in msg for msg in warning_messages)
326+
319327

320328
class DummyWandbExperiment:
321329
def __init__(self, directory):

0 commit comments

Comments
 (0)