Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions pkg/docker/cog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import string
import random
import signal
import requests
from io import BytesIO
Expand Down Expand Up @@ -222,18 +220,19 @@ def __init__(
redis_host: str,
redis_port: int,
input_queue: str,
consumer_id: str,
upload_url: str,
redis_db: int = 0,
):
self.model = model
self.redis_host = redis_host
self.redis_port = redis_port
self.input_queue = input_queue
self.consumer_id = consumer_id
self.upload_url = upload_url
self.redis_db = redis_db
# TODO: respect max_processing_time in message handling
self.max_processing_time = 10 * 60 # timeout after 10 minutes
self.redis_consumer_id = random_string()
self.redis = redis.Redis(
host=self.redis_host, port=self.redis_port, db=self.redis_db
)
Expand All @@ -253,7 +252,7 @@ def receive_message(self):
"XAUTOCLAIM",
self.input_queue,
self.input_queue,
self.redis_consumer_id,
self.consumer_id,
str(self.max_processing_time * 1000),
"0-0",
"COUNT",
Expand All @@ -268,7 +267,7 @@ def receive_message(self):
# if no old messages exist, get message from main queue
raw_messages = self.redis.xreadgroup(
groupname=self.input_queue,
consumername=self.redis_consumer_id,
consumername=self.consumer_id,
streams={self.input_queue: ">"},
count=1,
block=1000,
Expand Down Expand Up @@ -299,11 +298,13 @@ def start(self):
try:
self.handle_message(response_queue, message, cleanup_functions)
self.redis.xack(self.input_queue, self.input_queue, message_id)
self.redis.xdel(self.input_queue, message_id) # xdel to be able to get stream size
except Exception as e:
tb = traceback.format_exc()
sys.stderr.write(f"Failed to handle message: {tb}\n")
self.push_error(response_queue, e)
self.redis.xack(self.input_queue, self.input_queue, message_id)
self.redis.xdel(self.input_queue, message_id)
finally:
for cleanup_function in cleanup_functions:
try:
Expand Down Expand Up @@ -557,7 +558,3 @@ def _abort400(message):
resp = jsonify({"message": message})
resp.status_code = 400
return resp


def random_string(length=20):
return "".join(random.choices(string.ascii_uppercase + string.digits, k=length))
2 changes: 1 addition & 1 deletion pkg/docker/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ import os
os.chdir("` + g.getWorkdir() + `")
sys.path.append("` + g.getWorkdir() + `")
from ` + module + ` import ` + class + `
cog.RedisQueueWorker(` + class + `(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()`
cog.RedisQueueWorker(` + class + `(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4], consumer_id=sys.argv[5]).start()`
scriptString := strings.ReplaceAll(script, "\n", "\\n")
return `
RUN echo '` + scriptString + `' > ` + scriptPath + `
Expand Down
2 changes: 1 addition & 1 deletion pkg/docker/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,6 @@ RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nimport os\nos.chdir("/c
RUN chmod +x /usr/bin/cog-http-server
RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nimport os\nos.chdir("/code")\nsys.path.append("/code")\nfrom infer import Model\ncog.AIPlatformPredictionServer(Model()).start_server()' > /usr/bin/cog-ai-platform-prediction-server
RUN chmod +x /usr/bin/cog-ai-platform-prediction-server
RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nimport os\nos.chdir("/code")\nsys.path.append("/code")\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()' > /usr/bin/cog-redis-queue-worker
RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nimport os\nos.chdir("/code")\nsys.path.append("/code")\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4], consumer_id=sys.argv[5]).start()' > /usr/bin/cog-redis-queue-worker
RUN chmod +x /usr/bin/cog-redis-queue-worker`
}