Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ Examples of common options include:
| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. |
| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' |
| `--use_potentials` | `FLAG` | `False` | Whether to run the original Boltz-2 model using inference time potentials. |
| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. |
| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. |
| `--write_full_pae` <br> `--no_write_full_pae` | `FLAG` | `True` | Whether to save the full PAE matrix as a file. |
| `--write_full_pde` <br> `--no_write_full_pde` | `FLAG` | `True` | Whether to save the full PDE matrix as a file. |

## Output

Expand Down
10 changes: 6 additions & 4 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,15 +887,17 @@ def cli() -> None:
default=None,
)
@click.option(
"--write_full_pae",
"--write_full_pae/--no_write_full_pae",
type=bool,
is_flag=True,
default=True,
help="Whether to dump the pae into a npz file. Default is True.",
)
@click.option(
"--write_full_pde",
"--write_full_pde/--no_write_full_pde",
type=bool,
is_flag=True,
default=True,
help="Whether to dump the pde into a npz file. Default is False.",
)
@click.option(
Expand Down Expand Up @@ -1054,8 +1056,8 @@ def predict( # noqa: C901, PLR0915, PLR0912
diffusion_samples_affinity: int = 3,
max_parallel_samples: Optional[int] = None,
step_scale: Optional[float] = None,
write_full_pae: bool = False,
write_full_pde: bool = False,
write_full_pae: bool = True,
write_full_pde: bool = True,
output_format: Literal["pdb", "mmcif"] = "mmcif",
num_workers: int = 2,
override: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions src/boltz/model/models/boltz2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d
pred_dict["coords"] = out["sample_atom_coords"]
if self.confidence_prediction:
# pred_dict["confidence"] = out.get("ablation_confidence", None)
pred_dict["pde"] = out["pde"]
if self.predict_args.get("write_full_pde", True):
pred_dict["pde"] = out["pde"]
pred_dict["plddt"] = out["plddt"]
pred_dict["confidence_score"] = (
4 * out["complex_plddt"]
Expand All @@ -1098,7 +1099,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d
pred_dict["complex_pde"] = out["complex_pde"]
pred_dict["complex_ipde"] = out["complex_ipde"]
if self.alpha_pae > 0:
pred_dict["pae"] = out["pae"]
if self.predict_args.get("write_full_pae", True):
pred_dict["pae"] = out["pae"]
pred_dict["ptm"] = out["ptm"]
pred_dict["iptm"] = out["iptm"]
pred_dict["ligand_iptm"] = out["ligand_iptm"]
Expand Down