Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 73 additions & 65 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def cli() -> None:
return


@cli.command()
@cli.command(context_settings={'show_default': True})
@click.argument("data", type=click.Path(exists=True))
@click.option(
"--out_dir",
Expand All @@ -764,37 +764,37 @@ def cli() -> None:
@click.option(
"--devices",
type=int,
help="The number of devices to use for prediction. Default is 1.",
help="The number of devices to use for prediction.",
default=1,
)
@click.option(
"--accelerator",
type=click.Choice(["gpu", "cpu", "tpu"]),
help="The accelerator to use for prediction. Default is gpu.",
help="The accelerator to use for prediction.",
default="gpu",
)
@click.option(
"--recycling_steps",
type=int,
help="The number of recycling steps to use for prediction. Default is 3.",
help="The number of recycling steps to use for prediction.",
default=3,
)
@click.option(
"--sampling_steps",
type=int,
help="The number of sampling steps to use for prediction. Default is 200.",
help="The number of sampling steps to use for prediction.",
default=200,
)
@click.option(
"--diffusion_samples",
type=int,
help="The number of diffusion samples to use for prediction. Default is 1.",
help="The number of diffusion samples to use for prediction.",
default=1,
)
@click.option(
"--max_parallel_samples",
type=int,
help="The maximum number of samples to predict in parallel. Default is None.",
help="The maximum number of samples to predict in parallel.",
default=5,
)
@click.option(
Expand All @@ -811,44 +811,48 @@ def cli() -> None:
default=None,
)
@click.option(
"--write_full_pae",
"--write_full_pae/--no_write_full_pae",
type=bool,
is_flag=True,
help="Whether to dump the pae into a npz file. Default is True.",
help="Whether to dump the pae into a npz file.",
default=False,
)
@click.option(
"--write_full_pde",
"--write_full_pde/--no_write_full_pde",
type=bool,
is_flag=True,
help="Whether to dump the pde into a npz file. Default is False.",
help="Whether to dump the pde into a npz file.",
default=False,
)
@click.option(
"--output_format",
type=click.Choice(["pdb", "mmcif"]),
help="The output format to use for the predictions. Default is mmcif.",
help="The output format to use for the predictions.",
default="mmcif",
)
@click.option(
"--num_workers",
type=int,
help="The number of dataloader workers to use for prediction. Default is 2.",
help="The number of dataloader workers to use for prediction.",
default=2,
)
@click.option(
"--override",
"--override/--no_override",
is_flag=True,
help="Whether to override existing found predictions. Default is False.",
help="Override existing found predictions.",
default=False,
)
@click.option(
"--seed",
type=int,
help="Seed to use for random number generator. Default is None (no seeding).",
help="Seed to use for random number generator.",
default=None,
)
@click.option(
"--use_msa_server",
"--use_msa_server/--no_use_msa_server",
is_flag=True,
help="Whether to use the MMSeqs2 server for MSA generation. Default is False.",
help="Whether to use the MMSeqs2 server for MSA generation.",
default=False,
)
@click.option(
"--msa_server_url",
Expand All @@ -858,52 +862,54 @@ def cli() -> None:
)
@click.option(
"--msa_pairing_strategy",
type=str,
type=click.Choice(["greedy", "complete"]),
help=(
"Pairing strategy to use. Used only if --use_msa_server is set. "
"Options are 'greedy' and 'complete'"
),
default="greedy",
)
@click.option(
"--use_potentials",
"--use_potentials/--no_use_potentials",
is_flag=True,
help="Whether to not use potentials for steering. Default is False.",
help="Whether to use potentials for steering.",
default=False,
)
@click.option(
"--model",
default="boltz2",
type=click.Choice(["boltz1", "boltz2"]),
help="The model to use for prediction. Default is boltz2.",
help="The model to use for prediction.",
default="boltz2",
)
@click.option(
"--method",
type=str,
help="The method to use for prediction. Default is None.",
help="The method to use for prediction.",
type=click.Choice(list(const.method_types_ids.keys()) + [None,]),
default=None,
)
@click.option(
"--preprocessing-threads",
type=int,
help="The number of threads to use for preprocessing. Default is 1.",
help="The number of threads to use for preprocessing.",
default=multiprocessing.cpu_count(),
)
@click.option(
"--affinity_mw_correction",
"--affinity_mw_correction/--no_affinity_mw_correction",
is_flag=True,
type=bool,
help="Whether to add the Molecular Weight correction to the affinity value head.",
default=False,
)
@click.option(
"--sampling_steps_affinity",
type=int,
help="The number of sampling steps to use for affinity prediction. Default is 200.",
help="The number of sampling steps to use for affinity prediction.",
default=200,
)
@click.option(
"--diffusion_samples_affinity",
type=int,
help="The number of diffusion samples to use for affinity prediction. Default is 5.",
help="The number of diffusion samples to use for affinity prediction.",
default=5,
)
@click.option(
Expand All @@ -915,58 +921,60 @@ def cli() -> None:
@click.option(
"--max_msa_seqs",
type=int,
help="The maximum number of MSA sequences to use for prediction. Default is 8192.",
help="The maximum number of MSA sequences to use for prediction.",
default=8192,
)
@click.option(
"--subsample_msa",
"--subsample_msa/--no_subsample_msa",
is_flag=True,
help="Whether to subsample the MSA. Default is True.",
help="Whether to subsample the MSA.",
default=False,
)
@click.option(
"--num_subsampled_msa",
type=int,
help="The number of MSA sequences to subsample. Default is 1024.",
help="The number of MSA sequences to subsample.",
default=1024,
)
@click.option(
"--no_kernels",
"--no_kernels/--kernels",
is_flag=True,
help="Whether to disable the kernels. Default False",
help="Whether to disable the kernels.",
default=False
)
def predict( # noqa: C901, PLR0915, PLR0912
data: str,
out_dir: str,
cache: str = "~/.boltz",
checkpoint: Optional[str] = None,
affinity_checkpoint: Optional[str] = None,
devices: int = 1,
accelerator: str = "gpu",
recycling_steps: int = 3,
sampling_steps: int = 200,
diffusion_samples: int = 1,
sampling_steps_affinity: int = 200,
diffusion_samples_affinity: int = 3,
max_parallel_samples: Optional[int] = None,
step_scale: Optional[float] = None,
write_full_pae: bool = False,
write_full_pde: bool = False,
output_format: Literal["pdb", "mmcif"] = "mmcif",
num_workers: int = 2,
override: bool = False,
seed: Optional[int] = None,
use_msa_server: bool = False,
msa_server_url: str = "https://api.colabfold.com",
msa_pairing_strategy: str = "greedy",
use_potentials: bool = False,
model: Literal["boltz1", "boltz2"] = "boltz2",
method: Optional[str] = None,
affinity_mw_correction: Optional[bool] = False,
preprocessing_threads: int = 1,
max_msa_seqs: int = 8192,
subsample_msa: bool = True,
num_subsampled_msa: int = 1024,
no_kernels: bool = False,
cache: str,
checkpoint: Optional[str],
affinity_checkpoint: Optional[str],
devices: int,
accelerator: str,
recycling_steps: int,
sampling_steps: int,
diffusion_samples: int,
sampling_steps_affinity: int,
diffusion_samples_affinity: int,
max_parallel_samples: Optional[int],
step_scale: Optional[float],
write_full_pae: bool,
write_full_pde: bool,
output_format: Literal["pdb", "mmcif"],
num_workers: int,
override: bool,
seed: Optional[int],
use_msa_server: bool,
msa_server_url: str,
msa_pairing_strategy: str,
use_potentials: bool,
model: Literal["boltz1", "boltz2"],
method: Optional[str],
affinity_mw_correction: Optional[bool],
preprocessing_threads: int,
max_msa_seqs: int,
subsample_msa: bool,
num_subsampled_msa: int,
no_kernels: bool,
) -> None:
"""Run predictions with Boltz."""
# If cpu, write a friendly warning
Expand Down