-
Notifications
You must be signed in to change notification settings - Fork 31.1k
[Pipeline] Add zero shot audio classificatoin pipeline #21600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
ArthurZucker
merged 17 commits into
huggingface:main
from
ArthurZucker:add-zero-shot-ac-pipeline
Feb 27, 2023
Merged
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
ce0da1e
add pipeline
ArthurZucker 16ee94b
update init
ArthurZucker c3c2033
add zero shot to init
ArthurZucker 332c690
update inits and correct checkpoints
ArthurZucker d322929
update base to support input features
ArthurZucker 8970fa9
add tests
ArthurZucker 73ca9df
Update src/transformers/pipelines/zero_shot_audio_classification.py
ArthurZucker e0dcc6a
Update src/transformers/pipelines/zero_shot_audio_classification.py
ArthurZucker 410bfd7
update pieline code
ArthurZucker 1f40868
Merge branch 'main' of https://github.com/huggingface/transformers in…
ArthurZucker 222b3ca
use tiny checkpoint
ArthurZucker aae29ff
nits and expected value with tiny model
ArthurZucker d60bef1
style
ArthurZucker af21844
last nit on tests values
ArthurZucker 5915102
fix styling
ArthurZucker d76c6c2
fix collate fn that was casting t float
ArthurZucker 220f82d
update
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
src/transformers/pipelines/zero_shot_audio_classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| from typing import Union | ||
|
|
||
| import numpy as np | ||
| import requests | ||
|
|
||
| from ..utils import ( | ||
| add_end_docstrings, | ||
| is_torch_available, | ||
| logging, | ||
| ) | ||
| from .audio_classification import ffmpeg_read | ||
| from .base import PIPELINE_INIT_ARGS, ChunkPipeline | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| @add_end_docstrings(PIPELINE_INIT_ARGS) | ||
| class ZeroShotAudioClassificationPipeline(ChunkPipeline): | ||
| """ | ||
| Zero shot audio classification pipeline using `ClapModel`. This pipeline predicts the class of an audio when you | ||
| provide an audio and a set of `candidate_labels`. | ||
| Example: | ||
| ```python | ||
| >>> from transformers import pipeline | ||
| >>> from datasets import load_dataset | ||
| >>> dataset = load_dataset("ashraq/esc50") | ||
| >>> audio = next(iter(dataset["train"]["audio"]))["array"] | ||
| >>> classifier = pipeline(task="zero-shot-audio-classification", model="laion-ai/clap-htsat-tiny") | ||
| >>> classifier( | ||
ArthurZucker marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ... audio, | ||
| ... candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"], | ||
| ... ) | ||
| [{'score': 0.999727189540863, 'label': 'Sound of a dog'}, {'score': 0.0002727957325987518, 'label': 'Sound of vaccum cleaner'}] | ||
ArthurZucker marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ``` | ||
| Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial) This audio | ||
| classification pipeline can currently be loaded from [`pipeline`] using the following task identifier: | ||
| `"zero-shot-audio-classification"`. See the list of available models on | ||
| [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-audio-classification). | ||
| """ | ||
|
|
||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
|
|
||
| if self.framework != "pt": | ||
| raise ValueError(f"The {self.__class__} is only available in PyTorch.") | ||
| # No specific FOR_XXX available yet | ||
|
|
||
| def __call__( | ||
| self, | ||
| audios: Union[np.ndarray, bytes, str], | ||
| **kwargs, | ||
| ): | ||
|
||
| """ | ||
| Args: | ||
| Assign labels to the audio(s) passed as inputs. | ||
|
||
| audios (`str`, `List[str]`, `np.array` or `List[np.array]`): | ||
| The pipeline handles three types of inputs: | ||
| - A string containing a http link pointing to an audio | ||
| - A string containing a local path to an audio | ||
| - An audio loaded in numpy | ||
| candidate_labels (`List[str]`): | ||
| The candidate labels for this audio | ||
| hypothesis_template (`str`, *optional*, defaults to `"This is a sound of {}"`): | ||
| The sentence used in cunjunction with *candidate_labels* to attempt the audio classification by | ||
| replacing the placeholder with the candidate_labels. Then likelihood is estimated by using | ||
| logits_per_audio | ||
| Return: | ||
| A list of dictionaries containing result, one dictionary per proposed label. The dictionaries contain the | ||
| following keys: | ||
| - **label** (`str`) -- The label identified by the model. It is one of the suggested `candidate_label`. | ||
| - **score** (`float`) -- The score attributed by the model for that label (between 0 and 1). | ||
| """ | ||
| return super().__call__(audios, **kwargs) | ||
|
|
||
| def _sanitize_parameters(self, **kwargs): | ||
| preprocess_params = {} | ||
| if "candidate_labels" in kwargs: | ||
| preprocess_params["candidate_labels"] = kwargs["candidate_labels"] | ||
| if "hypothesis_template" in kwargs: | ||
| preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"] | ||
|
|
||
| return preprocess_params, {}, {} | ||
|
|
||
| def preprocess(self, audio, candidate_labels=None, hypothesis_template="This is a sound of {}."): | ||
| if isinstance(audio, str): | ||
| if audio.startswith("http://") or audio.startswith("https://"): | ||
| # We need to actually check for a real protocol, otherwise it's impossible to use a local file | ||
| # like http_huggingface_co.png | ||
| audio = requests.get(audio).content | ||
| else: | ||
| with open(audio, "rb") as f: | ||
| audio = f.read() | ||
|
|
||
| if isinstance(audio, bytes): | ||
| audio = ffmpeg_read(audio, self.feature_extractor.sampling_rate) | ||
|
|
||
| if not isinstance(audio, np.ndarray): | ||
| raise ValueError("We expect a numpy ndarray as input") | ||
| if len(audio.shape) != 1: | ||
| raise ValueError("We expect a single channel audio input for ZeroShotAudioClassificationPipeline") | ||
|
|
||
| n = len(candidate_labels) | ||
| for i, candidate_label in enumerate(candidate_labels): | ||
| audios = self.feature_extractor( | ||
| audio, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="pt" | ||
| ) | ||
| sequence = hypothesis_template.format(candidate_label) | ||
| inputs = self.tokenizer(sequence, return_tensors=self.framework) | ||
| inputs["input_features"] = audios.input_features | ||
| yield {"is_last": i == n - 1, "candidate_label": candidate_label, **inputs} | ||
|
|
||
| def _forward(self, model_inputs): | ||
| is_last = model_inputs.pop("is_last") | ||
| candidate_label = model_inputs.pop("candidate_label") | ||
| outputs = self.model(**model_inputs) | ||
|
|
||
| # Clap does crossproduct scoring by default, so we're only | ||
| # interested in the results where audio and text and in the same | ||
| # batch position. | ||
| diag = torch.diagonal | ||
| logits_per_audio = diag(outputs.logits_per_audio) | ||
|
|
||
| model_outputs = { | ||
| "is_last": is_last, | ||
| "candidate_label": candidate_label, | ||
| "logits_per_audio": logits_per_audio, | ||
| } | ||
| return model_outputs | ||
|
|
||
| def postprocess(self, model_outputs): | ||
| candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] | ||
| if self.framework == "pt": | ||
| logits = torch.cat([output["logits_per_audio"] for output in model_outputs]) | ||
| probs = logits.softmax(dim=0) | ||
| scores = probs.tolist() | ||
| else: | ||
| raise ValueError("`tf` framework not supported.") | ||
|
|
||
| result = [ | ||
| {"score": score, "label": candidate_label} | ||
| for score, candidate_label in sorted(zip(scores, candidate_labels), key=lambda x: -x[0]) | ||
| ] | ||
| return result | ||
99 changes: 99 additions & 0 deletions
99
tests/pipelines/test_pipelines_zero_shot_audio_classification.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| # Copyright 2023 The HuggingFace Team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| from datasets import load_dataset | ||
|
|
||
| from transformers.pipelines import pipeline | ||
| from transformers.testing_utils import nested_simplify, require_torch, slow | ||
|
|
||
| from .test_pipelines_common import PipelineTestCaseMeta | ||
|
|
||
|
|
||
| @require_torch | ||
| class ZeroShotAudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): | ||
| # Deactivating auto tests since we don't have a good MODEL_FOR_XX mapping, | ||
| # and only CLAP would be there for now. | ||
| # model_mapping = {CLAPConfig: CLAPModel} | ||
|
|
||
| @require_torch | ||
| def test_small_model_pt(self): | ||
| audio_classifier = pipeline( | ||
| task="zero-shot-audio-classification", | ||
| model="hf-testing-internal/clap-htsat-unfused", | ||
| ) | ||
| dataset = load_dataset("ashraq/esc50") | ||
| audio = dataset["train"]["audio"][-1]["array"] | ||
| output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"]) | ||
| self.assertEqual( | ||
| nested_simplify(output), | ||
| [ | ||
| {"score": 0.999, "label": "Sound of a dog"}, | ||
| {"score": 0.001, "label": "Sound of vaccum cleaner"}, | ||
| ], | ||
| ) | ||
|
|
||
| @unittest.skip("No models are available in TF") | ||
| def test_small_model_tf(self): | ||
| pass | ||
|
|
||
| @slow | ||
| @require_torch | ||
| def test_large_model_pt(self): | ||
| audio_classifier = pipeline( | ||
| task="zero-shot-audio-classification", | ||
| model="laion/clap-htsat-unfused", | ||
| ) | ||
| # This is an audio of a dog | ||
| dataset = load_dataset("ashraq/esc50") | ||
| audio = dataset["train"]["audio"][-1]["array"] | ||
| output = audio_classifier(audio, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"]) | ||
|
|
||
| self.assertEqual( | ||
| nested_simplify(output), | ||
| [ | ||
| {"score": 0.999, "label": "Sound of a dog"}, | ||
| {"score": 0.001, "label": "Sound of vaccum cleaner"}, | ||
| ], | ||
| ) | ||
|
|
||
| output = audio_classifier([audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"]) | ||
| self.assertEqual( | ||
| nested_simplify(output), | ||
| [ | ||
| [ | ||
| {"score": 0.999, "label": "Sound of a dog"}, | ||
| {"score": 0.001, "label": "Sound of vaccum cleaner"}, | ||
| ], | ||
| ] | ||
| * 5, | ||
| ) | ||
| output = audio_classifier( | ||
| [audio] * 5, candidate_labels=["Sound of a dog", "Sound of vaccum cleaner"], batch_size=5 | ||
| ) | ||
| self.assertEqual( | ||
| nested_simplify(output), | ||
| [ | ||
| [ | ||
| {"score": 0.999, "label": "Sound of a dog"}, | ||
| {"score": 0.001, "label": "Sound of vaccum cleaner"}, | ||
| ], | ||
| ] | ||
| * 5, | ||
| ) | ||
|
|
||
| @unittest.skip("No models are available in TF") | ||
| def test_large_model_tf(self): | ||
| pass |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Failing test might be related ot this! Will investigate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, fixing this then will be good to go! 🔥