Skip to content

Commit 7137c75

Browse files
committed
feat(backends): add moonshine backend for faster transcription
Signed-off-by: Ettore Di Giacinto <[email protected]>
1 parent 5a9698b commit 7137c75

9 files changed

Lines changed: 337 additions & 0 deletions

File tree

.github/workflows/test-extra.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,22 @@ jobs:
247247
run: |
248248
make --jobs=5 --output-sync=target -C backend/python/coqui
249249
make --jobs=5 --output-sync=target -C backend/python/coqui test
250+
tests-moonshine:
251+
runs-on: ubuntu-latest
252+
steps:
253+
- name: Clone
254+
uses: actions/checkout@v6
255+
with:
256+
submodules: true
257+
- name: Dependencies
258+
run: |
259+
sudo apt-get update
260+
sudo apt-get install build-essential ffmpeg
261+
sudo apt-get install -y ca-certificates cmake curl patch python3-pip
262+
# Install UV
263+
curl -LsSf https://astral.sh/uv/install.sh | sh
264+
pip install --user --no-cache-dir grpcio-tools==1.64.1
265+
- name: Test moonshine
266+
run: |
267+
make --jobs=5 --output-sync=target -C backend/python/moonshine
268+
make --jobs=5 --output-sync=target -C backend/python/moonshine test

backend/python/moonshine/Makefile

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.DEFAULT_GOAL := install
2+
3+
.PHONY: install
4+
install:
5+
bash install.sh
6+
7+
.PHONY: protogen-clean
8+
protogen-clean:
9+
$(RM) backend_pb2_grpc.py backend_pb2.py
10+
11+
.PHONY: clean
12+
clean: protogen-clean
13+
rm -rf venv __pycache__
14+
15+
test: install
16+
bash test.sh
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#!/usr/bin/env python3
2+
"""
3+
This is an extra gRPC server of LocalAI for Moonshine transcription
4+
"""
5+
from concurrent import futures
6+
import time
7+
import argparse
8+
import signal
9+
import sys
10+
import os
11+
import backend_pb2
12+
import backend_pb2_grpc
13+
import moonshine_onnx
14+
15+
import grpc
16+
17+
18+
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
19+
20+
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
21+
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
22+
23+
# Implement the BackendServicer class with the service methods
24+
class BackendServicer(backend_pb2_grpc.BackendServicer):
25+
"""
26+
BackendServicer is the class that implements the gRPC service
27+
"""
28+
def Health(self, request, context):
29+
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
30+
31+
def LoadModel(self, request, context):
32+
try:
33+
print("Preparing models, please wait", file=sys.stderr)
34+
# Store the model name for use in transcription
35+
# Model name format: e.g., "moonshine/tiny"
36+
self.model_name = request.Model
37+
print(f"Model name set to: {self.model_name}", file=sys.stderr)
38+
except Exception as err:
39+
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
40+
return backend_pb2.Result(message="Model loaded successfully", success=True)
41+
42+
def AudioTranscription(self, request, context):
43+
resultSegments = []
44+
text = ""
45+
try:
46+
# moonshine_onnx.transcribe returns a list of strings
47+
transcriptions = moonshine_onnx.transcribe(request.dst, self.model_name)
48+
49+
# Combine all transcriptions into a single text
50+
if isinstance(transcriptions, list):
51+
text = " ".join(transcriptions)
52+
# Create segments for each transcription in the list
53+
for id, trans in enumerate(transcriptions):
54+
# Since moonshine doesn't provide timing info, we'll create a single segment
55+
# with id and text, using approximate timing
56+
resultSegments.append(backend_pb2.TranscriptSegment(
57+
id=id,
58+
start=0,
59+
end=0,
60+
text=trans
61+
))
62+
else:
63+
# Handle case where it's not a list (shouldn't happen, but be safe)
64+
text = str(transcriptions)
65+
resultSegments.append(backend_pb2.TranscriptSegment(
66+
id=0,
67+
start=0,
68+
end=0,
69+
text=text
70+
))
71+
except Exception as err:
72+
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
73+
return backend_pb2.TranscriptResult(segments=[], text="")
74+
75+
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
76+
77+
def serve(address):
78+
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
79+
options=[
80+
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
81+
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
82+
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
83+
])
84+
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
85+
server.add_insecure_port(address)
86+
server.start()
87+
print("Server started. Listening on: " + address, file=sys.stderr)
88+
89+
# Define the signal handler function
90+
def signal_handler(sig, frame):
91+
print("Received termination signal. Shutting down...")
92+
server.stop(0)
93+
sys.exit(0)
94+
95+
# Set the signal handlers for SIGINT and SIGTERM
96+
signal.signal(signal.SIGINT, signal_handler)
97+
signal.signal(signal.SIGTERM, signal_handler)
98+
99+
try:
100+
while True:
101+
time.sleep(_ONE_DAY_IN_SECONDS)
102+
except KeyboardInterrupt:
103+
server.stop(0)
104+
105+
if __name__ == "__main__":
106+
parser = argparse.ArgumentParser(description="Run the gRPC server.")
107+
parser.add_argument(
108+
"--addr", default="localhost:50051", help="The address to bind the server to."
109+
)
110+
args = parser.parse_args()
111+
112+
serve(args.addr)
113+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
set -e
3+
4+
backend_dir=$(dirname $0)
5+
if [ -d $backend_dir/common ]; then
6+
source $backend_dir/common/libbackend.sh
7+
else
8+
source $backend_dir/../common/libbackend.sh
9+
fi
10+
11+
installRequirements
12+
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
set -e
3+
4+
backend_dir=$(dirname $0)
5+
if [ -d $backend_dir/common ]; then
6+
source $backend_dir/common/libbackend.sh
7+
else
8+
source $backend_dir/../common/libbackend.sh
9+
fi
10+
11+
python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto
12+
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
grpcio==1.71.0
2+
protobuf
3+
grpcio-tools
4+
useful-moonshine-onnx@git+https://[email protected]/moonshine-ai/moonshine.git#subdirectory=moonshine-onnx

backend/python/moonshine/run.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
backend_dir=$(dirname $0)
3+
if [ -d $backend_dir/common ]; then
4+
source $backend_dir/common/libbackend.sh
5+
else
6+
source $backend_dir/../common/libbackend.sh
7+
fi
8+
9+
startBackend $@
10+

backend/python/moonshine/test.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
A test script to test the gRPC service for Moonshine transcription
3+
"""
4+
import unittest
5+
import subprocess
6+
import time
7+
import os
8+
import tempfile
9+
import shutil
10+
import backend_pb2
11+
import backend_pb2_grpc
12+
13+
import grpc
14+
15+
16+
class TestBackendServicer(unittest.TestCase):
17+
"""
18+
TestBackendServicer is the class that tests the gRPC service
19+
"""
20+
def setUp(self):
21+
"""
22+
This method sets up the gRPC service by starting the server
23+
"""
24+
self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"])
25+
time.sleep(10)
26+
27+
def tearDown(self) -> None:
28+
"""
29+
This method tears down the gRPC service by terminating the server
30+
"""
31+
self.service.terminate()
32+
self.service.wait()
33+
34+
def test_server_startup(self):
35+
"""
36+
This method tests if the server starts up successfully
37+
"""
38+
try:
39+
self.setUp()
40+
with grpc.insecure_channel("localhost:50051") as channel:
41+
stub = backend_pb2_grpc.BackendStub(channel)
42+
response = stub.Health(backend_pb2.HealthMessage())
43+
self.assertEqual(response.message, b'OK')
44+
except Exception as err:
45+
print(err)
46+
self.fail("Server failed to start")
47+
finally:
48+
self.tearDown()
49+
50+
def test_load_model(self):
51+
"""
52+
This method tests if the model is loaded successfully
53+
"""
54+
try:
55+
self.setUp()
56+
with grpc.insecure_channel("localhost:50051") as channel:
57+
stub = backend_pb2_grpc.BackendStub(channel)
58+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
59+
self.assertTrue(response.success)
60+
self.assertEqual(response.message, "Model loaded successfully")
61+
except Exception as err:
62+
print(err)
63+
self.fail("LoadModel service failed")
64+
finally:
65+
self.tearDown()
66+
67+
def test_audio_transcription(self):
68+
"""
69+
This method tests if audio transcription works successfully
70+
"""
71+
# Create a temporary directory for the audio file
72+
temp_dir = tempfile.mkdtemp()
73+
audio_file = os.path.join(temp_dir, 'audio.wav')
74+
75+
try:
76+
# Download the audio file to the temporary directory
77+
print(f"Downloading audio file to {audio_file}...")
78+
url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav"
79+
result = subprocess.run(
80+
["wget", "-q", url, "-O", audio_file],
81+
capture_output=True,
82+
text=True
83+
)
84+
if result.returncode != 0:
85+
self.fail(f"Failed to download audio file: {result.stderr}")
86+
87+
# Verify the file was downloaded
88+
if not os.path.exists(audio_file):
89+
self.fail(f"Audio file was not downloaded to {audio_file}")
90+
91+
self.setUp()
92+
with grpc.insecure_channel("localhost:50051") as channel:
93+
stub = backend_pb2_grpc.BackendStub(channel)
94+
# Load the model first
95+
load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny"))
96+
self.assertTrue(load_response.success)
97+
98+
# Perform transcription
99+
transcript_request = backend_pb2.TranscriptRequest(dst=audio_file)
100+
transcript_response = stub.AudioTranscription(transcript_request)
101+
102+
# Print the transcribed text for debugging
103+
print(f"Transcribed text: {transcript_response.text}")
104+
print(f"Number of segments: {len(transcript_response.segments)}")
105+
106+
# Verify response structure
107+
self.assertIsNotNone(transcript_response)
108+
self.assertIsNotNone(transcript_response.text)
109+
# Protobuf repeated fields return a sequence, not a list
110+
self.assertIsNotNone(transcript_response.segments)
111+
# Check if segments is iterable (has length)
112+
self.assertGreaterEqual(len(transcript_response.segments), 0)
113+
114+
# Verify the transcription contains the expected text
115+
expected_text = "This is the micro machine man presenting the most midget miniature"
116+
self.assertIn(
117+
expected_text.lower(),
118+
transcript_response.text.lower(),
119+
f"Expected text '{expected_text}' not found in transcription: '{transcript_response.text}'"
120+
)
121+
122+
# If we got segments, verify they have the expected structure
123+
if len(transcript_response.segments) > 0:
124+
segment = transcript_response.segments[0]
125+
self.assertIsNotNone(segment.text)
126+
self.assertIsInstance(segment.id, int)
127+
else:
128+
# Even if no segments, we should have text
129+
self.assertIsNotNone(transcript_response.text)
130+
self.assertGreater(len(transcript_response.text), 0)
131+
except Exception as err:
132+
print(err)
133+
self.fail("AudioTranscription service failed")
134+
finally:
135+
self.tearDown()
136+
# Clean up the temporary directory
137+
if os.path.exists(temp_dir):
138+
shutil.rmtree(temp_dir)
139+

backend/python/moonshine/test.sh

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
set -e
3+
4+
backend_dir=$(dirname $0)
5+
if [ -d $backend_dir/common ]; then
6+
source $backend_dir/common/libbackend.sh
7+
else
8+
source $backend_dir/../common/libbackend.sh
9+
fi
10+
11+
runUnittests
12+

0 commit comments

Comments
 (0)