feat(inference): Add support for custom residue numbering (resolves #58)#69
feat(inference): Add support for custom residue numbering (resolves #58)#69chodec wants to merge 1 commit intoaqlaboratory:mainfrom
Conversation
jnwei
left a comment
There was a problem hiding this comment.
Hi @chodec,
Thank you for working on this PR. The custom numbering feature is a tricky feature to get right, but as #58 indicates, it would be very useful for researchers.
I took a first pass over this PR today, and I have a few suggestions / questions:
-
First, a check of my understanding: In the current implementation, a new method for creating a custom residue ID list (
get_custom_residue_ids) in theInferenceDatasetclass. It looks like the intention is for the new residue IDs to be read by the output writer, but I don't see where the new residue ids are added to the batch? -
I would recommend that the custom
residue_idlist be created upon construction of theChainclass ininference_query_format.pyrather than being generated in theInferenceDataset. This way, the logic around parsing the residue ids can be kept in one place, rather than adding extra logic to the InferenceDataset.- For an example of how to use pydantic validators to create the
residue_idlist given an input that is either a full list or an int, you might be able to borrow the logic used inInferenceExperimentSettingsto generate random seeds from a list or a initial integer seed here - The
InferenceDatasetcan then be used to create batch features of the custom residue list if it is provided in a chain.
- For an example of how to use pydantic validators to create the
-
Could you please add unit tests to test the creation of the custom residue numbering? I think it could be helpful to have two tests:
- One test for generating the optionalresidue_idlist in theChainclass, perhaps added here
- One test for writing the outputs, which could be added here). -
I would guess that some of the examples you used for manual validation of the numbering might be suitable test cases.
Also assigning @ljarosch to review, as he has more experience working with biotite and renumbering chains and may have additional suggestions regarding the organization.
Please let us know if you have any questions, and thank you again for your work on this issue!
openfold3/core/runners/writer.py
Outdated
| elif out_fmt == "npz": | ||
| np.savez_compressed(out_file_full, **full_confidence_scores) | ||
|
|
||
| # openfold3/core/runners/writer.py |
There was a problem hiding this comment.
Please remove this stray comment
|
Hello, any development on this front? |
|
Hi @lucajovine really sorry for the late reply. Between the holidays, a work crunch, and now my university exams, I completely lost track of this. I’m just heading out for a vacation now, but I’ll jump back into it as soon as I’m back. Thanks for your patience |
|
No worries and thanks again! |
…qlaboratory#58) Implements custom residue numbering feature with support for: - Explicit residue_ids list (e.g., ['1A', '2', '3B']) - Starting residue number offset - Insertion codes in PDB format Changes: - Add residue_ids and starting_residue_number fields to Chain model - Add validation and generation logic in Chain._validate_and_generate_residue_ids - Add custom_residue_ids to batch features in InferenceDataset - Add _renumber_atom_array method in OF3OutputWriter - Add unit tests for Chain validation and writer renumbering Resolves: aqlaboratory#58
eabfc13 to
43bc59b
Compare
|
I've finally implemented all the requested changes, sorry for the long delay :D
Hopefully it will be helpful now. |
jnwei
left a comment
There was a problem hiding this comment.
@chodec Thank you very much for your work on this PR and for refactoring the residue id construction into the Chain definition. I also really appreciate the additional tests and attention to detail, especially for getting parsing the insertion codes.
I can help run an end to end test for this PR some time next week. But for now, I wanted to provide some suggestions for organization of the tests and default behavior.
@ljarosch Please take a quick look, especially at the _renumber_atom_array logic in OF3OutputWriter.
| from pydantic import model_validator, Field | ||
|
|
||
| from pydantic import ( | ||
| BaseModel, | ||
| BeforeValidator, | ||
| DirectoryPath, | ||
| FilePath, | ||
| field_serializer, | ||
| field_serializer |
There was a problem hiding this comment.
nit: The imports seem to have been separated into two groups from pydantic. Can these import statements be merged together?
A linting tool should be able to fix the import statements. For this project, we use ruff with the settings in pyproject.toml.
There was a problem hiding this comment.
Please revert these changes to tests/__init__.py the final submission
There was a problem hiding this comment.
Please revert these changes to openfold3/__init__.py
| else: | ||
| # Default to numbering starting from 1 | ||
| res_ids = [str(1 + i) for i in range(sequence_length)] | ||
| data["residue_ids"] = res_ids |
There was a problem hiding this comment.
I think it would be preferable to leave residue_ids as None if there is no custom residue numbering / start number provided by the user.
I see that later, this custom residue_ids field is passed to the batch, which is then used to trigger the renumber_residue_ids later. If residue_ids is left blank, then default behavior for inference would remain the same (i.e. it would not perform the renumbering step).
There was a problem hiding this comment.
Would you mind making a new copy of query_multimer.json with the desired residue_ids. Perhaps you could save it as examples/example_inference_inputs/query_multimer_custom_numbering.json
| ) | ||
|
|
||
| @staticmethod | ||
| def _renumber_atom_array( |
There was a problem hiding this comment.
@ljarosch Do you have any comments about this method? I assume we have other parts of the data pipeline that would also require renumbering / reannotation of atom arrays. Would it make sense to add this function as one of the pipelines?
| ) | ||
| def test_structure_from_query(query: Query, ground_truth_file: Path): | ||
| """Tests that the generated structure and reference molecules matches gt.""" | ||
| def test_structure_with_ref_mols_from_query(query, ground_truth_file): |
There was a problem hiding this comment.
The changes to test_structure_with_ref_mols_from_query seem unrelated to this PR? could we revert these changes?
| chain_default = Chain.model_validate(base_params) | ||
| assert chain_default.residue_ids == ['1', '2', '3'] | ||
|
|
||
| params_start = base_params.copy() |
There was a problem hiding this comment.
I would recommend creating a separate test for each test case, rather than updating the base_params. This way, if a test fails, it is easy to see which specific use case has failed.
Perhaps an organization that is something like this (pseudocode, not tested)
class TestCustomResidueIDGeneration:
def base_params(self): # could potentially be a pytest.fixture instead
return {
"molecule_type": MoleculeType.PROTEIN,
"chain_ids": ["A"],
"sequence": "AAA"
}
def test_base_definition(self):
chain_default = Chain.model_validate(self.base_params())
assert chain_default.residue_ids == ['1', '2', '3']
def test_residue_id_starting_number(self):
params_start = self.base_params().update({'starting_residue_number': "100"})
assert chain_start.residue_ids == ['100', '101', '102']
| assert writer.failed_count == 1 | ||
| assert writer.success_count == 0 | ||
|
|
||
| def test_renumber_atom_array_with_insertion_codes(self): |
There was a problem hiding this comment.
Would it be possible to explicitly write out the atom array? Similar to the dummy_array created here? https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/tests/conftest.py#L10-L24
I find that it is much easier to check the input example when large inputs such as atom_arrays are written explicitly rather than constructed piece by piece.
* readme for how to build the docs * consistent capitalization * minimal training.md * review: comment from Etowah * specify the bucket * Update PDB s3 path --------- Co-authored-by: Vinay Swamy <[email protected]>
This Pull Request implements full support for custom residue numbering in the inference output. This feature allows users to define specific residue numbering in the input JSON, resolving Issue #58.
Summary of Changes
The primary goal was to allow users to define specific residue numbering in the input JSON, rather than relying on the default numbering starting from 1. This includes support for non-sequential numbers and PDB-style insertion codes (e.g., '103A').
The implementation required coordinated changes across three key modules:
inference_query_format.py):Chainclass:starting_residue_number(for simple offset) andresidue_ids(for explicit lists).inference.py):residue_idstakes precedence overstarting_residue_number. If a valid explicit list is provided, it is used; otherwise, a sequential list is generated based on the start number. The final list is stored in the data batch.writer.py):OF3OutputWriter._renumber_atom_array. This method executes after model inference but before writing the PDB/mmCIF file.re) to safely parse string IDs (e.g., separating'103A'into the integer ID 103 and the insertion code 'A').AtomArray'sres_idandins_codeannotations. This ensures the output structure reflects the desired numbering without affecting core model calculations.Related Issues
Resolves: #58
Testing and Validation
Note on Testing: Due to local environment configuration issues (missing model checkpoints), an end-to-end test run was not possible to perform.
However, the logic has been manually validated to ensure:
residue_idsand handles sequence length mismatch by defaulting to standard numbering (1, 2, 3...).writer.pycorrectly extracts insertion codes, which is critical for PDB compliance.