From 2545f128d21ea7e5da3e4b34a85ef6ed98a60996 Mon Sep 17 00:00:00 2001 From: Yi-Shu Tu Date: Thu, 26 Jun 2025 13:51:02 +0000 Subject: [PATCH 1/5] make device argument more flexible --- src/boltz/main.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 605f1abe8..09bae2841 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -10,6 +10,7 @@ from multiprocessing import Pool from pathlib import Path from typing import Literal, Optional +from ast import literal_eval import click import torch @@ -731,6 +732,21 @@ def cli() -> None: """Boltz.""" return +def devices_option_convert(value: str) -> int | list[int] | Literal["auto"]: + """Convert the devices option value to an int or list of ints. raise ValueError if not convertible.""" + if value == "auto" or value == "0": + return "auto" + elif value.isdigit(): + return int(value) + else: + try: + value_list = literal_eval(value) + if isinstance(value_list, list) and all(isinstance(i, int) for i in value_list): + return value_list + else: + raise ValueError("Invalid format for devices option.") + except (ValueError, SyntaxError): + raise ValueError("Invalid format for devices option. Use an int, list of ints, or 'auto'.") @cli.command() @click.argument("data", type=click.Path(exists=True)) @@ -757,7 +773,7 @@ def cli() -> None: ) @click.option( "--devices", - type=int, + type=devices_option_convert, help="The number of devices to use for prediction. Default is 1.", default=1, ) @@ -934,7 +950,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 cache: str = "~/.boltz", checkpoint: Optional[str] = None, affinity_checkpoint: Optional[str] = None, - devices: int = 1, + devices: int | list[int] | Literal["auto"] = 1, accelerator: str = "gpu", recycling_steps: int = 3, sampling_steps: int = 200, @@ -1076,7 +1092,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 ): start_method = "fork" if platform.system() != "win32" else "spawn" strategy = DDPStrategy(start_method=start_method) - if len(filtered_manifest.records) < devices: + if len(filtered_manifest.records) < (devices if isinstance(devices, int) else len(devices)): msg = ( "Number of requested devices is greater " "than the number of predictions, taking the minimum." From 658afd58852be1f6bdf2586384d44e0747add97b Mon Sep 17 00:00:00 2001 From: Yi-Shu Tu Date: Wed, 16 Jul 2025 18:33:39 +0000 Subject: [PATCH 2/5] adding type check --- src/boltz/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 09bae2841..04ace4dbc 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -736,7 +736,7 @@ def devices_option_convert(value: str) -> int | list[int] | Literal["auto"]: """Convert the devices option value to an int or list of ints. raise ValueError if not convertible.""" if value == "auto" or value == "0": return "auto" - elif value.isdigit(): + elif isinstance(value,int) or value.isdigit(): return int(value) else: try: From 52dc883f0fd4dee2cdb5a31d4e175629363cbf7e Mon Sep 17 00:00:00 2001 From: Yi-Shu Tu Date: Thu, 17 Jul 2025 15:23:26 +0000 Subject: [PATCH 3/5] typehint modification --- src/boltz/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 156e9be93..64fa0a51e 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -813,7 +813,7 @@ def cli() -> None: """Boltz.""" return -def devices_option_convert(value: str) -> int | list[int] | Literal["auto"]: +def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto"]: """Convert the devices option value to an int or list of ints. raise ValueError if not convertible.""" if value == "auto" or value == "0": return "auto" From e8797397c920486367f9462d0110267b6c64423c Mon Sep 17 00:00:00 2001 From: Yi-Shu Tu Date: Thu, 17 Jul 2025 15:46:59 +0000 Subject: [PATCH 4/5] Add descriptions --- src/boltz/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 64fa0a51e..2ca799c16 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -817,6 +817,8 @@ def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto" """Convert the devices option value to an int or list of ints. raise ValueError if not convertible.""" if value == "auto" or value == "0": return "auto" + elif value == "-1": + return -1 elif isinstance(value,int) or value.isdigit(): return int(value) else: @@ -855,7 +857,11 @@ def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto" @click.option( "--devices", type=devices_option_convert, - help="The number of devices to use for prediction. Default is 1.", + help=( + "The number of devices, the list of devices, or 'auto' to use for prediction." + "Examples: 1, [0, 1, 2], [2], auto\n" + "Default is 1." + ), default=1, ) @click.option( From 002208528e6d5fb32e0979cdb2f6f9e9ec474510 Mon Sep 17 00:00:00 2001 From: Yi-Shu Tu Date: Thu, 17 Jul 2025 15:57:18 +0000 Subject: [PATCH 5/5] reject device=-1 --- src/boltz/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 2ca799c16..44ba7a85f 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -818,7 +818,7 @@ def devices_option_convert(value: str | int) -> int | list[int] | Literal["auto" if value == "auto" or value == "0": return "auto" elif value == "-1": - return -1 + raise ValueError("Using -1 would cause issues, please use 'auto' instead.") elif isinstance(value,int) or value.isdigit(): return int(value) else: