Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/source/Inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,15 @@ run_openfold predict \

(323-inference-without-msas)=
#### 3.2.3 🚫 Inference Without MSAs
To run OpenFold3 without MSA features, you need to provide a "dummy" MSA file that only contains the query sequence. See the {ref}`MSA formatting section <1-precomputed-msa-files>` in our precomputed MSA documentation for how to prepare these files. Note that currently, if no MSAs are provided at all, the input sequence will not be propagated to the MSA embedder and prediction quality will be significantly reduced, so a completely MSA-free inference mode is currently discouraged if the goal is to obtain decent quality structures without aligned sequences. We are working on automatic dummmy MSA generation to make it the default no-MSA behavior.
You can run OpenFold3 without MSAs. Prediction performance may be worse than predictions that use MSAs

```bash
run_openfold predict \
--query_json /path/to/query.json \
--use_msa_server=False \
--output_dir /path/to/output/ \
--runner_yaml /path/to/inference.yml
```

(33-customized-inference-settings-using-runneryml)=
### 3.3 Customized Inference Settings Using `runner.yml`
Expand Down
2 changes: 1 addition & 1 deletion environments/development.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pytest
pytest==8.4.2 # until v9.0.1 is available due to unittest bug: pytest-dev/pytest#13895
ruff
awscli
aria2
6 changes: 6 additions & 0 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from openfold3.core.data.pipelines.preprocessing.template import TemplatePreprocessor
from openfold3.core.data.tools.colabfold_msa_server import (
MsaComputationSettings,
augment_main_msa_with_query_sequence,
preprocess_colabfold_msas,
)
from openfold3.core.utils.tensor_utils import dict_multimap
Expand Down Expand Up @@ -521,6 +522,11 @@ def prepare_data(self) -> None:
inference_query_set=self.inference_config.query_set,
compute_settings=self.msa_computation_settings,
)
else:
self.inference_config.query_set = augment_main_msa_with_query_sequence(
inference_query_set=self.inference_config.query_set,
compute_settings=self.msa_computation_settings,
)

if self.use_templates:
template_preprocessor = TemplatePreprocessor(
Expand Down
45 changes: 45 additions & 0 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,48 @@ def preprocess_colabfold_msas(
)

return inference_query_set


def augment_main_msa_with_query_sequence(
inference_query_set: InferenceQuerySet,
compute_settings: MsaComputationSettings,
) -> InferenceQuerySet:
output_directory = compute_settings.msa_output_directory
for query_name, query in inference_query_set.queries.items():
for chain in query.chains:
if (
chain.molecule_type == MoleculeType.PROTEIN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hard-coding the moltypes, could you make them dependent on MSASettings.moltypes (should be in projects.of3_all_atom.config.dataset_config_components), which we use elsewhere to determine which molecule types are expected to have MSAs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, moltypes is attribute of MSASettings not MSAComputationSettings.
So I think it's not accessible yet in the DataModule prepare_data level.
Is there any way to cleanly address this?
But I think it's okay to compute & save for both. Later if moltypes don't include RNA or PROTEIN, this field will be ignored. How do you think?

or chain.molecule_type == MoleculeType.RNA
) and chain.main_msa_file_paths is None:
dummy_rep_dir = (
output_directory / "dummy" / get_sequence_hash(chain.sequence)
)
dummy_aln = ">query\n" + chain.sequence

# If save as a3m...
if compute_settings and "a3m" in compute_settings.msa_file_format:
dummy_rep_dir.mkdir(parents=True, exist_ok=True)
a3m_file = dummy_rep_dir / "colabfold_main.a3m"
with open(a3m_file, "w") as f:
f.write(dummy_aln)
chain.main_msa_file_paths = [a3m_file]
# If save as npz...
else:
npz_file = Path(f"{dummy_rep_dir}.npz")
npz_file.parent.mkdir(exist_ok=True, parents=True)
msas_preparsed = {"dummy": parse_a3m(dummy_aln).to_dict()}
np.savez_compressed(npz_file, **msas_preparsed)
chain.main_msa_file_paths = [npz_file]

chain_ids = ",".join(chain.chain_ids)
warnings.warn(
(
f"Expected MSA file for chain {chain_ids} of type "
f"{chain.molecule_type.name} in query {query_name}, "
"but no MSA files found. A dummy MSA with only "
"the query sequence will be used for this chain."
),
stacklevel=2,
)

return inference_query_set
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class MSASettings(BaseModel):
"nt_hits",
"concat_cfdb_uniref100_filtered",
"colabfold_main",
"dummy", # aln containing only query; used for MSA-free inference
]
paired_msa_order: list = ["colabfold_paired"]

Expand Down
36 changes: 36 additions & 0 deletions openfold3/tests/test_colabfold_msa.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ColabFoldQueryRunner,
ComplexGroup,
MsaComputationSettings,
augment_main_msa_with_query_sequence,
collect_colabfold_msa_data,
get_sequence_hash,
preprocess_colabfold_msas,
Expand Down Expand Up @@ -215,6 +216,41 @@ def test_msa_generation_on_multiple_queries_with_same_name(
f"Expected file {f} not found in main directory"
)

@pytest.mark.parametrize(
"msa_file_format", ["a3m", "npz"], ids=lambda fmt: f"format={fmt}"
)
def test_augment_main_msa_with_query_sequence(
self,
tmp_path,
msa_file_format,
):
sequence = "TEST"
msa_compute_settings = MsaComputationSettings(
msa_file_format=msa_file_format,
server_user_agent="test-agent",
server_url="https://dummy.url",
save_mappings=True,
msa_output_directory=tmp_path,
cleanup_msa_dir=False,
)

query = self._construct_monomer_query(sequence)
augmented = augment_main_msa_with_query_sequence(query, msa_compute_settings)
match msa_file_format:
case "a3m":
f = f"{get_sequence_hash(sequence)}/colabfold_main.a3m"
case "npz":
f = f"{get_sequence_hash(sequence)}.npz"

expected_file = tmp_path / "dummy" / f
assert expected_file.exists(), f"Expected file {f} not found in main directory"

paths_in_augmented = augmented.queries["query1"].chains[0].main_msa_file_paths
assert len(paths_in_augmented) == 1
assert expected_file == paths_in_augmented[0], (
f"Unexpected MSA path in augmented query set: {paths_in_augmented[0]}"
)

@patch(
"openfold3.core.data.tools.colabfold_msa_server.query_colabfold_msa_server",
side_effect=_construct_dummy_a3m,
Expand Down
2 changes: 2 additions & 0 deletions openfold3/tests/test_dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from openfold3.core.data.pipelines.preprocessing.template import (
TemplatePreprocessorSettings,
)
from openfold3.core.data.tools.colabfold_msa_server import MsaComputationSettings
from openfold3.projects.of3_all_atom.config.dataset_configs import (
InferenceDatasetSpec,
InferenceJobConfig,
Expand Down Expand Up @@ -308,6 +309,7 @@ def test_inference_config_loading(self, tmp_path):
data_config,
use_msa_server=False,
use_templates=False,
msa_computation_settings=MsaComputationSettings(),
)

data_module.prepare_data()
Expand Down
Loading