diff --git a/go.mod b/go.mod index eaae362254..490769e1dc 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( google.golang.org/grpc v1.35.0 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/yaml.v2 v2.4.0 - gotest.tools/gotestsum v1.6.3 // indirect + gotest.tools/gotestsum v1.6.4 // indirect gotest.tools/v3 v3.0.3 // indirect ) diff --git a/go.sum b/go.sum index 856e6dc267..9093c6253f 100644 --- a/go.sum +++ b/go.sum @@ -479,6 +479,8 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/gotestsum v1.6.3 h1:E3wOF4wmxKA19BB5wTY7t0L1m+QNARtDcBX4yqG6DEc= gotest.tools/gotestsum v1.6.3/go.mod h1:fTR9ZhxC/TLAAx2/WMk/m3TkMB9eEI89gdEzhiRVJT8= +gotest.tools/gotestsum v1.6.4 h1:HFkapG0hK/HWiOxWS78SbR/JK5EpbH8hFzUuCvvfbfQ= +gotest.tools/gotestsum v1.6.4/go.mod h1:fTR9ZhxC/TLAAx2/WMk/m3TkMB9eEI89gdEzhiRVJT8= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0= gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8= diff --git a/pkg/cli/server.go b/pkg/cli/server.go index f6575f2158..89af8c2de2 100644 --- a/pkg/cli/server.go +++ b/pkg/cli/server.go @@ -9,7 +9,6 @@ import ( "github.com/spf13/cobra" "github.com/replicate/cog/pkg/console" - "github.com/replicate/cog/pkg/database" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/server" diff --git a/pkg/docker/cog.py b/pkg/docker/cog.py index b7e829701d..4afb8cb72e 100644 --- a/pkg/docker/cog.py +++ b/pkg/docker/cog.py @@ -1,3 +1,7 @@ +import signal +import requests +from io import BytesIO +import json import time import sys from contextlib import contextmanager @@ -15,8 +19,12 @@ from flask import Flask, send_file, request, jsonify, abort, Response from werkzeug.datastructures import FileStorage +import redis # TODO(andreas): handle directory input +# TODO(andreas): handle List[Dict[str, int]], etc. +# TODO(andreas): model-level documentation + _VALID_INPUT_TYPES = frozenset([str, int, float, bool, Path]) _UNSPECIFIED = object() @@ -26,8 +34,6 @@ class InputValidationError(Exception): class Model(ABC): - app: Flask - @abstractmethod def setup(self): pass @@ -36,14 +42,14 @@ def setup(self): def run(self, **kwargs): pass - def cli_run(self): - self.setup() - result = self.run() - print(result) + +class HTTPServer: + def __init__(self, model: Model): + self.model = model def make_app(self) -> Flask: start_time = time.time() - self.setup() + self.model.setup() app = Flask(__name__) setup_time = time.time() - start_time @@ -63,17 +69,17 @@ def handle_request(): ) raw_inputs[key] = val - if hasattr(self.run, "_inputs"): + if hasattr(self.model.run, "_inputs"): try: - inputs = self.validate_and_convert_inputs( - raw_inputs, cleanup_functions + inputs = validate_and_convert_inputs( + self.model, raw_inputs, cleanup_functions ) except InputValidationError as e: return _abort400(str(e)) else: inputs = raw_inputs - result = self.run(**inputs) + result = self.model.run(**inputs) run_time = time.time() - start_time return self.create_response(result, setup_time, run_time) finally: @@ -90,8 +96,8 @@ def ping(): @app.route("/help") def help(): args = {} - if hasattr(self.run, "_inputs"): - input_specs = self.run._inputs + if hasattr(self.model.run, "_inputs"): + input_specs = self.model.run._inputs for name, spec in input_specs.items(): arg = { "type": _type_name(spec.type), @@ -124,92 +130,302 @@ def create_response(self, result, setup_time, run_time): resp.headers["X-Run-Time"] = run_time return resp - def validate_and_convert_inputs( - self, raw_inputs: Dict[str, Any], cleanup_functions: List[Callable] - ) -> Dict[str, Any]: - input_specs = self.run._inputs - inputs = {} - for name, input_spec in input_specs.items(): - if name in raw_inputs: - val = raw_inputs[name] +class AIPlatformPredictionServer: + def __init__(self, model: Model): + sys.stderr.write( + "WARNING: AIPlatformPredictionServer is experimental, do not use this in production\n" + ) + self.model = model - if input_spec.type == Path: - if not isinstance(val, FileStorage): - raise InputValidationError( - f"Could not convert file input {name} to {_type_name(input_spec.type)}", - ) - if val.filename is None: - raise InputValidationError( - f"No filename is provided for file input {name}" + def make_app(self) -> Flask: + self.model.setup() + app = Flask(__name__) + + @app.route("/infer", methods=["POST"]) + def handle_request(): + cleanup_functions = [] + try: + content = request.json + instances = content["instances"] + results = [] + for instance in instances: + try: + validate_and_convert_inputs( + self.model, instance, cleanup_functions ) + except InputValidationError as e: + return jsonify({"error": str(e)}) + results.append(self.model.run(**instance)) + return jsonify( + { + "predictions": results, + } + ) + except Exception as e: + tb = traceback.format_exc() + return jsonify( + { + "error": tb, + } + ) - temp_dir = tempfile.mkdtemp() - cleanup_functions.append(lambda: shutil.rmtree(temp_dir)) + @app.route("/ping") + def ping(): + return "PONG" - temp_path = os.path.join(temp_dir, val.filename) - with open(temp_path, "wb") as f: - f.write(val.stream.read()) - converted = Path(temp_path) + @app.route("/help") + def help(): + args = {} + if hasattr(self.model.run, "_inputs"): + input_specs = self.model.run._inputs + for name, spec in input_specs.items(): + arg = { + "type": _type_name(spec.type), + } + if spec.help: + arg["help"] = spec.help + if spec.default is not _UNSPECIFIED: + arg["default"] = str(spec.default) # TODO: don't string this + if spec.min is not None: + arg["min"] = str(spec.min) # TODO: don't string this + if spec.max is not None: + arg["max"] = str(spec.max) # TODO: don't string this + args[name] = arg + return jsonify({"arguments": args}) - elif input_spec.type == int: - try: - converted = int(val) - except ValueError: - raise InputValidationError( - f"Could not convert {name}={val} to int" - ) + return app - elif input_spec.type == float: - try: - converted = float(val) - except ValueError: - raise InputValidationError( - f"Could not convert {name}={val} to float" - ) + def start_server(self): + app = self.make_app() + app.run(host="0.0.0.0", port=5000) - elif input_spec.type == bool: - if val not in [True, False]: - raise InputValidationError(f"{name}={val} is not a boolean") + def create_response(self, result, setup_time, run_time): + if isinstance(result, Path): + resp = send_file(str(result)) + elif isinstance(result, str): + resp = Response(result) + else: + resp = jsonify(result) + resp.headers["X-Setup-Time"] = setup_time + resp.headers["X-Run-Time"] = run_time + return resp - elif input_spec.type == str: - if isinstance(val, FileStorage): - raise InputValidationError( - f"Could not convert file input {name} to str" - ) - converted = val - else: - raise TypeError( - f"Internal error: Input type {input_spec} is not a valid input type" +# TODO: reliable queue +class RedisQueueWorker: + def __init__( + self, + model: Model, + redis_host: str, + redis_port: int, + input_queue: str, + upload_url: str, + redis_db: int = 0, + ): + self.model = model + self.redis_host = redis_host + self.redis_port = redis_port + self.input_queue = input_queue + self.upload_url = upload_url + self.redis_db = redis_db + self.redis = redis.Redis( + host=self.redis_host, port=self.redis_port, db=self.redis_db + ) + self.should_exit = False + sys.stderr.write( + f"Connected to Redis: {self.redis_host}:{self.redis_port} (db {self.redis_db})\n" + ) + + def signal_exit(self, signum, frame): + self.should_exit = True + sys.stderr.write("Caught SIGTERM, exiting...\n") + + def start(self): + signal.signal(signal.SIGTERM, self.signal_exit) + self.model.setup() + while not self.should_exit: + try: + sys.stderr.write(f"Waiting for message on {self.input_queue}\n") + _, raw_message = self.redis.blpop([self.input_queue]) + message = json.loads(raw_message) + message_id = message["id"] + response_queue = message["response_queue"] + sys.stderr.write( + f"Received message {message_id} on {self.input_queue}\n" + ) + cleanup_functions = [] + try: + self.handle_message( + message_id, response_queue, message, cleanup_functions ) + except Exception as e: + tb = traceback.format_exc() + sys.stderr.write(f"Failed to handle message: {tb}\n") + self.push_error(response_queue, e) + finally: + for cleanup_function in cleanup_functions: + try: + cleanup_function() + except Exception as e: + sys.stderr.write(f"Cleanup function caught error: {e}") + except Exception as e: + tb = traceback.format_exc() + sys.stderr.write(f"Failed to handle message: {tb}\n") + + def handle_message(self, message_id, response_queue, message, cleanup_functions): + inputs = {} + raw_inputs = message["inputs"] + for k, v in raw_inputs.items(): + if "value" in v and v["value"] != "": + inputs[k] = v["value"] + else: + file_url = v["file"]["url"] + sys.stderr.write(f"Downloading file from {file_url}\n") + value_bytes = self.download(file_url) + inputs[k] = FileStorage( + stream=BytesIO(value_bytes), filename=v["file"]["name"] + ) + try: + inputs = validate_and_convert_inputs(self.model, inputs, cleanup_functions) + except InputValidationError as e: + tb = traceback.format_exc() + sys.stderr.write(tb) + self.push_error(response_queue, e) + return + + result = self.model.run(**inputs) + self.push_result(response_queue, result) + + def download(self, url): + resp = requests.get(url) + resp.raise_for_status() + return resp.content + + def push_error(self, response_queue, error): + message = json.dumps( + { + "error": str(error), + } + ) + sys.stderr.write(f"Pushing error to {response_queue}\n") + self.redis.rpush(response_queue, message) - if _is_numeric_type(input_spec.type): - if input_spec.max is not None and converted > input_spec.max: - raise InputValidationError( - f"Value {converted} is greater than the max value {input_spec.max}" - ) - if input_spec.min is not None and converted < input_spec.min: - raise InputValidationError( - f"Value {converted} is less than the min value {input_spec.min}" - ) + def push_result(self, response_queue, result): + if isinstance(result, Path): + message = { + "file": { + "url": self.upload_to_temp(result), + "name": result.name, + } + } + elif isinstance(result, str): + message = { + "value": result, + } + else: + message = { + "value": json.dumps(result), + } + + sys.stderr.write(f"Pushing successful result to {response_queue}\n") + self.redis.rpush(response_queue, json.dumps(message)) + + def upload_to_temp(self, path: Path) -> str: + sys.stderr.write( + f"Uploading {path.name} to temporary storage at {self.upload_url}\n" + ) + resp = requests.put( + self.upload_url, files={"file": (path.name, path.open("rb"))} + ) + resp.raise_for_status() + return resp.json()["url"] + + +def validate_and_convert_inputs( + model: Model, raw_inputs: Dict[str, Any], cleanup_functions: List[Callable] +) -> Dict[str, Any]: + input_specs = model.run._inputs + inputs = {} + + for name, input_spec in input_specs.items(): + if name in raw_inputs: + val = raw_inputs[name] + + if input_spec.type == Path: + if not isinstance(val, FileStorage): + raise InputValidationError( + f"Could not convert file input {name} to {_type_name(input_spec.type)}", + ) + if val.filename is None: + raise InputValidationError( + f"No filename is provided for file input {name}" + ) + + temp_dir = tempfile.mkdtemp() + cleanup_functions.append(lambda: shutil.rmtree(temp_dir)) + + temp_path = os.path.join(temp_dir, val.filename) + with open(temp_path, "wb") as f: + f.write(val.stream.read()) + converted = Path(temp_path) + + elif input_spec.type == int: + try: + converted = int(val) + except ValueError: + raise InputValidationError(f"Could not convert {name}={val} to int") + + elif input_spec.type == float: + try: + converted = float(val) + except ValueError: + raise InputValidationError( + f"Could not convert {name}={val} to float" + ) + + elif input_spec.type == bool: + if val not in [True, False]: + raise InputValidationError(f"{name}={val} is not a boolean") + + elif input_spec.type == str: + if isinstance(val, FileStorage): + raise InputValidationError( + f"Could not convert file input {name} to str" + ) + converted = val else: - if input_spec.default is not _UNSPECIFIED: - converted = input_spec.default - else: - raise InputValidationError(f"Missing expected argument: {name}") - inputs[name] = converted - - expected_keys = set(self.run._inputs.keys()) - raw_keys = set(raw_inputs.keys()) - extraneous_keys = raw_keys - expected_keys - if extraneous_keys: - raise InputValidationError( - f"Extraneous input keys: {', '.join(extraneous_keys)}" - ) - - return inputs + raise TypeError( + f"Internal error: Input type {input_spec} is not a valid input type" + ) + + if _is_numeric_type(input_spec.type): + if input_spec.max is not None and converted > input_spec.max: + raise InputValidationError( + f"Value {converted} is greater than the max value {input_spec.max}" + ) + if input_spec.min is not None and converted < input_spec.min: + raise InputValidationError( + f"Value {converted} is less than the min value {input_spec.min}" + ) + + else: + if input_spec.default is not _UNSPECIFIED: + converted = input_spec.default + else: + raise InputValidationError(f"Missing expected argument: {name}") + inputs[name] = converted + + expected_keys = set(input_specs.keys()) + raw_keys = set(raw_inputs.keys()) + extraneous_keys = raw_keys - expected_keys + if extraneous_keys: + raise InputValidationError( + f"Extraneous input keys: {', '.join(extraneous_keys)}" + ) + + return inputs @contextmanager diff --git a/pkg/docker/generate.go b/pkg/docker/generate.go index 18ebe061d8..1fca101da6 100644 --- a/pkg/docker/generate.go +++ b/pkg/docker/generate.go @@ -67,6 +67,7 @@ func (g *DockerfileGenerator) Generate() (string, error) { g.installCog(), g.preInstall(), g.copyCode(), + g.installHelperScripts(), g.workdir(), g.postInstall(), g.command(), @@ -86,6 +87,7 @@ func (g *DockerfileGenerator) baseImage() (string, error) { func (g *DockerfileGenerator) preamble() string { // TODO: other stuff return `ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu` } @@ -134,11 +136,51 @@ RUN apt-get update -q && apt-get install -qy --no-install-recommends \ func (g *DockerfileGenerator) installCog() string { cogLibB64 := base64.StdEncoding.EncodeToString(cogLibrary) - return g.sectionLabel(SectionInstallingCog) + fmt.Sprintf(`RUN pip install flask + return g.sectionLabel(SectionInstallingCog) + + fmt.Sprintf(`RUN pip install flask requests redis ENV PYTHONPATH=/usr/local/lib/cog RUN mkdir -p /usr/local/lib/cog && echo %s | base64 --decode > /usr/local/lib/cog/cog.py`, cogLibB64) } +func (g *DockerfileGenerator) installHelperScripts() string { + return g.serverHelperScript("HTTPServer", "cog-http-server") + + g.serverHelperScript("AIPlatformPredictionServer", "cog-ai-platform-prediction-server") + + g.queueWorkerHelperScript() +} + +func (g *DockerfileGenerator) serverHelperScript(serverClass string, filename string) string { + scriptPath := "/code/" + filename + name := g.Config.Model + parts := strings.Split(name, ".py:") + module := parts[0] + class := parts[1] + script := `#!/usr/bin/env python +import cog +from ` + module + ` import ` + class + ` +cog.` + serverClass + `(` + class + `()).start_server()` + scriptString := strings.ReplaceAll(script, "\n", "\\n") + return ` +RUN echo '` + scriptString + `' > ` + scriptPath + ` +RUN chmod +x ` + scriptPath +} + +func (g *DockerfileGenerator) queueWorkerHelperScript() string { + scriptPath := "/code/cog-redis-queue-worker" + name := g.Config.Model + parts := strings.Split(name, ".py:") + module := parts[0] + class := parts[1] + script := `#!/usr/bin/env python +import sys +import cog +from ` + module + ` import ` + class + ` +cog.RedisQueueWorker(` + class + `(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()` + scriptString := strings.ReplaceAll(script, "\n", "\\n") + return ` +RUN echo '` + scriptString + `' > ` + scriptPath + ` +RUN chmod +x ` + scriptPath +} + func (g *DockerfileGenerator) pythonRequirements() (string, error) { reqs := g.Config.Environment.PythonRequirements if reqs == "" { @@ -179,11 +221,7 @@ func (g *DockerfileGenerator) copyCode() string { func (g *DockerfileGenerator) command() string { // TODO: handle infer scripts in subdirectories // TODO: check this actually exists - name := g.Config.Model - parts := strings.Split(name, ".py:") - module := parts[0] - class := parts[1] - return `CMD ["python", "-c", "from ` + module + ` import ` + class + `; ` + class + `().start_server()"]` + return `CMD /code/cog-http-server` } func (g *DockerfileGenerator) workdir() string { diff --git a/pkg/docker/generate_test.go b/pkg/docker/generate_test.go index 26b7f13fa1..ddb02c3ba9 100644 --- a/pkg/docker/generate_test.go +++ b/pkg/docker/generate_test.go @@ -13,7 +13,7 @@ import ( func installCog() string { cogLibB64 := base64.StdEncoding.EncodeToString(cogLibrary) return fmt.Sprintf(`RUN ### --> Installing Cog -RUN pip install flask +RUN pip install flask requests redis ENV PYTHONPATH=/usr/local/lib/cog RUN mkdir -p /usr/local/lib/cog && echo %s | base64 --decode > /usr/local/lib/cog/cog.py`, cogLibB64) } @@ -59,21 +59,37 @@ model: infer.py:Model expectedCPU := `FROM ubuntu:20.04 ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu ` + installPython("3.8") + installCog() + ` RUN ### --> Copying code COPY . /code + +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.HTTPServer(Model()).start_server()' > /code/cog-http-server +RUN chmod +x /code/cog-http-server +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.AIPlatformPredictionServer(Model()).start_server()' > /code/cog-ai-platform-prediction-server +RUN chmod +x /code/cog-ai-platform-prediction-server +RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()' > /code/cog-redis-queue-worker +RUN chmod +x /code/cog-redis-queue-worker WORKDIR /code -CMD ["python", "-c", "from infer import Model; Model().start_server()"]` +CMD /code/cog-http-server` expectedGPU := `FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu16.04 ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu ` + installPython("3.8") + installCog() + ` RUN ### --> Copying code COPY . /code + +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.HTTPServer(Model()).start_server()' > /code/cog-http-server +RUN chmod +x /code/cog-http-server +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.AIPlatformPredictionServer(Model()).start_server()' > /code/cog-ai-platform-prediction-server +RUN chmod +x /code/cog-ai-platform-prediction-server +RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()' > /code/cog-redis-queue-worker +RUN chmod +x /code/cog-redis-queue-worker WORKDIR /code -CMD ["python", "-c", "from infer import Model; Model().start_server()"]` +CMD /code/cog-http-server` gen := DockerfileGenerator{conf, "cpu"} actualCPU, err := gen.Generate() @@ -103,6 +119,7 @@ model: infer.py:Model expectedCPU := `FROM ubuntu:20.04 ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu ` + installPython("3.8") + `RUN ### --> Installing system packages RUN apt-get update -qq && apt-get install -qy ffmpeg cowsay && rm -rf /var/lib/apt/lists/* @@ -114,11 +131,19 @@ RUN pip install -f https://download.pytorch.org/whl/torch_stable.html torch==1 ` + installCog() + ` RUN ### --> Copying code COPY . /code + +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.HTTPServer(Model()).start_server()' > /code/cog-http-server +RUN chmod +x /code/cog-http-server +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.AIPlatformPredictionServer(Model()).start_server()' > /code/cog-ai-platform-prediction-server +RUN chmod +x /code/cog-ai-platform-prediction-server +RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()' > /code/cog-redis-queue-worker +RUN chmod +x /code/cog-redis-queue-worker WORKDIR /code -CMD ["python", "-c", "from infer import Model; Model().start_server()"]` +CMD /code/cog-http-server` expectedGPU := `FROM nvidia/cuda:10.2-cudnn8-devel-ubuntu18.04 ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu ` + installPython("3.8") + `RUN ### --> Installing system packages RUN apt-get update -qq && apt-get install -qy ffmpeg cowsay && rm -rf /var/lib/apt/lists/* @@ -130,8 +155,15 @@ RUN pip install torch==1.5.1 pandas==1.2.0.12 ` + installCog() + ` RUN ### --> Copying code COPY . /code + +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.HTTPServer(Model()).start_server()' > /code/cog-http-server +RUN chmod +x /code/cog-http-server +RUN echo '#!/usr/bin/env python\nimport cog\nfrom infer import Model\ncog.AIPlatformPredictionServer(Model()).start_server()' > /code/cog-ai-platform-prediction-server +RUN chmod +x /code/cog-ai-platform-prediction-server +RUN echo '#!/usr/bin/env python\nimport sys\nimport cog\nfrom infer import Model\ncog.RedisQueueWorker(Model(), redis_host=sys.argv[1], redis_port=sys.argv[2], input_queue=sys.argv[3], upload_url=sys.argv[4]).start()' > /code/cog-redis-queue-worker +RUN chmod +x /code/cog-redis-queue-worker WORKDIR /code -CMD ["python", "-c", "from infer import Model; Model().start_server()"]` +CMD /code/cog-http-server` gen := DockerfileGenerator{conf, "cpu"} actualCPU, err := gen.Generate() diff --git a/pkg/server/build.go b/pkg/server/build.go index 80df1f4f4a..7ef3ce805b 100644 --- a/pkg/server/build.go +++ b/pkg/server/build.go @@ -179,12 +179,11 @@ func (s *Server) testModel(mod *model.Model, dir string, logWriter logger.Logger } logWriter.Infof(fmt.Sprintf("Inference result length: %d, mime type: %s", len(outputBytes), output.MimeType)) if expectedOutput != nil { - if !bytes.Equal(expectedOutput, outputBytes) { - if outputIsFile { - return nil, fmt.Errorf("Output file contents doesn't match expected %s", example.Output[1:]) - } else { - return nil, fmt.Errorf("Output %s doesn't match expected: %s", string(outputBytes), example.Output) - } + if outputIsFile && !bytes.Equal(expectedOutput, outputBytes) { + return nil, fmt.Errorf("Output file contents doesn't match expected %s", example.Output[1:]) + } else if !outputIsFile && strings.TrimSpace(string(outputBytes)) != strings.TrimSpace(example.Output) { + // TODO(andreas): are there cases where space is significant? + return nil, fmt.Errorf("Output %s doesn't match expected: %s", string(outputBytes), example.Output) } } } diff --git a/pkg/server/web_hook.go b/pkg/server/web_hook.go index 4c2ad0ba4f..5b868be6f9 100644 --- a/pkg/server/web_hook.go +++ b/pkg/server/web_hook.go @@ -39,6 +39,14 @@ func (wh *WebHook) run(user string, name string, mod *model.Model, dir string, l } modelJSONBase64 := base64.StdEncoding.EncodeToString(modelJSON) modelPath := fmt.Sprintf("/v1/repos/%s/%s/models/%s", user, name, mod.ID) + dockerImageCPU := "" + if artifact, ok := mod.ArtifactFor(model.TargetDockerCPU); ok { + dockerImageCPU = artifact.URI + } + dockerImageGPU := "" + if artifact, ok := mod.ArtifactFor(model.TargetDockerGPU); ok { + dockerImageGPU = artifact.URI + } logWriter.Infof("Posting model to %s", wh.url.Host) @@ -46,6 +54,8 @@ func (wh *WebHook) run(user string, name string, mod *model.Model, dir string, l "model_id": {mod.ID}, "model_path": {modelPath}, "model_json_base64": {modelJSONBase64}, + "docker_image_cpu": {dockerImageCPU}, + "docker_image_gpu": {dockerImageGPU}, "user": {user}, "repo_name": {name}, "secret": {wh.secret},