diff --git a/src/boltz/main.py b/src/boltz/main.py index 4fdf84ab0..44ba7a85f 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -10,6 +10,7 @@ from multiprocessing import Pool from pathlib import Path from typing import Literal, Optional +from ast import literal_eval import click import torch @@ -812,6 +813,23 @@ def cli() -> None: """Boltz.""" return +def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto"]: + """Convert the devices option value to an int or list of ints. raise ValueError if not convertible.""" + if value == "auto" or value == "0": + return "auto" + elif value == "-1": + raise ValueError("Using -1 would cause issues, please use 'auto' instead.") + elif isinstance(value,int) or value.isdigit(): + return int(value) + else: + try: + value_list = literal_eval(value) + if isinstance(value_list, list) and all(isinstance(i, int) for i in value_list): + return value_list + else: + raise ValueError("Invalid format for devices option.") + except (ValueError, SyntaxError): + raise ValueError("Invalid format for devices option. Use an int, list of ints, or 'auto'.") @cli.command() @click.argument("data", type=click.Path(exists=True)) @@ -838,8 +856,12 @@ def cli() -> None: ) @click.option( "--devices", - type=int, - help="The number of devices to use for prediction. Default is 1.", + type=devices_option_convert, + help=( + "The number of devices, the list of devices, or 'auto' to use for prediction." + "Examples: 1, [0, 1, 2], [2], auto\n" + "Default is 1." + ), default=1, ) @click.option( @@ -1044,7 +1066,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 cache: str = "~/.boltz", checkpoint: Optional[str] = None, affinity_checkpoint: Optional[str] = None, - devices: int = 1, + devices: int | list[int] | Literal["auto"] = 1, accelerator: str = "gpu", recycling_steps: int = 3, sampling_steps: int = 200, @@ -1213,7 +1235,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 ): start_method = "fork" if platform.system() != "win32" and platform.system() != "Windows" else "spawn" strategy = DDPStrategy(start_method=start_method) - if len(filtered_manifest.records) < devices: + if len(filtered_manifest.records) < (devices if isinstance(devices, int) else len(devices)): msg = ( "Number of requested devices is greater " "than the number of predictions, taking the minimum."