Skip to content

Commit 904a840

Browse files
committed
pipeline support for device="mps" (or any other string)
1 parent 6eb5145 commit 904a840

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def pipeline(
422422
revision: Optional[str] = None,
423423
use_fast: bool = True,
424424
use_auth_token: Optional[Union[str, bool]] = None,
425+
device: Optional[Union[int, str, "torch.device"]] = None,
425426
device_map=None,
426427
torch_dtype=None,
427428
trust_remote_code: Optional[bool] = None,
@@ -508,6 +509,9 @@ def pipeline(
508509
use_auth_token (`str` or *bool*, *optional*):
509510
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
510511
when running `huggingface-cli login` (stored in `~/.huggingface`).
512+
device (`int` or `str` or `torch.device`):
513+
Sent directly as `model_kwargs` (just a simpler shortcut). Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`,
514+
`"mps"`, or a GPU ordinal rank like `1`) on which this pipeline will be allocated.
511515
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
512516
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
513517
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
@@ -802,4 +806,4 @@ def pipeline(
802806
if feature_extractor is not None:
803807
kwargs["feature_extractor"] = feature_extractor
804808

805-
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
809+
return pipeline_class(model=model, framework=framework, task=task, device=device, **kwargs)

src/transformers/pipelines/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def predict(self, X):
704704
Reference to the object in charge of parsing supplied pipeline parameters.
705705
device (`int`, *optional*, defaults to -1):
706706
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
707-
the associated CUDA device id. You can pass native `torch.device` too.
707+
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
708708
binary_output (`bool`, *optional*, defaults to `False`):
709709
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
710710
"""
@@ -747,7 +747,7 @@ def __init__(
747747
framework: Optional[str] = None,
748748
task: str = "",
749749
args_parser: ArgumentHandler = None,
750-
device: int = -1,
750+
device: Union[int, str, "torch.device"] = -1,
751751
binary_output: bool = False,
752752
**kwargs,
753753
):
@@ -763,11 +763,15 @@ def __init__(
763763
if is_torch_available() and isinstance(device, torch.device):
764764
self.device = device
765765
else:
766-
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
766+
self.device = (
767+
device
768+
if framework == "tf"
769+
else torch.device(device if type(device) == str else "cpu" if device < 0 else f"cuda:{device}")
770+
)
767771
self.binary_output = binary_output
768772

769773
# Special handling
770-
if self.framework == "pt" and self.device.type == "cuda":
774+
if self.framework == "pt" and self.device.type != "cpu":
771775
self.model = self.model.to(self.device)
772776

773777
# Update config with task specific parameters

0 commit comments

Comments
 (0)