diff --git a/docs/prediction.md b/docs/prediction.md index b33d2d0e2..997f5cc3d 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -168,8 +168,8 @@ Examples of common options include: | `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. | | `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' | | `--use_potentials` | `FLAG` | `False` | Whether to run the original Boltz-2 model using inference time potentials. | -| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. | -| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. | +| `--write_full_pae`
`--no_write_full_pae` | `FLAG` | `True` | Whether to save the full PAE matrix as a file. | +| `--write_full_pde`
`--no_write_full_pde` | `FLAG` | `True` | Whether to save the full PDE matrix as a file. | ## Output diff --git a/src/boltz/main.py b/src/boltz/main.py index 4a3750fec..648bb7ecc 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -887,15 +887,17 @@ def cli() -> None: default=None, ) @click.option( - "--write_full_pae", + "--write_full_pae/--no_write_full_pae", type=bool, is_flag=True, + default=True, help="Whether to dump the pae into a npz file. Default is True.", ) @click.option( - "--write_full_pde", + "--write_full_pde/--no_write_full_pde", type=bool, is_flag=True, + default=True, help="Whether to dump the pde into a npz file. Default is False.", ) @click.option( @@ -1054,8 +1056,8 @@ def predict( # noqa: C901, PLR0915, PLR0912 diffusion_samples_affinity: int = 3, max_parallel_samples: Optional[int] = None, step_scale: Optional[float] = None, - write_full_pae: bool = False, - write_full_pde: bool = False, + write_full_pae: bool = True, + write_full_pde: bool = True, output_format: Literal["pdb", "mmcif"] = "mmcif", num_workers: int = 2, override: bool = False, diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py index d42f3400c..33bc3e71f 100644 --- a/src/boltz/model/models/boltz2.py +++ b/src/boltz/model/models/boltz2.py @@ -1080,7 +1080,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d pred_dict["coords"] = out["sample_atom_coords"] if self.confidence_prediction: # pred_dict["confidence"] = out.get("ablation_confidence", None) - pred_dict["pde"] = out["pde"] + if self.predict_args.get("write_full_pde", True): + pred_dict["pde"] = out["pde"] pred_dict["plddt"] = out["plddt"] pred_dict["confidence_score"] = ( 4 * out["complex_plddt"] @@ -1098,7 +1099,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d pred_dict["complex_pde"] = out["complex_pde"] pred_dict["complex_ipde"] = out["complex_ipde"] if self.alpha_pae > 0: - pred_dict["pae"] = out["pae"] + if self.predict_args.get("write_full_pae", True): + pred_dict["pae"] = out["pae"] pred_dict["ptm"] = out["ptm"] pred_dict["iptm"] = out["iptm"] pred_dict["ligand_iptm"] = out["ligand_iptm"]