diff --git a/speech_recognition/recognizers/whisper_api/base.py b/speech_recognition/recognizers/whisper_api/base.py index c435ef59..7cc15f7f 100644 --- a/speech_recognition/recognizers/whisper_api/base.py +++ b/speech_recognition/recognizers/whisper_api/base.py @@ -1,7 +1,10 @@ +import logging from io import BytesIO from speech_recognition.audio import AudioData +logger = logging.getLogger(__name__) + class OpenAICompatibleRecognizer: def __init__(self, client) -> None: @@ -16,7 +19,10 @@ def recognize(self, audio_data: "AudioData", model: str, **kwargs) -> str: wav_data = BytesIO(audio_data.get_wav_data()) wav_data.name = "SpeechRecognition_audio.wav" + parameters = {"model": model, **kwargs} + logger.debug(parameters) + transcript = self.client.audio.transcriptions.create( - file=wav_data, model=model, **kwargs + file=wav_data, **parameters ) return transcript.text diff --git a/speech_recognition/recognizers/whisper_api/openai.py b/speech_recognition/recognizers/whisper_api/openai.py index bf407380..7f66b56c 100644 --- a/speech_recognition/recognizers/whisper_api/openai.py +++ b/speech_recognition/recognizers/whisper_api/openai.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from typing import Literal from typing_extensions import Unpack @@ -65,16 +66,31 @@ def recognize( parser = argparse.ArgumentParser() parser.add_argument("audio_file") parser.add_argument( - "--model", choices=get_args(WhisperModel), default="whisper-1" + "-m", "--model", choices=get_args(WhisperModel), default="whisper-1" ) parser.add_argument("-l", "--language") + parser.add_argument("-p", "--prompt") + parser.add_argument("-v", "--verbose", action="store_true") args = parser.parse_args() + if args.verbose: + speech_recognition_logger = logging.getLogger("speech_recognition") + speech_recognition_logger.setLevel(logging.DEBUG) + + console_handler = logging.StreamHandler() + console_formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(name)s:%(funcName)s:%(lineno)d - %(message)s" + ) + console_handler.setFormatter(console_formatter) + speech_recognition_logger.addHandler(console_handler) + audio_data = sr.AudioData.from_file(args.audio_file) + + recognize_args = {"model": args.model} if args.language: - transcription = recognize( - None, audio_data, model=args.model, language=args.language - ) - else: - transcription = recognize(None, audio_data, model=args.model) + recognize_args["language"] = args.language + if args.prompt: + recognize_args["prompt"] = args.prompt + + transcription = recognize(None, audio_data, **recognize_args) print(transcription)