Skip to content

Commit a4f8851

Browse files
committed
pipeline support for device="mps" (or any other string)
1 parent 5cd4032 commit a4f8851

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/transformers/pipelines/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def pipeline(
420420
revision: Optional[str] = None,
421421
use_fast: bool = True,
422422
use_auth_token: Optional[Union[str, bool]] = None,
423+
device: Optional[Union[int, str, "torch.device"]] = None,
423424
device_map=None,
424425
torch_dtype=None,
425426
trust_remote_code: Optional[bool] = None,
@@ -506,6 +507,9 @@ def pipeline(
506507
use_auth_token (`str` or *bool*, *optional*):
507508
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
508509
when running `transformers-cli login` (stored in `~/.huggingface`).
510+
device (`int` or `str` or `torch.device`):
511+
Sent directly as `model_kwargs` (just a simpler shortcut). Defines the device (*e.g.*, `"cpu"`, `"cuda:1"`,
512+
`"mps"`, or a GPU ordinal rank like `1`) on which this pipeline will be allocated.
509513
device_map (`str` or `Dict[str, Union[int, str, torch.device]`, *optional*):
510514
Sent directly as `model_kwargs` (just a simpler shortcut). When `accelerate` library is present, set
511515
`device_map="auto"` to compute the most optimized `device_map` automatically. [More
@@ -800,4 +804,4 @@ def pipeline(
800804
if feature_extractor is not None:
801805
kwargs["feature_extractor"] = feature_extractor
802806

803-
return pipeline_class(model=model, framework=framework, task=task, **kwargs)
807+
return pipeline_class(model=model, framework=framework, task=task, device=device, **kwargs)

src/transformers/pipelines/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def predict(self, X):
704704
Reference to the object in charge of parsing supplied pipeline parameters.
705705
device (`int`, *optional*, defaults to -1):
706706
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, a positive will run the model on
707-
the associated CUDA device id. You can pass native `torch.device` too.
707+
the associated CUDA device id. You can pass native `torch.device` or a `str` too.
708708
binary_output (`bool`, *optional*, defaults to `False`):
709709
Flag indicating if the output the pipeline should happen in a binary format (i.e., pickle) or as raw text.
710710
"""
@@ -747,7 +747,7 @@ def __init__(
747747
framework: Optional[str] = None,
748748
task: str = "",
749749
args_parser: ArgumentHandler = None,
750-
device: int = -1,
750+
device: Union[int, str, "torch.device"] = -1,
751751
binary_output: bool = False,
752752
**kwargs,
753753
):
@@ -763,11 +763,15 @@ def __init__(
763763
if is_torch_available() and isinstance(device, torch.device):
764764
self.device = device
765765
else:
766-
self.device = device if framework == "tf" else torch.device("cpu" if device < 0 else f"cuda:{device}")
766+
self.device = (
767+
device
768+
if framework == "tf"
769+
else torch.device(device if type(device) == str else "cpu" if device < 0 else f"cuda:{device}")
770+
)
767771
self.binary_output = binary_output
768772

769773
# Special handling
770-
if self.framework == "pt" and self.device.type == "cuda":
774+
if self.framework == "pt" and self.device.type != "cpu":
771775
self.model = self.model.to(self.device)
772776

773777
# Update config with task specific parameters

0 commit comments

Comments
 (0)