Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,12 @@ Pipelines available for audio tasks include the following.
- __call__
- all

### ZeroShotAudioClassificationPipeline

[[autodoc]] ZeroShotAudioClassificationPipeline
- __call__
- all

## Computer vision

Pipelines available for computer vision tasks include the following.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@
"TranslationPipeline",
"VideoClassificationPipeline",
"VisualQuestionAnsweringPipeline",
"ZeroShotAudioClassificationPipeline",
"ZeroShotClassificationPipeline",
"ZeroShotImageClassificationPipeline",
"ZeroShotObjectDetectionPipeline",
Expand Down Expand Up @@ -4007,6 +4008,7 @@
TranslationPipeline,
VideoClassificationPipeline,
VisualQuestionAnsweringPipeline,
ZeroShotAudioClassificationPipeline,
ZeroShotClassificationPipeline,
ZeroShotImageClassificationPipeline,
ZeroShotObjectDetectionPipeline,
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
)
from .video_classification import VideoClassificationPipeline
from .visual_question_answering import VisualQuestionAnsweringPipeline
from .zero_shot_audio_classification import ZeroShotAudioClassificationPipeline
from .zero_shot_classification import ZeroShotClassificationArgumentHandler, ZeroShotClassificationPipeline
from .zero_shot_image_classification import ZeroShotImageClassificationPipeline
from .zero_shot_object_detection import ZeroShotObjectDetectionPipeline
Expand Down Expand Up @@ -299,6 +300,17 @@
},
"type": "multimodal",
},
"zero-shot-audio-classification": {
"impl": ZeroShotAudioClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("laion/clap-htsat-fused", "f39917b"),
}
},
"type": "multimodal",
},
"conversational": {
"impl": ConversationalPipeline,
"tf": (TFAutoModelForSeq2SeqLM, TFAutoModelForCausalLM) if is_tf_available() else (),
Expand Down Expand Up @@ -534,6 +546,7 @@ def pipeline(
- `"visual-question-answering"`: will return a [`VisualQuestionAnsweringPipeline`].
- `"zero-shot-classification"`: will return a [`ZeroShotClassificationPipeline`].
- `"zero-shot-image-classification"`: will return a [`ZeroShotImageClassificationPipeline`].
- `"zero-shot-audio-classification"`: will return a [`ZeroShotAudioClassificationPipeline`].
- `"zero-shot-object-detection"`: will return a [`ZeroShotObjectDetectionPipeline`].

model (`str` or [`PreTrainedModel`] or [`TFPreTrainedModel`], *optional*):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _pad(items, key, padding_value, padding_side):
# Others include `attention_mask` etc...
shape = items[0][key].shape
dim = len(shape)
if key in ["pixel_values", "image"]:
if key in ["pixel_values", "image", "input_features"]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failing test might be related ot this! Will investigate

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, fixing this then will be good to go! 🔥

# This is probable image so padding shouldn't be necessary
# B, C, H, W
return torch.cat([item[key] for item in items], dim=0)
Expand Down
153 changes: 153 additions & 0 deletions src/transformers/pipelines/zero_shot_audio_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Union

import numpy as np
import requests

from ..utils import (
add_end_docstrings,
is_torch_available,
logging,
)
from .audio_classification import ffmpeg_read
from .base import PIPELINE_INIT_ARGS, ChunkPipeline


if is_torch_available():
import torch


logger = logging.get_logger(__name__)


@add_end_docstrings(PIPELINE_INIT_ARGS)
class ZeroShotAudioClassificationPipeline(ChunkPipeline):
"""
Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you
provide an audio and a set of `candidate_labels`.
Example:
```python
>>> from transformers import pipeline
>>> from datasets import load_dataset
>>> dataset = load_dataset("ashraq/esc50")
>>> audio = next(iter(dataset["train"]["audio"]))["array"]
>>> classifier = pipeline(task="zero-shot-audio-classification", model="laion-ai/clap-hsat-tiny")
>>> classifier(
... audio,
... candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"],
... )
[{'score': 0.999727189540863, 'label': 'Sound of a dog'}, {'score': 0.0002727957325987518, 'label': 'Sound of vaccum cleaner'}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) This audio
classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"zero-shot-audio-classification"`. See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification).
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)

if self.framework != "pt":
raise ValueError(f"The {self.__class__} is only available in PyTorch.")
# No specific FOR_XXX available yet
# self.check_model_type(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING)

def __call__(
self,
audios: Union[np.ndarray, bytes, str],
**kwargs,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this fits in one line.

"""
Args:
Assign labels to the audio(s) passed as inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The args should be below the description.

audios (`str`, `List[str]`, `np.array` or `List[np.array]`):
The pipeline handles three types of inputs:
- A string containing a http link pointing to an audio
- A string containing a local path to an audio
- An audio loaded in numpy
candidate_labels (`List[str]`):
The candidate labels for this audio
hypothesis_template (`str`, *optional*, defaults to `"This is a sound of {}"`):
The sentence used in cunjunction with *candidate_labels* to attempt the audio classification by
replacing the placeholder with the candidate_labels. Then likelihood is estimated by using
logits_per_audio
Return:
A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the
following keys:
- **label** (`str`) -- The label identified by the model. It is one of the suggested `candidate_label`.
- **score** (`float`) -- The score attributed by the model for that label (between 0 and 1).
"""
return super().__call__(audios, **kwargs)

def _sanitize_parameters(self, **kwargs):
preprocess_params = {}
if "candidate_labels" in kwargs:
preprocess_params["candidate_labels"] = kwargs["candidate_labels"]
if "hypothesis_template" in kwargs:
preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]

return preprocess_params, {}, {}

def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is a sound of {}."):
if isinstance(audio, str):
if audio.startswith("http://") or audio.startswith("https://"):
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
# like http_huggingface_co.png
audio = requests.get(audio).content
else:
with open(audio, "rb") as f:
audio = f.read()

if isinstance(audio, bytes):
audio = ffmpeg_read(audio, self.feature_extractor.sampling_rate)

if not isinstance(audio, np.ndarray):
raise ValueError("We expect a numpy ndarray as input")
if len(audio.shape) != 1:
raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline")

n = len(candidate_labels)
for i, candidate_label in enumerate(candidate_labels):
audios = self.feature_extractor(
audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt"
)
sequence = hypothesis_template.format(candidate_label)
inputs = self.tokenizer(sequence, return_tensors=self.framework)
inputs["input_features"] = audios.input_features
yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs}

def _forward(self, model_inputs):
is_last = model_inputs.pop("is_last")
candidate_label = model_inputs.pop("candidate_label")
outputs = self.model(**model_inputs)

# Clap does crossproduct scoring by default, so we're only
# interested in the results where audio and text and in the same
# batch position.
diag = torch.diagonal
logits_per_audio = diag(outputs.logits_per_audio)

model_outputs = {
"is_last": is_last,
"candidate_label": candidate_label,
"logits_per_audio": logits_per_audio,
}
return model_outputs

def postprocess(self, model_outputs):
candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
if self.framework == "pt":
logits = torch.cat([output["logits_per_audio"] for output in model_outputs])
probs = logits.softmax(dim=0)
scores = probs.tolist()
else:
raise ValueError("`tf` framework not supported.")

result = [
{"score": score, "label": candidate_label}
for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0])
]
return result
99 changes: 99 additions & 0 deletions tests/pipelines/test_pipelines_zero_shot_audio_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from datasets import load_dataset

from transformers.pipelines import pipeline
from transformers.testing_utils import nested_simplify, require_torch, slow

from .test_pipelines_common import PipelineTestCaseMeta


@require_torch
class ZeroShotAudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
# Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping,
# and only CLAP would be there for now.
# model_mapping = {CLAPConfig: CLAPModel}

@require_torch
def test_small_model_pt(self):
audio_classifier = pipeline(
task="zero-shot-audio-classification",
model="hf-testing-internal/clap-htsat-unfused",
)
dataset = load_dataset("ashraq/esc50")
audio = dataset["train"]["audio"][-1]["array"]
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
)

@unittest.skip("No models are available in TF")
def test_small_model_tf(self):
pass

@slow
@require_torch
def test_large_model_pt(self):
audio_classifier = pipeline(
task="zero-shot-audio-classification",
model="laion/clap-htsat-unfused",
)
# This is an audio of a dog
dataset = load_dataset("ashraq/esc50")
audio = dataset["train"]["audio"][-1]["array"]
output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])

self.assertEqual(
nested_simplify(output),
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
)

output = audio_classifier([audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"])
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
]
* 5,
)
output = audio_classifier(
[audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"], batch_size=5
)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.999, "label": "Sound of a dog"},
{"score": 0.001, "label": "Sound of vaccum cleaner"},
],
]
* 5,
)

@unittest.skip("No models are available in TF")
def test_large_model_tf(self):
pass