From 4579ef17a5684779496942da07163d238514eda2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Jun 2024 07:50:40 -0400 Subject: [PATCH 01/14] enable low-precision pipeline --- .../pipelines/automatic_speech_recognition.py | 6 +- .../pipelines/image_classification.py | 6 +- .../pipelines/token_classification.py | 7 ++- .../pipelines/zero_shot_classification.py | 2 +- ..._pipelines_automatic_speech_recognition.py | 42 ++++++++++++++ .../test_pipelines_token_classification.py | 31 +++++++++++ tests/pipelines/test_pipelines_zero_shot.py | 55 ++++++++++++++++++- 7 files changed, 142 insertions(+), 7 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 01faab6d74ad..47a438f9d6c1 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -557,7 +557,11 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - items = outputs[key].numpy() + if self.framework == "pt": + # To enable using fp16 and bf16 + outputs = outputs[key].float().numpy() + else: + outputs = outputs[key].numpy() stride = outputs.get("stride", None) if stride is not None and self.type in {"ctc", "ctc_with_lm"}: total_n, left, right = stride diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index bfa005f06bab..be5ae702abe6 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -23,7 +23,6 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): - import torch from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES @@ -182,8 +181,9 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5): top_k = self.model.config.num_labels outputs = model_outputs["logits"][0] - if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): - outputs = outputs.to(torch.float32).numpy() + if self.framework == "pt": + # To enable using fp16 and bf16 + outputs = outputs.float().numpy() else: outputs = outputs.numpy() diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index e1d763eafa8b..12703267fe3e 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -299,7 +299,12 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE ignore_labels = ["O"] all_entities = [] for model_outputs in all_outputs: - logits = model_outputs["logits"][0].numpy() + if self.framework == "pt": + # To enable using fp16 and bf16 + logits = model_outputs["logits"][0].float().numpy() + else: + logits = model_outputs["logits"][0].numpy() + sentence = all_outputs[0]["sentence"] input_ids = model_outputs["input_ids"][0] offset_mapping = ( diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 9a600bc8ad0f..7a4f43e7ec8b 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -239,7 +239,7 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 73376ff2189c..c37d31a8aca9 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -169,6 +169,48 @@ def test_small_model_pt(self): ): _ = speech_recognizer(waveform, return_timestamps="char") + @require_torch + def test_small_model_pt_fp16(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/s2t-small-mustc-en-fr-st", + tokenizer="facebook/s2t-small-mustc-en-fr-st", + framework="pt", + torch_dtype=torch.float16, + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform) + self.assertEqual(output, {"text": "(Applaudissements)"}) + output = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual(output, {"text": "(Applaudissements)"}) + + # Non CTC models cannot use return_timestamps + with self.assertRaisesRegex( + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" + ): + _ = speech_recognizer(waveform, return_timestamps="char") + + @require_torch + def test_small_model_pt_bf16(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/s2t-small-mustc-en-fr-st", + tokenizer="facebook/s2t-small-mustc-en-fr-st", + framework="pt", + torch_dtype=torch.bfloat16, + ) + waveform = np.tile(np.arange(1000, dtype=np.float32), 34) + output = speech_recognizer(waveform) + self.assertEqual(output, {"text": "(Applaudissements)"}) + output = speech_recognizer(waveform, chunk_length_s=10) + self.assertEqual(output, {"text": "(Applaudissements)"}) + + # Non CTC models cannot use return_timestamps + with self.assertRaisesRegex( + ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$" + ): + _ = speech_recognizer(waveform, return_timestamps="char") + @slow @require_torch_accelerator def test_whisper_fp16(self): diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index eda9ac014bf7..63614479134b 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -27,6 +27,7 @@ from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler from transformers.testing_utils import ( is_pipeline_test, + is_torch_available, nested_simplify, require_tf, require_torch, @@ -38,6 +39,10 @@ from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]] # These 2 model types require different inputs than those of the usual text models. @@ -841,6 +846,32 @@ def test_small_model_pt(self): ], ) + @require_torch + def test_small_model_pt_fp16(self): + model_name = "hf-internal-testing/tiny-bert-for-token-classification" + token_classifier = pipeline(task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16) + outputs = token_classifier("This is a test !") + self.assertEqual( + nested_simplify(outputs), + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + ) + + @require_torch + def test_small_model_pt_bf16(self): + model_name = "hf-internal-testing/tiny-bert-for-token-classification" + token_classifier = pipeline(task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16) + outputs = token_classifier("This is a test !") + self.assertEqual( + nested_simplify(outputs), + [ + {"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4}, + {"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7}, + ], + ) + @require_torch def test_pt_ignore_subwords_slow_tokenizer_raises(self): model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" diff --git a/tests/pipelines/test_pipelines_zero_shot.py b/tests/pipelines/test_pipelines_zero_shot.py index 2e61d97c1dc8..a80216127089 100644 --- a/tests/pipelines/test_pipelines_zero_shot.py +++ b/tests/pipelines/test_pipelines_zero_shot.py @@ -21,11 +21,22 @@ ZeroShotClassificationPipeline, pipeline, ) -from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow +from transformers.testing_utils import ( + is_pipeline_test, + is_torch_available, + nested_simplify, + require_tf, + require_torch, + slow, +) from .test_pipelines_common import ANY +if is_torch_available(): + import torch + + # These 2 model types require different inputs than those of the usual text models. _TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"} @@ -176,6 +187,48 @@ def test_small_model_pt(self): }, ) + @require_torch + def test_small_model_pt_fp16(self): + zero_shot_classifier = pipeline( + "zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + framework="pt", + torch_dtype=torch.float16, + ) + outputs = zero_shot_classifier( + "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] + ) + + self.assertEqual( + nested_simplify(outputs), + { + "sequence": "Who are you voting for in 2020?", + "labels": ["science", "public health", "politics"], + "scores": [0.333, 0.333, 0.333], + }, + ) + + @require_torch + def test_small_model_pt_bf16(self): + zero_shot_classifier = pipeline( + "zero-shot-classification", + model="sshleifer/tiny-distilbert-base-cased-distilled-squad", + framework="pt", + torch_dtype=torch.bfloat16, + ) + outputs = zero_shot_classifier( + "Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"] + ) + + self.assertEqual( + nested_simplify(outputs), + { + "sequence": "Who are you voting for in 2020?", + "labels": ["science", "public health", "politics"], + "scores": [0.333, 0.333, 0.333], + }, + ) + @require_tf def test_small_model_tf(self): zero_shot_classifier = pipeline( From c3ac276221712f8e48e3f67e99ce7791c4977173 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Jun 2024 07:57:25 -0400 Subject: [PATCH 02/14] fix parameter for ASR --- src/transformers/pipelines/automatic_speech_recognition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 47a438f9d6c1..f54eecc6b743 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -559,9 +559,9 @@ def postprocess( for outputs in model_outputs: if self.framework == "pt": # To enable using fp16 and bf16 - outputs = outputs[key].float().numpy() + items = outputs[key].float().numpy() else: - outputs = outputs[key].numpy() + items = outputs[key].numpy() stride = outputs.get("stride", None) if stride is not None and self.type in {"ctc", "ctc_with_lm"}: total_n, left, right = stride From 9105ad468eb0bce79a30f32ce2838d43f6083ca7 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Jun 2024 08:00:28 -0400 Subject: [PATCH 03/14] reformat --- src/transformers/pipelines/image_classification.py | 1 - tests/pipelines/test_pipelines_token_classification.py | 8 ++++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index be5ae702abe6..d0c4f343e933 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -23,7 +23,6 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): - from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES logger = logging.get_logger(__name__) diff --git a/tests/pipelines/test_pipelines_token_classification.py b/tests/pipelines/test_pipelines_token_classification.py index 63614479134b..97d46cd194cb 100644 --- a/tests/pipelines/test_pipelines_token_classification.py +++ b/tests/pipelines/test_pipelines_token_classification.py @@ -849,7 +849,9 @@ def test_small_model_pt(self): @require_torch def test_small_model_pt_fp16(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" - token_classifier = pipeline(task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16) + token_classifier = pipeline( + task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16 + ) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), @@ -862,7 +864,9 @@ def test_small_model_pt_fp16(self): @require_torch def test_small_model_pt_bf16(self): model_name = "hf-internal-testing/tiny-bert-for-token-classification" - token_classifier = pipeline(task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16) + token_classifier = pipeline( + task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16 + ) outputs = token_classifier("This is a test !") self.assertEqual( nested_simplify(outputs), From 3fab613297791a921f25dc2685ac076e513f77fb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Jun 2024 08:07:33 -0400 Subject: [PATCH 04/14] fix asr bug --- src/transformers/pipelines/automatic_speech_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index f54eecc6b743..432e8eb82693 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -557,7 +557,7 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - if self.framework == "pt": + if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): # To enable using fp16 and bf16 items = outputs[key].float().numpy() else: From 159bb2d7c6663de7e234580e4e54bab4f46cd959 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 26 Jun 2024 08:14:33 -0400 Subject: [PATCH 05/14] fix bug for zero-shot --- src/transformers/pipelines/zero_shot_classification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 7a4f43e7ec8b..f4aee3341e30 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -239,7 +239,10 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) + if self.framework == "pt": + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) + else: + logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n From 9d12a016710e481739ce1d3e388cc64c978a2227 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 4 Jul 2024 22:23:18 -0400 Subject: [PATCH 06/14] add dtype check --- .../pipelines/automatic_speech_recognition.py | 2 +- .../pipelines/image_classification.py | 4 +++- .../pipelines/token_classification.py | 4 +++- .../pipelines/zero_shot_classification.py | 17 ++++++++++++----- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 432e8eb82693..b8c346999ffb 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -557,7 +557,7 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): + if self.framework == "pt" and outputs[key].dtype == torch.bfloat16: # To enable using fp16 and bf16 items = outputs[key].float().numpy() else: diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index d0c4f343e933..c52a33a40753 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -23,6 +23,8 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES logger = logging.get_logger(__name__) @@ -180,7 +182,7 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5): top_k = self.model.config.num_labels outputs = model_outputs["logits"][0] - if self.framework == "pt": + if self.framework == "pt" and outputs.dtype == torch.bfloat16: # To enable using fp16 and bf16 outputs = outputs.float().numpy() else: diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 12703267fe3e..1173b7f5a3c3 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -19,6 +19,8 @@ from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES if is_torch_available(): + import torch + from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES @@ -299,7 +301,7 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE ignore_labels = ["O"] all_entities = [] for model_outputs in all_outputs: - if self.framework == "pt": + if self.framework == "pt" and model_outputs["logits"][0].dtype == torch.bfloat16: # To enable using fp16 and bf16 logits = model_outputs["logits"][0].float().numpy() else: diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index f4aee3341e30..8d26c4ff10b5 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -4,10 +4,14 @@ import numpy as np from ..tokenization_utils import TruncationStrategy -from ..utils import add_end_docstrings, logging +from ..utils import add_end_docstrings, is_torch_available, logging from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args +if is_torch_available(): + import torch + + logger = logging.get_logger(__name__) @@ -239,10 +243,13 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - if self.framework == "pt": - logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) - else: - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + logits = [] + for output in model_outputs: + if self.framework == "pt" and output["logits"].dtype == torch.bfloat16: + logits.append(output["logits"].float().numpy()) + else: + logits.append(output["logits"].numpy()) + logits = np.concatenate(logits) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n From 9bb1d45da1a72d2ce92bc8f7c5c37b620d2dd190 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 4 Jul 2024 22:24:41 -0400 Subject: [PATCH 07/14] rm useless comments --- src/transformers/pipelines/automatic_speech_recognition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index b8c346999ffb..53bf052ee515 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -558,7 +558,6 @@ def postprocess( stride = None for outputs in model_outputs: if self.framework == "pt" and outputs[key].dtype == torch.bfloat16: - # To enable using fp16 and bf16 items = outputs[key].float().numpy() else: items = outputs[key].numpy() From a3a5be59a6658b8beee8f6eb6debcd4d3535fce9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 4 Jul 2024 22:44:33 -0400 Subject: [PATCH 08/14] add np.float16 check --- src/transformers/testing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 8dda057f1b9d..ac40bf99b686 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -2083,7 +2083,7 @@ def nested_simplify(obj, decimals=3): return nested_simplify(obj.numpy().tolist()) elif isinstance(obj, float): return round(obj, decimals) - elif isinstance(obj, (np.int32, np.float32)): + elif isinstance(obj, (np.int32, np.float32, np.float16)): return nested_simplify(obj.item(), decimals) else: raise Exception(f"Not supported: {type(obj)}") From 73e88794896502fc4bf7bf7267adfa211e49b771 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:58:14 +0800 Subject: [PATCH 09/14] Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/pipelines/image_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index c52a33a40753..15f10339a9d0 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -183,7 +183,7 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5): outputs = model_outputs["logits"][0] if self.framework == "pt" and outputs.dtype == torch.bfloat16: - # To enable using fp16 and bf16 + # To enable using bf16 outputs = outputs.float().numpy() else: outputs = outputs.numpy() From 5480f415cd3d884c375be47b96605d055ba5f3d6 Mon Sep 17 00:00:00 2001 From: jiqing-feng <107918818+jiqing-feng@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:58:22 +0800 Subject: [PATCH 10/14] Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/pipelines/token_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index 1173b7f5a3c3..a25d696cbf1d 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -302,7 +302,7 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE all_entities = [] for model_outputs in all_outputs: if self.framework == "pt" and model_outputs["logits"][0].dtype == torch.bfloat16: - # To enable using fp16 and bf16 + # To enable using bf16 logits = model_outputs["logits"][0].float().numpy() else: logits = model_outputs["logits"][0].numpy() From 76e91dc05063184bd0576727ecece83bba1499db Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 18 Sep 2024 05:35:43 -0400 Subject: [PATCH 11/14] fix comments --- src/transformers/pipelines/automatic_speech_recognition.py | 4 ++-- src/transformers/pipelines/image_classification.py | 5 ++--- src/transformers/pipelines/token_classification.py | 5 ++--- src/transformers/pipelines/zero_shot_classification.py | 4 ++-- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 9e3754477b69..31a6b49c104a 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -557,8 +557,8 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - if self.framework == "pt" and outputs[key].dtype == torch.bfloat16: - items = outputs[key].float().numpy() + if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): + items = outputs[key].to(torch.float32).numpy() else: items = outputs[key].numpy() stride = outputs.get("stride", None) diff --git a/src/transformers/pipelines/image_classification.py b/src/transformers/pipelines/image_classification.py index e47accc68fdd..8aaa66e6c458 100644 --- a/src/transformers/pipelines/image_classification.py +++ b/src/transformers/pipelines/image_classification.py @@ -184,9 +184,8 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=5): top_k = self.model.config.num_labels outputs = model_outputs["logits"][0] - if self.framework == "pt" and outputs.dtype == torch.bfloat16: - # To enable using bf16 - outputs = outputs.float().numpy() + if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): + outputs = outputs.to(torch.float32).numpy() else: outputs = outputs.numpy() diff --git a/src/transformers/pipelines/token_classification.py b/src/transformers/pipelines/token_classification.py index a25d696cbf1d..9256f2381484 100644 --- a/src/transformers/pipelines/token_classification.py +++ b/src/transformers/pipelines/token_classification.py @@ -301,9 +301,8 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE ignore_labels = ["O"] all_entities = [] for model_outputs in all_outputs: - if self.framework == "pt" and model_outputs["logits"][0].dtype == torch.bfloat16: - # To enable using bf16 - logits = model_outputs["logits"][0].float().numpy() + if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16): + logits = model_outputs["logits"][0].to(torch.float32).numpy() else: logits = model_outputs["logits"][0].numpy() diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 8d26c4ff10b5..f57634c97f00 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -245,8 +245,8 @@ def postprocess(self, model_outputs, multi_label=False): sequences = [outputs["sequence"] for outputs in model_outputs] logits = [] for output in model_outputs: - if self.framework == "pt" and output["logits"].dtype == torch.bfloat16: - logits.append(output["logits"].float().numpy()) + if self.framework == "pt" and output["logits"].dtype in (torch.bfloat16, torch.float16): + logits.append(output["logits"].to(torch.float32).numpy()) else: logits.append(output["logits"].numpy()) logits = np.concatenate(logits) From 94cc8b3edf9a12782295ff6c627fa21f28e74802 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 18 Sep 2024 06:52:14 -0400 Subject: [PATCH 12/14] fix asr check --- src/transformers/pipelines/automatic_speech_recognition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/automatic_speech_recognition.py b/src/transformers/pipelines/automatic_speech_recognition.py index 31a6b49c104a..7230ecba2a1f 100644 --- a/src/transformers/pipelines/automatic_speech_recognition.py +++ b/src/transformers/pipelines/automatic_speech_recognition.py @@ -557,7 +557,7 @@ def postprocess( key = "logits" if self.type == "ctc_with_lm" else "tokens" stride = None for outputs in model_outputs: - if self.framework == "pt" and outputs.dtype in (torch.bfloat16, torch.float16): + if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16): items = outputs[key].to(torch.float32).numpy() else: items = outputs[key].numpy() From 0a2ba5b5b08bda33db615128001ed4f5d27ecb13 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Sep 2024 16:07:15 +0100 Subject: [PATCH 13/14] make fixup --- src/transformers/pipelines/zero_shot_classification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 54e56ca198ff..9a30bed1e3ce 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -9,7 +9,7 @@ if is_torch_available(): - import torch + pass logger = logging.get_logger(__name__) From 98959f34bc060fa6e98e7d3de759bb7d8da2ad65 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Sep 2024 16:16:47 +0100 Subject: [PATCH 14/14] No more need for is_torch_available() --- src/transformers/pipelines/zero_shot_classification.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 9a30bed1e3ce..f4aee3341e30 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -4,14 +4,10 @@ import numpy as np from ..tokenization_utils import TruncationStrategy -from ..utils import add_end_docstrings, is_torch_available, logging +from ..utils import add_end_docstrings, logging from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args -if is_torch_available(): - pass - - logger = logging.get_logger(__name__)