diff --git a/docs/prediction.md b/docs/prediction.md
index b33d2d0e2..997f5cc3d 100644
--- a/docs/prediction.md
+++ b/docs/prediction.md
@@ -168,8 +168,8 @@ Examples of common options include:
| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. |
| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' |
| `--use_potentials` | `FLAG` | `False` | Whether to run the original Boltz-2 model using inference time potentials. |
-| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. |
-| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. |
+| `--write_full_pae`
`--no_write_full_pae` | `FLAG` | `True` | Whether to save the full PAE matrix as a file. |
+| `--write_full_pde`
`--no_write_full_pde` | `FLAG` | `True` | Whether to save the full PDE matrix as a file. |
## Output
diff --git a/src/boltz/main.py b/src/boltz/main.py
index 4a3750fec..648bb7ecc 100644
--- a/src/boltz/main.py
+++ b/src/boltz/main.py
@@ -887,15 +887,17 @@ def cli() -> None:
default=None,
)
@click.option(
- "--write_full_pae",
+ "--write_full_pae/--no_write_full_pae",
type=bool,
is_flag=True,
+ default=True,
help="Whether to dump the pae into a npz file. Default is True.",
)
@click.option(
- "--write_full_pde",
+ "--write_full_pde/--no_write_full_pde",
type=bool,
is_flag=True,
+ default=True,
help="Whether to dump the pde into a npz file. Default is False.",
)
@click.option(
@@ -1054,8 +1056,8 @@ def predict( # noqa: C901, PLR0915, PLR0912
diffusion_samples_affinity: int = 3,
max_parallel_samples: Optional[int] = None,
step_scale: Optional[float] = None,
- write_full_pae: bool = False,
- write_full_pde: bool = False,
+ write_full_pae: bool = True,
+ write_full_pde: bool = True,
output_format: Literal["pdb", "mmcif"] = "mmcif",
num_workers: int = 2,
override: bool = False,
diff --git a/src/boltz/model/models/boltz2.py b/src/boltz/model/models/boltz2.py
index d42f3400c..33bc3e71f 100644
--- a/src/boltz/model/models/boltz2.py
+++ b/src/boltz/model/models/boltz2.py
@@ -1080,7 +1080,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d
pred_dict["coords"] = out["sample_atom_coords"]
if self.confidence_prediction:
# pred_dict["confidence"] = out.get("ablation_confidence", None)
- pred_dict["pde"] = out["pde"]
+ if self.predict_args.get("write_full_pde", True):
+ pred_dict["pde"] = out["pde"]
pred_dict["plddt"] = out["plddt"]
pred_dict["confidence_score"] = (
4 * out["complex_plddt"]
@@ -1098,7 +1099,8 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> d
pred_dict["complex_pde"] = out["complex_pde"]
pred_dict["complex_ipde"] = out["complex_ipde"]
if self.alpha_pae > 0:
- pred_dict["pae"] = out["pae"]
+ if self.predict_args.get("write_full_pae", True):
+ pred_dict["pae"] = out["pae"]
pred_dict["ptm"] = out["ptm"]
pred_dict["iptm"] = out["iptm"]
pred_dict["ligand_iptm"] = out["ligand_iptm"]