Skip to content

Commit 3d614cd

Browse files
authored
Merge pull request #44 from NavodPeiris/dev
added missing cuda support for custom and hf models
2 parents 4085aa7 + eaf195b commit 3d614cd

3 files changed

Lines changed: 12 additions & 7 deletions

File tree

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="speechlib",
8-
version="1.1.4",
8+
version="1.1.5",
99
description="speechlib is a library that can do speaker diarization, transcription and speaker recognition on an audio file to create transcripts with actual speaker names. This library also contain audio preprocessor functions.",
1010
packages=find_packages(),
1111
long_description=long_description,

setup_instruction.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ for publishing:
99
pip install twine
1010

1111
for install locally for testing:
12-
pip install dist/speechlib-1.1.4-py3-none-any.whl
12+
pip install dist/speechlib-1.1.5-py3-none-any.whl
1313

1414
finally run:
1515
twine upload dist/*

speechlib/transcribe.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def transcribe(file, language, model_size, model_type, quantization, custom_mode
5353
print("model fodler: ", model_folder)
5454
try:
5555
if torch.cuda.is_available():
56-
model = whisper.load_model(custom_model_path, download_root=model_folder)
56+
model = whisper.load_model(custom_model_path, download_root=model_folder, device="cuda")
5757
result = model.transcribe(file, language=language, fp16=True)
5858
res = result["text"]
5959
else:
60-
model = whisper.load_model(custom_model_path, download_root=model_folder)
60+
model = whisper.load_model(custom_model_path, download_root=model_folder, device="cpu")
6161
result = model.transcribe(file, language=language, fp16=False)
6262
res = result["text"]
6363

@@ -66,9 +66,14 @@ def transcribe(file, language, model_size, model_type, quantization, custom_mode
6666
raise Exception(f"an error occured while transcribing: {err}")
6767
elif model_type == "huggingface":
6868
try:
69-
pipe = pipeline("automatic-speech-recognition", model=hf_model_path)
70-
result = pipe(file)
71-
res = result['text']
69+
if torch.cuda.is_available():
70+
pipe = pipeline("automatic-speech-recognition", model=hf_model_path, device="cuda")
71+
result = pipe(file)
72+
res = result['text']
73+
else:
74+
pipe = pipeline("automatic-speech-recognition", model=hf_model_path, device="cpu")
75+
result = pipe(file)
76+
res = result['text']
7277
return res
7378
except Exception as err:
7479
raise Exception(f"an error occured while transcribing: {err}")

0 commit comments

Comments
 (0)