Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,10 @@ def postprocess(
key = "logits" if self.type == "ctc_with_lm" else "tokens"
stride = None
for outputs in model_outputs:
items = outputs[key].numpy()
if self.framework == "pt" and outputs[key].dtype == torch.bfloat16:
items = outputs[key].float().numpy()
else:
items = outputs[key].numpy()
stride = outputs.get("stride", None)
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
total_n, left, right = stride
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/pipelines/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,9 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
top_k = self.model.config.num_labels

outputs = model_outputs["logits"][0]
if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16):
outputs = outputs.to(torch.float32).numpy()
if self.framework == "pt" and outputs.dtype == torch.bfloat16:
# To enable using bf16
outputs = outputs.float().numpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a regression to me! The previous code worked with both torch.bfloat16 and torch.float16.

Copy link
Contributor Author

@jiqing-feng jiqing-feng Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use torch 2.4, numpy-2.1.1 in python 3.10 and it failed:
image

else:
outputs = outputs.numpy()

Expand Down
9 changes: 8 additions & 1 deletion src/transformers/pipelines/token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
if is_torch_available():
import torch

from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES


Expand Down Expand Up @@ -299,7 +301,12 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
ignore_labels = ["O"]
all_entities = []
for model_outputs in all_outputs:
logits = model_outputs["logits"][0].numpy()
if self.framework == "pt" and model_outputs["logits"][0].dtype == torch.bfloat16:
# To enable using bf16
logits = model_outputs["logits"][0].float().numpy()
else:
logits = model_outputs["logits"][0].numpy()

sentence = all_outputs[0]["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = (
Expand Down
14 changes: 12 additions & 2 deletions src/transformers/pipelines/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import numpy as np

from ..tokenization_utils import TruncationStrategy
from ..utils import add_end_docstrings, logging
from ..utils import add_end_docstrings, is_torch_available, logging
from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args


if is_torch_available():
import torch


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -239,7 +243,13 @@ def _forward(self, inputs):
def postprocess(self, model_outputs, multi_label=False):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
sequences = [outputs["sequence"] for outputs in model_outputs]
logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
logits = []
for output in model_outputs:
if self.framework == "pt" and output["logits"].dtype == torch.bfloat16:
logits.append(output["logits"].float().numpy())
else:
logits.append(output["logits"].numpy())
logits = np.concatenate(logits)
N = logits.shape[0]
n = len(candidate_labels)
num_sequences = N // n
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,7 +2136,7 @@ def nested_simplify(obj, decimals=3):
return nested_simplify(obj.numpy().tolist())
elif isinstance(obj, float):
return round(obj, decimals)
elif isinstance(obj, (np.int32, np.float32)):
elif isinstance(obj, (np.int32, np.float32, np.float16)):
return nested_simplify(obj.item(), decimals)
else:
raise Exception(f"Not supported: {type(obj)}")
Expand Down
42 changes: 42 additions & 0 deletions tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,48 @@ def test_small_model_pt(self):
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_fp16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.float16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@require_torch
def test_small_model_pt_bf16(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="facebook/s2t-small-mustc-en-fr-st",
tokenizer="facebook/s2t-small-mustc-en-fr-st",
framework="pt",
torch_dtype=torch.bfloat16,
)
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
output = speech_recognizer(waveform, chunk_length_s=10)
self.assertEqual(output, {"text": "(Applaudissements)"})

# Non CTC models cannot use return_timestamps
with self.assertRaisesRegex(
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
):
_ = speech_recognizer(waveform, return_timestamps="char")

@slow
@require_torch_accelerator
def test_whisper_fp16(self):
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/test_pipelines_token_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
Expand All @@ -38,6 +39,10 @@
from .test_pipelines_common import ANY


if is_torch_available():
import torch


VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]

# These 2 model types require different inputs than those of the usual text models.
Expand Down Expand Up @@ -841,6 +846,36 @@ def test_small_model_pt(self):
],
)

@require_torch
def test_small_model_pt_fp16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_small_model_pt_bf16(self):
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
token_classifier = pipeline(
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
)
outputs = token_classifier("This is a test !")
self.assertEqual(
nested_simplify(outputs),
[
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
],
)

@require_torch
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down
55 changes: 54 additions & 1 deletion tests/pipelines/test_pipelines_zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,22 @@
ZeroShotClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
from transformers.testing_utils import (
is_pipeline_test,
is_torch_available,
nested_simplify,
require_tf,
require_torch,
slow,
)

from .test_pipelines_common import ANY


if is_torch_available():
import torch


# These 2 model types require different inputs than those of the usual text models.
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}

Expand Down Expand Up @@ -176,6 +187,48 @@ def test_small_model_pt(self):
},
)

@require_torch
def test_small_model_pt_fp16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.float16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_torch
def test_small_model_pt_bf16(self):
zero_shot_classifier = pipeline(
"zero-shot-classification",
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
framework="pt",
torch_dtype=torch.bfloat16,
)
outputs = zero_shot_classifier(
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
)

self.assertEqual(
nested_simplify(outputs),
{
"sequence": "Who are you voting for in 2020?",
"labels": ["science", "public health", "politics"],
"scores": [0.333, 0.333, 0.333],
},
)

@require_tf
def test_small_model_tf(self):
zero_shot_classifier = pipeline(
Expand Down