Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from openfold3.core.data.pipelines.preprocessing.template import TemplatePreprocessor
from openfold3.core.data.tools.colabfold_msa_server import (
MsaComputationSettings,
augment_main_msa_with_query_sequence,
preprocess_colabfold_msas,
)
from openfold3.core.utils.tensor_utils import dict_multimap
Expand Down Expand Up @@ -521,6 +522,11 @@ def prepare_data(self) -> None:
inference_query_set=self.inference_config.query_set,
compute_settings=self.msa_computation_settings,
)
else:
self.inference_config.query_set = augment_main_msa_with_query_sequence(
inference_query_set=self.inference_config.query_set,
compute_settings=self.msa_computation_settings,
)

if self.use_templates:
template_preprocessor = TemplatePreprocessor(
Expand Down
36 changes: 36 additions & 0 deletions openfold3/core/data/tools/colabfold_msa_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,39 @@ def preprocess_colabfold_msas(
)

return inference_query_set


def augment_main_msa_with_query_sequence(
inference_query_set: InferenceQuerySet,
compute_settings: MsaComputationSettings,
) -> InferenceQuerySet:
output_directory = compute_settings.msa_output_directory
for query_name, query in inference_query_set.queries.items():
for chain in query.chains:
if (
chain.molecule_type == MoleculeType.PROTEIN
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hard-coding the moltypes, could you make them dependent on MSASettings.moltypes (should be in projects.of3_all_atom.config.dataset_config_components), which we use elsewhere to determine which molecule types are expected to have MSAs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, moltypes is attribute of MSASettings not MSAComputationSettings.
So I think it's not accessible yet in the DataModule prepare_data level.
Is there any way to cleanly address this?
But I think it's okay to compute & save for both. Later if moltypes don't include RNA or PROTEIN, this field will be ignored. How do you think?

or chain.molecule_type == MoleculeType.RNA
) and chain.main_msa_file_paths is None:
dummy_msa_file_path = (
output_directory
/ "dummy"
/ f"{get_sequence_hash(chain.sequence)}.npz"
)
dummy_msa_file_path.parent.mkdir(exist_ok=True, parents=True)
dummy_msa = ">query\n" + chain.sequence
msas_preparsed = {"dummy": parse_a3m(dummy_msa).to_dict()}
np.savez_compressed(dummy_msa_file_path, **msas_preparsed)
chain.main_msa_file_paths = [dummy_msa_file_path]

chain_ids = ",".join(chain.chain_ids)
warnings.warn(
(
f"Expected MSA file for chain {chain_ids} of "
f"type {chain.molecule_type.name} in query "
f"{query_name}, but no MSA files found. Query sequence "
"will be used as dummy MSA for this chain."
),
stacklevel=2,
)

return inference_query_set
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class MSASettings(BaseModel):
"nt_hits",
"concat_cfdb_uniref100_filtered",
"colabfold_main",
"dummy",
]
paired_msa_order: list = ["colabfold_paired"]

Expand Down
Loading