Skip to content

Commit 7850ca4

Browse files
authored
Merge pull request #141 from aqlaboratory/hotfix/setup-download-all-parameters
Fix setup-openfold for download all parameters option
2 parents ab97d87 + d7d6875 commit 7850ca4

File tree

3 files changed

+47
-5
lines changed

3 files changed

+47
-5
lines changed

openfold3/entry_points/parameters.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,23 @@ class CheckpointEntry:
3636

3737

3838
OPENFOLD_MODEL_CHECKPOINT_REGISTRY = {
39-
"openfold3-p1": CheckpointEntry(
40-
file_name="of3_ft3_v1.pt", version_compatibility="<0.4"
41-
),
4239
"openfold3-p2-145k": CheckpointEntry(
4340
file_name="of3-p2-145k.pt", version_compatibility=">=0.4"
4441
),
4542
"openfold3-p2-155k": CheckpointEntry(
4643
file_name="of3-p2-155k.pt", version_compatibility=">=0.4"
4744
),
45+
"openfold3-p1": CheckpointEntry( # legacy
46+
file_name="of3_ft3_v1.pt", version_compatibility="<0.4"
47+
),
4848
}
4949

5050
DEFAULT_CHECKPOINT_NAME = "openfold3-p2-155k"
5151

52+
# These checkpoints are not supported for download and use in the current version,
53+
# but are left in the registry for record-keeping and compatibility checks.
54+
LEGACY_CHECKPOINTS = ["openfold3-p1"]
55+
5256

5357
def download_model_parameters(
5458
download_dir: Path,

openfold3/setup_openfold.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from openfold3.core.utils.s3 import download_s3_file, s3_file_matches_local
3030
from openfold3.entry_points.parameters import (
3131
DEFAULT_CHECKPOINT_NAME,
32+
LEGACY_CHECKPOINTS,
3233
OPENFOLD_MODEL_CHECKPOINT_REGISTRY,
3334
download_model_parameters,
3435
)
@@ -131,7 +132,12 @@ def setup_param_directory(
131132

132133
def download_parameters(param_dir) -> None:
133134
"""Perform the parameter download."""
134-
all_checkpoints = list(OPENFOLD_MODEL_CHECKPOINT_REGISTRY.keys())
135+
# Exclude incompatible checkpoints:
136+
all_checkpoints = [
137+
name
138+
for name in OPENFOLD_MODEL_CHECKPOINT_REGISTRY
139+
if name not in LEGACY_CHECKPOINTS
140+
]
135141

136142
logger.info("Select parameters to download:")
137143
logger.info(f"1) Download only the default checkpoint ({DEFAULT_CHECKPOINT_NAME})")

openfold3/tests/test_entry_points.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from openfold3.entry_points.parameters import (
3737
CHECKPOINT_ROOT_FILENAME,
3838
DEFAULT_CHECKPOINT_NAME,
39+
LEGACY_CHECKPOINTS,
3940
OPENFOLD_MODEL_CHECKPOINT_REGISTRY,
4041
CheckpointEntry,
4142
)
@@ -699,7 +700,7 @@ def test_remove_duplicates(self, dummy_ckpt_file, dummy_output_path):
699700

700701

701702
class TestSetupOpenFold:
702-
def test_fresh_parameter_download(self, tmp_path):
703+
def test_fresh_parameter_default_download(self, tmp_path):
703704
inputs = iter(
704705
[
705706
str(tmp_path), # Set cache directory
@@ -726,3 +727,34 @@ def test_fresh_parameter_download(self, tmp_path):
726727
tmp_path
727728
/ OPENFOLD_MODEL_CHECKPOINT_REGISTRY[DEFAULT_CHECKPOINT_NAME].file_name
728729
).exists()
730+
731+
def test_fresh_parameter_download_all(self, tmp_path):
732+
inputs = iter(
733+
[
734+
str(tmp_path), # Set cache directory
735+
"", # Use default (cache) directory for params directory
736+
"2", # download choice: all parameters
737+
"no", # skip integration tests
738+
]
739+
)
740+
741+
with (
742+
patch("builtins.input", side_effect=inputs),
743+
patch(
744+
"openfold3.setup_openfold.download_s3_file",
745+
side_effect=_fake_download_s3_file,
746+
),
747+
):
748+
setup_openfold.main()
749+
750+
# Check that the checkpoint root file exists and has the expected path
751+
assert (tmp_path / CHECKPOINT_ROOT_FILENAME).exists()
752+
assert (tmp_path / CHECKPOINT_ROOT_FILENAME).read_text() == str(tmp_path)
753+
754+
expected_checkpoints = list(
755+
set(OPENFOLD_MODEL_CHECKPOINT_REGISTRY.keys()) - set(LEGACY_CHECKPOINTS)
756+
)
757+
for ckpt_name in expected_checkpoints:
758+
assert (
759+
tmp_path / OPENFOLD_MODEL_CHECKPOINT_REGISTRY[ckpt_name].file_name
760+
).exists()

0 commit comments

Comments
 (0)