Skip to content

Commit 49a0bef

Browse files
jiqing-fengSunMarcRocketknight1
authored
enable low-precision pipeline (#31625)
* enable low-precision pipeline * fix parameter for ASR * reformat * fix asr bug * fix bug for zero-shot * add dtype check * rm useless comments * add np.float16 check * Update src/transformers/pipelines/image_classification.py Co-authored-by: Marc Sun <[email protected]> * Update src/transformers/pipelines/token_classification.py Co-authored-by: Marc Sun <[email protected]> * fix comments * fix asr check * make fixup * No more need for is_torch_available() --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Matt <[email protected]> Co-authored-by: Matt <[email protected]>
1 parent 7b2b536 commit 49a0bef

File tree

6 files changed

+143
-4
lines changed

6 files changed

+143
-4
lines changed

src/transformers/pipelines/automatic_speech_recognition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,10 @@ def postprocess(
565565
key = "logits" if self.type == "ctc_with_lm" else "tokens"
566566
stride = None
567567
for outputs in model_outputs:
568-
items = outputs[key].numpy()
568+
if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
569+
items = outputs[key].to(torch.float32).numpy()
570+
else:
571+
items = outputs[key].numpy()
569572
stride = outputs.get("stride", None)
570573
if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
571574
total_n, left, right = stride

src/transformers/pipelines/token_classification.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
2121
if is_torch_available():
22+
import torch
23+
2224
from ..models.auto.modeling_auto import MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
2325

2426

@@ -299,7 +301,11 @@ def postprocess(self, all_outputs, aggregation_strategy=AggregationStrategy.NONE
299301
ignore_labels = ["O"]
300302
all_entities = []
301303
for model_outputs in all_outputs:
302-
logits = model_outputs["logits"][0].numpy()
304+
if self.framework == "pt" and model_outputs["logits"][0].dtype in (torch.bfloat16, torch.float16):
305+
logits = model_outputs["logits"][0].to(torch.float32).numpy()
306+
else:
307+
logits = model_outputs["logits"][0].numpy()
308+
303309
sentence = all_outputs[0]["sentence"]
304310
input_ids = model_outputs["input_ids"][0]
305311
offset_mapping = (

src/transformers/testing_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2143,7 +2143,7 @@ def nested_simplify(obj, decimals=3):
21432143
return nested_simplify(obj.numpy().tolist())
21442144
elif isinstance(obj, float):
21452145
return round(obj, decimals)
2146-
elif isinstance(obj, (np.int32, np.float32)):
2146+
elif isinstance(obj, (np.int32, np.float32, np.float16)):
21472147
return nested_simplify(obj.item(), decimals)
21482148
else:
21492149
raise Exception(f"Not supported: {type(obj)}")

tests/pipelines/test_pipelines_automatic_speech_recognition.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,48 @@ def test_small_model_pt(self):
167167
):
168168
_ = speech_recognizer(waveform, return_timestamps="char")
169169

170+
@require_torch
171+
def test_small_model_pt_fp16(self):
172+
speech_recognizer = pipeline(
173+
task="automatic-speech-recognition",
174+
model="facebook/s2t-small-mustc-en-fr-st",
175+
tokenizer="facebook/s2t-small-mustc-en-fr-st",
176+
framework="pt",
177+
torch_dtype=torch.float16,
178+
)
179+
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
180+
output = speech_recognizer(waveform)
181+
self.assertEqual(output, {"text": "(Applaudissements)"})
182+
output = speech_recognizer(waveform, chunk_length_s=10)
183+
self.assertEqual(output, {"text": "(Applaudissements)"})
184+
185+
# Non CTC models cannot use return_timestamps
186+
with self.assertRaisesRegex(
187+
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
188+
):
189+
_ = speech_recognizer(waveform, return_timestamps="char")
190+
191+
@require_torch
192+
def test_small_model_pt_bf16(self):
193+
speech_recognizer = pipeline(
194+
task="automatic-speech-recognition",
195+
model="facebook/s2t-small-mustc-en-fr-st",
196+
tokenizer="facebook/s2t-small-mustc-en-fr-st",
197+
framework="pt",
198+
torch_dtype=torch.bfloat16,
199+
)
200+
waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
201+
output = speech_recognizer(waveform)
202+
self.assertEqual(output, {"text": "(Applaudissements)"})
203+
output = speech_recognizer(waveform, chunk_length_s=10)
204+
self.assertEqual(output, {"text": "(Applaudissements)"})
205+
206+
# Non CTC models cannot use return_timestamps
207+
with self.assertRaisesRegex(
208+
ValueError, "^We cannot return_timestamps yet on non-CTC models apart from Whisper!$"
209+
):
210+
_ = speech_recognizer(waveform, return_timestamps="char")
211+
170212
@slow
171213
@require_torch_accelerator
172214
def test_whisper_fp16(self):

tests/pipelines/test_pipelines_token_classification.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers.pipelines import AggregationStrategy, TokenClassificationArgumentHandler
2828
from transformers.testing_utils import (
2929
is_pipeline_test,
30+
is_torch_available,
3031
nested_simplify,
3132
require_tf,
3233
require_torch,
@@ -38,6 +39,10 @@
3839
from .test_pipelines_common import ANY
3940

4041

42+
if is_torch_available():
43+
import torch
44+
45+
4146
VALID_INPUTS = ["A simple string", ["list of strings", "A simple string that is quite a bit longer"]]
4247

4348
# These 2 model types require different inputs than those of the usual text models.
@@ -841,6 +846,36 @@ def test_small_model_pt(self):
841846
],
842847
)
843848

849+
@require_torch
850+
def test_small_model_pt_fp16(self):
851+
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
852+
token_classifier = pipeline(
853+
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.float16
854+
)
855+
outputs = token_classifier("This is a test !")
856+
self.assertEqual(
857+
nested_simplify(outputs),
858+
[
859+
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
860+
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
861+
],
862+
)
863+
864+
@require_torch
865+
def test_small_model_pt_bf16(self):
866+
model_name = "hf-internal-testing/tiny-bert-for-token-classification"
867+
token_classifier = pipeline(
868+
task="token-classification", model=model_name, framework="pt", torch_dtype=torch.bfloat16
869+
)
870+
outputs = token_classifier("This is a test !")
871+
self.assertEqual(
872+
nested_simplify(outputs),
873+
[
874+
{"entity": "I-MISC", "score": 0.115, "index": 1, "word": "this", "start": 0, "end": 4},
875+
{"entity": "I-MISC", "score": 0.115, "index": 2, "word": "is", "start": 5, "end": 7},
876+
],
877+
)
878+
844879
@require_torch
845880
def test_pt_ignore_subwords_slow_tokenizer_raises(self):
846881
model_name = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"

tests/pipelines/test_pipelines_zero_shot.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,22 @@
2121
ZeroShotClassificationPipeline,
2222
pipeline,
2323
)
24-
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
24+
from transformers.testing_utils import (
25+
is_pipeline_test,
26+
is_torch_available,
27+
nested_simplify,
28+
require_tf,
29+
require_torch,
30+
slow,
31+
)
2532

2633
from .test_pipelines_common import ANY
2734

2835

36+
if is_torch_available():
37+
import torch
38+
39+
2940
# These 2 model types require different inputs than those of the usual text models.
3041
_TO_SKIP = {"LayoutLMv2Config", "LayoutLMv3Config"}
3142

@@ -176,6 +187,48 @@ def test_small_model_pt(self):
176187
},
177188
)
178189

190+
@require_torch
191+
def test_small_model_pt_fp16(self):
192+
zero_shot_classifier = pipeline(
193+
"zero-shot-classification",
194+
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
195+
framework="pt",
196+
torch_dtype=torch.float16,
197+
)
198+
outputs = zero_shot_classifier(
199+
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
200+
)
201+
202+
self.assertEqual(
203+
nested_simplify(outputs),
204+
{
205+
"sequence": "Who are you voting for in 2020?",
206+
"labels": ["science", "public health", "politics"],
207+
"scores": [0.333, 0.333, 0.333],
208+
},
209+
)
210+
211+
@require_torch
212+
def test_small_model_pt_bf16(self):
213+
zero_shot_classifier = pipeline(
214+
"zero-shot-classification",
215+
model="sshleifer/tiny-distilbert-base-cased-distilled-squad",
216+
framework="pt",
217+
torch_dtype=torch.bfloat16,
218+
)
219+
outputs = zero_shot_classifier(
220+
"Who are you voting for in 2020?", candidate_labels=["politics", "public health", "science"]
221+
)
222+
223+
self.assertEqual(
224+
nested_simplify(outputs),
225+
{
226+
"sequence": "Who are you voting for in 2020?",
227+
"labels": ["science", "public health", "politics"],
228+
"scores": [0.333, 0.333, 0.333],
229+
},
230+
)
231+
179232
@require_tf
180233
def test_small_model_tf(self):
181234
zero_shot_classifier = pipeline(

0 commit comments

Comments
 (0)