Skip to content

Commit 93202ec

Browse files
authored
Merge pull request #39 from predictionguard/jacob/audio_params
Adding new audio params
2 parents 7f97c57 + 426d80c commit 93202ec

File tree

3 files changed

+51
-22
lines changed

3 files changed

+51
-22
lines changed

predictionguard/src/audio.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,31 @@ def __init__(self, api_key, url):
4242
def create(
4343
self,
4444
model: str,
45-
file: str
45+
file: str,
46+
language: Optional[str] = "auto",
47+
temperature: Optional[float] = 0.0,
48+
prompt: Optional[str] = "",
4649
) -> Dict[str, Any]:
4750
"""
4851
Creates a audio transcription request to the Prediction Guard /audio/transcriptions API
4952
5053
:param model: The model to use
5154
:param file: Audio file to be transcribed
55+
:param language: The language of the audio file
56+
:param temperature: The temperature parameter for model transcription
57+
:param prompt: A prompt to assist in transcription styling
5258
:result: A dictionary containing the transcribed text.
5359
"""
5460

5561
# Create a list of tuples, each containing all the parameters for
5662
# a call to _transcribe_audio
57-
args = (model, file)
63+
args = (model, file, language, temperature, prompt)
5864

5965
# Run _transcribe_audio
6066
choices = self._transcribe_audio(*args)
6167
return choices
6268

63-
def _transcribe_audio(self, model, file):
69+
def _transcribe_audio(self, model, file, language, temperature, prompt):
6470
"""
6571
Function to transcribe an audio file.
6672
"""
@@ -72,7 +78,12 @@ def _transcribe_audio(self, model, file):
7278

7379
with open(file, "rb") as audio_file:
7480
files = {"file": (file, audio_file, "audio/wav")}
75-
data = {"model": model}
81+
data = {
82+
"model": model,
83+
"language": language,
84+
"temperature": temperature,
85+
"prompt": prompt,
86+
}
7687

7788
response = requests.request(
7889
"POST", self.url + "/audio/transcriptions", headers=headers, files=files, data=data

predictionguard/src/chat.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def create(
9595
str, Dict[
9696
str, Dict[str, str]
9797
]
98-
]] = "none",
98+
]] = None,
9999
tools: Optional[List[Dict[str, Union[str, Dict[str, str]]]]] = None,
100100
top_p: Optional[float] = 0.99,
101101
top_k: Optional[float] = 50,
@@ -296,22 +296,40 @@ def stream_generator(url, headers, payload, stream):
296296
elif entry["type"] == "text":
297297
continue
298298

299-
payload_dict = {
300-
"model": model,
301-
"messages": messages,
302-
"frequency_penalty": frequency_penalty,
303-
"logit_bias": logit_bias,
304-
"max_completion_tokens": max_completion_tokens,
305-
"parallel_tool_calls": parallel_tool_calls,
306-
"presence_penalty": presence_penalty,
307-
"stop": stop,
308-
"stream": stream,
309-
"temperature": temperature,
310-
"tool_choice": tool_choice,
311-
"tools": tools,
312-
"top_p": top_p,
313-
"top_k": top_k,
314-
}
299+
# TODO: Remove `tool_choice` check when null value available in API
300+
if tool_choice is None:
301+
payload_dict = {
302+
"model": model,
303+
"messages": messages,
304+
"frequency_penalty": frequency_penalty,
305+
"logit_bias": logit_bias,
306+
"max_completion_tokens": max_completion_tokens,
307+
"parallel_tool_calls": parallel_tool_calls,
308+
"presence_penalty": presence_penalty,
309+
"stop": stop,
310+
"stream": stream,
311+
"temperature": temperature,
312+
"tools": tools,
313+
"top_p": top_p,
314+
"top_k": top_k,
315+
}
316+
else:
317+
payload_dict = {
318+
"model": model,
319+
"messages": messages,
320+
"frequency_penalty": frequency_penalty,
321+
"logit_bias": logit_bias,
322+
"max_completion_tokens": max_completion_tokens,
323+
"parallel_tool_calls": parallel_tool_calls,
324+
"presence_penalty": presence_penalty,
325+
"stop": stop,
326+
"stream": stream,
327+
"temperature": temperature,
328+
"tool_choice": tool_choice,
329+
"tools": tools,
330+
"top_p": top_p,
331+
"top_k": top_k,
332+
}
315333

316334
if input:
317335
payload_dict["input"] = input

predictionguard/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# Setting the package version
2-
__version__ = "2.8.0"
2+
__version__ = "2.8.1"

0 commit comments

Comments
 (0)