Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from multiprocessing import Pool
from pathlib import Path
from typing import Literal, Optional
from ast import literal_eval

import click
import torch
Expand Down Expand Up @@ -812,6 +813,23 @@ def cli() -> None:
"""Boltz."""
return

def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto"]:
"""Convert the devices option value to an int or list of ints. raise ValueError if not convertible."""
if value == "auto" or value == "0":
return "auto"
elif value == "-1":
raise ValueError("Using -1 would cause issues, please use 'auto' instead.")
elif isinstance(value,int) or value.isdigit():
return int(value)
else:
try:
value_list = literal_eval(value)
if isinstance(value_list, list) and all(isinstance(i, int) for i in value_list):
return value_list
else:
raise ValueError("Invalid format for devices option.")
except (ValueError, SyntaxError):
raise ValueError("Invalid format for devices option. Use an int, list of ints, or 'auto'.")

@cli.command()
@click.argument("data", type=click.Path(exists=True))
Expand All @@ -838,8 +856,12 @@ def cli() -> None:
)
@click.option(
"--devices",
type=int,
help="The number of devices to use for prediction. Default is 1.",
type=devices_option_convert,
help=(
"The number of devices, the list of devices, or 'auto' to use for prediction."
"Examples: 1, [0, 1, 2], [2], auto\n"
"Default is 1."
),
default=1,
)
@click.option(
Expand Down Expand Up @@ -1044,7 +1066,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
cache: str = "~/.boltz",
checkpoint: Optional[str] = None,
affinity_checkpoint: Optional[str] = None,
devices: int = 1,
devices: int | list[int] | Literal["auto"] = 1,
accelerator: str = "gpu",
recycling_steps: int = 3,
sampling_steps: int = 200,
Expand Down Expand Up @@ -1213,7 +1235,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
):
start_method = "fork" if platform.system() != "win32" and platform.system() != "Windows" else "spawn"
strategy = DDPStrategy(start_method=start_method)
if len(filtered_manifest.records) < devices:
if len(filtered_manifest.records) < (devices if isinstance(devices, int) else len(devices)):
msg = (
"Number of requested devices is greater "
"than the number of predictions, taking the minimum."
Expand Down