Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,10 @@ def postprocess(
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
items = outputs[key].to(torch.float32).numpy()
else:
items = outputs[key].numpy()
stride = outputs.get("stride", None)
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES


Expand Down Expand Up @@ -299,7 +301,11 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
ignore_labels = ["O"]
all_entities = []
for model_outputs in all_outputs:
logits = model_outputs["logits"][0].numpy()
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
logits = model_outputs["logits"][0].to(torch.float32).numpy()
else:
logits = model_outputs["logits"][0].numpy()

sentence = all_outputs[0]["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = (
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,7 +2136,7 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
elif isinstance(obj, (np.int32, np.float32)):
elif isinstance(obj, (np.int32, np.float32, np.float16)):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
Expand Down
42 changes: 42 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,48 @@ def test_small_model_pt(self):
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_fp16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_bf16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.bfloat16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@slow
@require_torch_accelerator
def test_whisper_fp16(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/test_pipelines_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
Expand All @@ -38,6 +39,10 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch


VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]

# These 2 model types require different inputs than those of the usual text models.
Expand Down Expand Up @@ -841,6 +846,36 @@ def test_small_model_pt(self):
],
)

@require_torch
def test_small_model_pt_fp16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_small_model_pt_bf16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down
55 changes: 54 additions & 1 deletion tests/pipelines/test_pipelines_zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@
ZeroShotClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
slow,
)

from .test_pipelines_common import ANY


if is_torch_available():
import torch


# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}

Expand Down Expand Up @@ -176,6 +187,48 @@ def test_small_model_pt(self):
},
)

@require_torch
def test_small_model_pt_fp16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.float16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_torch
def test_small_model_pt_bf16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.bfloat16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_tf
def test_small_model_tf(self):
zero_shot_classifier = pipeline(
Expand Down