diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6483980a91..eb1204157f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ on: jobs: test: - name: "Test" + name: "Test Go" strategy: fail-fast: false matrix: @@ -34,7 +34,30 @@ jobs: - name: "Build" run: make build - name: Test - run: make test + run: make test-go + + test-cog-library: + name: "Test Cog library" + runs-on: ubuntu-20.04 + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@master + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: 3.8 + - uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: pip-${{ hashFiles('**/pkg/docker/cog_test_requirements.txt') }} + restore-keys: | + pip-${{ secrets.CACHE_VERSION }}- + - name: Install requirements + run: pip install -r pkg/docker/cog_test_requirements.txt + - name: Test + run: make test-cog-library # cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150 test-end-to-end: diff --git a/Makefile b/Makefile index f733f56cc8..3f1eb6e6a6 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,22 @@ clean: generate: go generate ./... -.PHONY: test -test: check-fmt vet lint +.PHONY: test-go +test-go: check-fmt vet lint go get gotest.tools/gotestsum go run gotest.tools/gotestsum -- -timeout 1200s -parallel 5 ./... $(ARGS) +.PHONY: test-end-to-end +test-end-to-end: install + cd end-to-end-test/ && $(MAKE) + +.PHONY: test-cog-library +test-cog-library: + cd pkg/docker/ && pytest cog_test.py + +.PHONY: test +test: test-go test-cog-library test-end-to-end + .PHONY: install install: go install $(LDFLAGS) $(MAIN) diff --git a/end-to-end-test/end_to_end_test/__init__.py b/end-to-end-test/end_to_end_test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/end-to-end-test/end_to_end_test/conftest.py b/end-to-end-test/end_to_end_test/conftest.py new file mode 100644 index 0000000000..a2cac9c080 --- /dev/null +++ b/end-to-end-test/end_to_end_test/conftest.py @@ -0,0 +1,101 @@ +import os +import subprocess +import tempfile +from waiting import wait +import requests +import pytest + +from .util import random_string, find_free_port, docker_run + + +@pytest.fixture +def cog_server_port(): + old_cwd = os.getcwd() + with tempfile.TemporaryDirectory() as cog_dir: + os.chdir(cog_dir) + port = str(find_free_port()) + server_proc = subprocess.Popen(["cog", "server", "--port", port]) + resp = wait( + lambda: requests.get("http://localhost:" + port + "/ping"), + timeout_seconds=60, + expected_exceptions=(requests.exceptions.ConnectionError,), + ) + assert resp.text == "pong" + + yield port + + os.chdir(old_cwd) + server_proc.kill() + + +@pytest.fixture +def project_dir(tmpdir_factory): + tmpdir = tmpdir_factory.mktemp("project") + with open(tmpdir / "infer.py", "w") as f: + f.write( + """ +import time +import tempfile +from pathlib import Path +import cog + +class Model(cog.Model): + def setup(self): + self.foo = "foo" + + @cog.input("text", type=str) + @cog.input("path", type=Path) + @cog.input("output_file", type=bool, default=False) + def run(self, text, path, output_file): + time.sleep(1) + with open(path) as f: + output = self.foo + text + f.read() + if output_file: + tmp = tempfile.NamedTemporaryFile(suffix=".txt") + tmp.close() + tmp_path = Path(tmp.name) + with tmp_path.open("w") as f: + f.write(output) + return tmp_path + return output + """ + ) + with open(tmpdir / "cog.yaml", "w") as f: + cog_yaml = """ +name: andreas/hello-world +model: infer.py:Model +examples: + - input: + text: "foo" + path: "@myfile.txt" + output: "foofoobaz" + - input: + text: "bar" + path: "@myfile.txt" + output: "foobarbaz" + - input: + text: "qux" + path: "@myfile.txt" +environment: + architectures: + - cpu + """ + f.write(cog_yaml) + + with open(tmpdir / "myfile.txt", "w") as f: + f.write("baz") + + return tmpdir + + +@pytest.fixture +def redis_port(): + container_name = "cog-test-redis-" + random_string(10) + port = find_free_port() + with docker_run( + "redis", + name=container_name, + publish=[{"host": port, "container": 6379}], + detach=True, + ): + yield port diff --git a/end-to-end-test/end_to_end_test/test_queue_worker.py b/end-to-end-test/end_to_end_test/test_queue_worker.py new file mode 100644 index 0000000000..6afa035427 --- /dev/null +++ b/end-to-end-test/end_to_end_test/test_queue_worker.py @@ -0,0 +1,150 @@ +import json +import redis +from contextlib import contextmanager +import multiprocessing +from flask import Flask, request, jsonify, Response + +from .util import ( + random_string, + set_model_url, + show_version, + push_with_log, + find_free_port, + docker_run, + get_local_ip, + wait_for_port, +) + + +def test_queue_worker(cog_server_port, project_dir, redis_port, tmpdir_factory): + user = random_string(10) + model_name = random_string(10) + model_url = f"http://localhost:{cog_server_port}/{user}/{model_name}" + + set_model_url(model_url, project_dir) + version_id = push_with_log(project_dir) + version_data = show_version(model_url, version_id) + + input_queue = multiprocessing.Queue() + output_queue = multiprocessing.Queue() + controller_port = find_free_port() + local_ip = get_local_ip() + upload_url = f"http://{local_ip}:{controller_port}/upload" + redis_host = local_ip + worker_name = "test-worker" + infer_queue_name = "infer-queue" + response_queue_name = "response-queue" + + wait_for_port(redis_host, redis_port) + + redis_client = redis.Redis(host=redis_host, port=redis_port, db=0) + + with queue_controller(input_queue, output_queue, controller_port), docker_run( + image=version_data["images"][0]["uri"], + interactive=True, + command=[ + "cog-redis-queue-worker", + redis_host, + str(redis_port), + infer_queue_name, + upload_url, + worker_name, + ], + ): + redis_client.xgroup_create( + mkstream=True, groupname=infer_queue_name, name=infer_queue_name, id="$" + ) + + infer_id = random_string(10) + redis_client.xadd( + name=infer_queue_name, + fields={ + "value": json.dumps( + { + "id": infer_id, + "inputs": { + "text": {"value": "bar"}, + "path": { + "file": { + "name": "myinput.txt", + "url": f"http://{local_ip}:{controller_port}/download", + } + }, + }, + "response_queue": response_queue_name, + } + ), + }, + ) + input_queue.put("test") + response = json.loads(redis_client.brpop(response_queue_name)[1])["value"] + assert response == "foobartest" + + infer_id = random_string(10) + redis_client.xadd( + name=infer_queue_name, + fields={ + "value": json.dumps( + { + "id": infer_id, + "inputs": { + "text": {"value": "bar"}, + "output_file": {"value": "true"}, + "path": { + "file": { + "name": "myinput.txt", + "url": f"http://{local_ip}:{controller_port}/download", + } + }, + }, + "response_queue": response_queue_name, + } + ), + }, + ) + input_queue.put("test") + response_contents = output_queue.get() + response = json.loads(redis_client.brpop(response_queue_name)[1])["file"] + assert response_contents.decode() == "foobartest" + assert response["url"] == "uploaded.txt" + + +@contextmanager +def queue_controller(input_queue, output_queue, controller_port): + controller = QueueController(input_queue, output_queue, controller_port) + controller.start() + yield controller + controller.kill() + + +class QueueController(multiprocessing.Process): + def __init__(self, input_queue, output_queue, port): + super().__init__() + self.input_queue = input_queue + self.output_queue = output_queue + self.port = port + + def run(self): + app = Flask("test-queue-controller") + + @app.route("/", methods=["GET"]) + def handle_index(): + return "OK" + + @app.route("/upload", methods=["PUT"]) + def handle_upload(): + f = request.files["file"] + contents = f.read() + self.output_queue.put(contents) + return jsonify({"url": "uploaded.txt"}) + + @app.route("/download", methods=["GET"]) + def handle_download(): + contents = self.input_queue.get() + return Response( + contents, + mimetype="text/plain", + headers={"Content-disposition": "attachment; filename=myinput.txt"}, + ) + + app.run(host="0.0.0.0", port=self.port, debug=False) diff --git a/end-to-end-test/end_to_end_test/test_server.py b/end-to-end-test/end_to_end_test/test_server.py index 73dbecea9c..ab2b817cbf 100644 --- a/end-to-end-test/end_to_end_test/test_server.py +++ b/end-to-end-test/end_to_end_test/test_server.py @@ -1,110 +1,21 @@ -import time -import json -import random -import string from glob import glob import os -import tempfile -import socket -from contextlib import closing import subprocess -import pytest import requests -from waiting import wait +from .util import random_string, set_model_url, show_version, push_with_log -@pytest.fixture -def cog_server_port_dir(): - old_cwd = os.getcwd() - with tempfile.TemporaryDirectory() as cog_dir: - os.chdir(cog_dir) - port = str(find_free_port()) - server_proc = subprocess.Popen(["cog", "server", "--port", port]) - resp = wait( - lambda: requests.get("http://localhost:" + port + "/ping"), - timeout_seconds=60, - expected_exceptions=(requests.exceptions.ConnectionError,), - ) - assert resp.text == "pong" - - yield port, cog_dir - - os.chdir(old_cwd) - server_proc.kill() - - -@pytest.fixture -def project_dir(tmpdir_factory): - tmpdir = tmpdir_factory.mktemp("project") - with open(tmpdir / "infer.py", "w") as f: - f.write( - """ -import time -from pathlib import Path -import cog - -class Model(cog.Model): - def setup(self): - self.foo = "foo" - - @cog.input("text", type=str) - @cog.input("path", type=Path) - def run(self, text, path): - time.sleep(1) - with open(path) as f: - return self.foo + text + f.read() - """ - ) - with open(tmpdir / "cog.yaml", "w") as f: - cog_yaml = """ -name: andreas/hello-world -model: infer.py:Model -examples: - - input: - text: "foo" - path: "@myfile.txt" - output: "foofoobaz" - - input: - text: "bar" - path: "@myfile.txt" - output: "foobarbaz" - - input: - text: "qux" - path: "@myfile.txt" -environment: - architectures: - - cpu - """ - f.write(cog_yaml) - - return tmpdir - - -def test_build_show_list_download_infer( - cog_server_port_dir, project_dir, tmpdir_factory -): - cog_port, cog_dir = cog_server_port_dir +def test_build_show_list_download_infer(cog_server_port, project_dir, tmpdir_factory): user = random_string(10) model_name = random_string(10) - model_url = f"http://localhost:{cog_port}/{user}/{model_name}" + model_url = f"http://localhost:{cog_server_port}/{user}/{model_name}" with open(os.path.join(project_dir, "cog.yaml")) as f: cog_yaml = f.read() - out, _ = subprocess.Popen( - ["cog", "model", "set", model_url], - stdout=subprocess.PIPE, - cwd=project_dir, - ).communicate() - assert ( - out.decode() - == f"Updated model: http://localhost:{cog_port}/{user}/{model_name}\n" - ) - - with open(project_dir / "myfile.txt", "w") as f: - f.write("baz") + set_model_url(model_url, project_dir) out, _ = subprocess.Popen( ["cog", "push"], @@ -173,7 +84,7 @@ def test_build_show_list_download_infer( with input_path.open("w") as f: f.write("input") - files_endpoint = f"http://localhost:{cog_port}/v1/models/{user}/{model_name}/versions/{version_id}/files" + files_endpoint = f"http://localhost:{cog_server_port}/v1/models/{user}/{model_name}/versions/{version_id}/files" assert requests.get(f"{files_endpoint}/cog.yaml").text == cog_yaml assert ( requests.get(f"{files_endpoint}/cog-example-output/output.02.txt").text @@ -201,57 +112,15 @@ def test_build_show_list_download_infer( assert f.read() == "foobazinput" -def test_push_log(cog_server_port_dir, project_dir): - cog_port, cog_dir = cog_server_port_dir - +def test_push_log(cog_server_port, project_dir): user = random_string(10) model_name = random_string(10) - model_url = f"http://localhost:{cog_port}/{user}/{model_name}" + model_url = f"http://localhost:{cog_server_port}/{user}/{model_name}" - out, _ = subprocess.Popen( - ["cog", "model", "set", model_url], - stdout=subprocess.PIPE, - cwd=project_dir, - ).communicate() - assert ( - out.decode() - == f"Updated model: http://localhost:{cog_port}/{user}/{model_name}\n" - ) - - with open(project_dir / "myfile.txt", "w") as f: - f.write("baz") - - out, _ = subprocess.Popen( - ["cog", "push", "--log"], - cwd=project_dir, - stdout=subprocess.PIPE, - ).communicate() - - assert out.decode().startswith("Successfully uploaded version "), ( - out.decode() + " doesn't start with 'Successfully uploaded version'" - ) - version_id = out.decode().strip().split("Successfully uploaded version ")[1] + set_model_url(model_url, project_dir) + version_id = push_with_log(project_dir) out = show_version(model_url, version_id) assert out["config"]["examples"][2]["output"] == "@cog-example-output/output.02.txt" assert out["images"][0]["arch"] == "cpu" assert out["images"][0]["run_arguments"]["text"]["type"] == "str" - - -def find_free_port(): - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: - s.bind(("", 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return s.getsockname()[1] - - -def random_string(length): - return "".join(random.choice(string.ascii_lowercase) for i in range(length)) - - -def show_version(model_url, version_id): - out, _ = subprocess.Popen( - ["cog", "--model", model_url, "show", "--json", version_id], - stdout=subprocess.PIPE, - ).communicate() - return json.loads(out) diff --git a/end-to-end-test/end_to_end_test/util.py b/end-to-end-test/end_to_end_test/util.py new file mode 100644 index 0000000000..a1ad86cf83 --- /dev/null +++ b/end-to-end-test/end_to_end_test/util.py @@ -0,0 +1,106 @@ +import time +from contextlib import contextmanager +import json +import subprocess +import random +import string +import socket +from typing import List, Optional +from contextlib import closing + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +def wait_for_port(host, port, timeout=60): + start = time.time() + while True: + try: + if time.time() - start > timeout: + raise Exception(f"Something is wrong. Timeout waiting for port {port}") + with socket.create_connection((host, port), timeout=5): + return + except socket.error: + pass + except socket.timeout: + raise + + +def random_string(length): + return "".join(random.choice(string.ascii_lowercase) for i in range(length)) + + +def show_version(model_url, version_id): + out, _ = subprocess.Popen( + ["cog", "--model", model_url, "show", "--json", version_id], + stdout=subprocess.PIPE, + ).communicate() + return json.loads(out) + + +def set_model_url(model_url, project_dir): + out, _ = subprocess.Popen( + ["cog", "model", "set", model_url], + stdout=subprocess.PIPE, + cwd=project_dir, + ).communicate() + assert out.decode() == f"Updated model: {model_url}\n" + + +def push_with_log(project_dir): + out, _ = subprocess.Popen( + ["cog", "push", "--log"], + cwd=project_dir, + stdout=subprocess.PIPE, + ).communicate() + + assert out.decode().startswith("Successfully uploaded version "), ( + out.decode() + " doesn't start with 'Successfully uploaded version'" + ) + version_id = out.decode().strip().split("Successfully uploaded version ")[1] + + return version_id + + +@contextmanager +def docker_run( + image, + name=None, + detach=False, + interactive=False, + publish: Optional[List[dict]] = None, + command: Optional[List[str]] = None, + env: Optional[dict] = None, +): + if name is None: + name = random_string(10) + + cmd = ["docker", "run", "--name", name] + if publish is not None: + for port_binding in publish: + host_port = port_binding["host"] + container_port = port_binding["container"] + cmd += ["--publish", f"{host_port}:{container_port}"] + if env is not None: + for key, value in env.items(): + cmd += ["-e", f"{key}={value}"] + if detach: + cmd += ["--detach"] + if interactive: + cmd += ["-i"] + cmd += [image] + if command: + cmd += command + try: + subprocess.Popen(cmd) + yield + finally: + subprocess.Popen(["docker", "rm", "--force", name]).wait() + + +def get_local_ip(): + return socket.gethostbyname(socket.gethostname()) diff --git a/end-to-end-test/requirements.txt b/end-to-end-test/requirements.txt index 51de3bde09..e96d0082fc 100644 --- a/end-to-end-test/requirements.txt +++ b/end-to-end-test/requirements.txt @@ -1,3 +1,5 @@ pytest==6.1.1 waiting==1.4.1 requests==2.25.1 +redis==3.5.3 +flask==2.0.0 diff --git a/pkg/cli/infer.go b/pkg/cli/infer.go index c76f0257ee..532686bdbf 100644 --- a/pkg/cli/infer.go +++ b/pkg/cli/infer.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io" - "mime" "os" "strings" @@ -13,12 +12,13 @@ import ( "github.com/mitchellh/go-homedir" "github.com/spf13/cobra" - "github.com/replicate/cog/pkg/util/console" - "github.com/replicate/cog/pkg/client" "github.com/replicate/cog/pkg/logger" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/serving" + "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/util/mime" + "github.com/replicate/cog/pkg/util/slices" ) var ( @@ -43,6 +43,10 @@ func newInferCommand() *cobra.Command { } func cmdInfer(cmd *cobra.Command, args []string) error { + if !slices.ContainsString([]string{"cpu", "gpu"}, inferArch) { + return fmt.Errorf("--arch must be either 'cpu' or 'gpu'") + } + mod, err := getModel() if err != nil { return err @@ -67,8 +71,7 @@ func cmdInfer(cmd *cobra.Command, args []string) error { return err } logWriter := logger.NewConsoleLogger() - // TODO(andreas): GPU inference - useGPU := false + useGPU := inferArch == "gpu" deployment, err := servingPlatform.Deploy(context.Background(), image.URI, useGPU, logWriter) if err != nil { return err @@ -79,22 +82,11 @@ func cmdInfer(cmd *cobra.Command, args []string) error { } }() - keyVals := map[string]string{} - for _, input := range inputs { - var name, value string + return inferIndividualInputs(deployment, inputs, outPath, logWriter) +} - // Default input name is "input" - if !strings.Contains(input, "=") { - name = "input" - value = input - } else { - split := strings.SplitN(input, "=", 2) - name = split[0] - value = split[1] - } - keyVals[name] = value - } - example := serving.NewExample(keyVals) +func inferIndividualInputs(deployment serving.Deployment, inputs []string, outputPath string, logWriter logger.Logger) error { + example := parseInferInputs(inputs) result, err := deployment.RunInference(context.Background(), example, logWriter) if err != nil { return err @@ -103,7 +95,7 @@ func cmdInfer(cmd *cobra.Command, args []string) error { output := result.Values["output"] // Write to stdout - if outPath == "" { + if outputPath == "" { // Is it something we can sensibly write to stdout? if output.MimeType == "plain/text" { _, err := io.Copy(os.Stdout, output.Buffer) @@ -121,23 +113,23 @@ func cmdInfer(cmd *cobra.Command, args []string) error { return nil } // Otherwise, fall back to writing file - outPath = "output" - extension, _ := mime.ExtensionsByType(output.MimeType) - if len(extension) > 0 { - outPath += extension[0] + outputPath = "output" + extension := mime.ExtensionByType(output.MimeType) + if extension != "" { + outputPath += extension } } // Ignore @, to make it behave the same as -i - outPath = strings.TrimPrefix(outPath, "@") + outputPath = strings.TrimPrefix(outputPath, "@") - outPath, err := homedir.Expand(outPath) + outputPath, err = homedir.Expand(outputPath) if err != nil { return err } // Write to file - outFile, err := os.OpenFile(outPath, os.O_WRONLY|os.O_CREATE, 0755) + outFile, err := os.OpenFile(outputPath, os.O_WRONLY|os.O_CREATE, 0755) if err != nil { return err } @@ -146,6 +138,28 @@ func cmdInfer(cmd *cobra.Command, args []string) error { return err } - fmt.Println("Written output to " + outPath) + fmt.Println("Written output to " + outputPath) return nil } + +func parseInferInputs(inputs []string) *serving.Example { + keyVals := map[string]string{} + for _, input := range inputs { + var name, value string + + // Default input name is "input" + if !strings.Contains(input, "=") { + name = "input" + value = input + } else { + split := strings.SplitN(input, "=", 2) + name = split[0] + value = split[1] + } + if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) { + value = value[1 : len(value)-1] + } + keyVals[name] = value + } + return serving.NewExample(keyVals) +} diff --git a/pkg/docker/cog.py b/pkg/docker/cog.py index 22bbf2dafe..5af1bb5b4b 100644 --- a/pkg/docker/cog.py +++ b/pkg/docker/cog.py @@ -17,7 +17,7 @@ from typing import Optional, Any, Type, List, Callable, Dict from numbers import Number -from flask import Flask, send_file, request, jsonify, abort, Response +from flask import Flask, send_file, request, jsonify, Response from werkzeug.datastructures import FileStorage import redis @@ -79,7 +79,7 @@ def handle_request(): else: inputs = raw_inputs - result = self.model.run(**inputs) + result = run_model(self.model, inputs, cleanup_functions) run_time = time.time() - start_time return self.create_response(result, setup_time, run_time) finally: @@ -158,7 +158,7 @@ def handle_request(): ) except InputValidationError as e: return jsonify({"error": str(e)}) - results.append(self.model.run(**instance)) + results.append(run_model(self.model, instance, cleanup_functions)) return jsonify( { "predictions": results, @@ -171,6 +171,12 @@ def handle_request(): "error": tb, } ) + finally: + for cleanup_function in cleanup_functions: + try: + cleanup_function() + except Exception as e: + sys.stderr.write(f"Cleanup function caught error: {e}") @app.route("/ping") def ping(): @@ -214,7 +220,6 @@ def create_response(self, result, setup_time, run_time): return resp -# TODO: reliable queue class RedisQueueWorker: def __init__( self, @@ -251,7 +256,6 @@ def signal_exit(self, signum, frame): self.should_exit = True sys.stderr.write("Caught SIGTERM, exiting...\n") - # TODO(andreas): test this def receive_message(self): # first, try to autoclaim old messages from pending queue _, raw_messages = self.redis.execute_command( @@ -362,7 +366,7 @@ def handle_message(self, response_queue, message, cleanup_functions): self.push_error(response_queue, e) return - result = self.model.run(**inputs) + result = run_model(self.model, inputs, cleanup_functions) self.push_result(response_queue, result) def download(self, url): @@ -504,22 +508,15 @@ def validate_and_convert_inputs( return inputs -@contextmanager -def unzip_to_tempdir(zip_path): - with tempfile.TemporaryDirectory() as tempdir: - shutil.unpack_archive(zip_path, tempdir, "zip") - yield tempdir - - -def make_temp_path(filename): - temp_dir = make_temp_dir() - return Path(os.path.join(temp_dir, filename)) - - -def make_temp_dir(): - # TODO(andreas): cleanup - temp_dir = tempfile.mkdtemp() - return temp_dir +def run_model(model, inputs, cleanup_functions): + """ + Run the model on the inputs, and append resulting paths + to cleanup functions for removal. + """ + result = model.run(**inputs) + if isinstance(result, Path): + cleanup_functions.append(result.unlink) + return result @dataclass diff --git a/pkg/docker/cog_test.py b/pkg/docker/cog_test.py index 39a0ae23a3..90518c3c1f 100644 --- a/pkg/docker/cog_test.py +++ b/pkg/docker/cog_test.py @@ -358,7 +358,6 @@ def setup(self): @cog.input("text", type=str) def run(self, text): - # TODO(andreas): how to clean up files? temp_dir = tempfile.mkdtemp() temp_path = os.path.join(temp_dir, "my_file.txt") with open(temp_path, "w") as f: @@ -387,7 +386,8 @@ def run(self): client = make_client(Model()) resp = client.post("/infer") assert resp.status_code == 200 - assert resp.content_type == "image/bmp" + # need both image/bmp and image/x-ms-bmp until https://bugs.python.org/issue44211 is fixed + assert resp.content_type in ["image/bmp", "image/x-ms-bmp"] assert resp.content_length == 195894 @@ -483,22 +483,3 @@ def run(self): assert resp.status_code == 200 assert float(resp.headers["X-Setup-Time"]) < 0.5 assert float(resp.headers["X-Run-Time"]) < 0.5 - - -def test_unzip_to_tempdir(tmpdir_factory): - input_dir = tmpdir_factory.mktemp("input") - with open(os.path.join(input_dir, "hello.txt"), "w") as f: - f.write("hello") - os.mkdir(os.path.join(input_dir, "mydir")) - with open(os.path.join(input_dir, "mydir", "world.txt"), "w") as f: - f.write("world") - - zip_dir = tmpdir_factory.mktemp("zip") - zip_path = os.path.join(zip_dir, "my-archive.zip") - - shutil.make_archive(zip_path.split(".zip")[0], "zip", input_dir) - with cog.unzip_to_tempdir(zip_path) as tempdir: - with open(os.path.join(tempdir, "hello.txt")) as f: - assert f.read() == "hello" - with open(os.path.join(tempdir, "mydir", "world.txt")) as f: - assert f.read() == "world" diff --git a/pkg/docker/cog_test_requirements.txt b/pkg/docker/cog_test_requirements.txt new file mode 100644 index 0000000000..d22198bff3 --- /dev/null +++ b/pkg/docker/cog_test_requirements.txt @@ -0,0 +1,5 @@ +flask==2.0.1 +pillow==8.2.0 +pytest==6.2.4 +redis==3.5.3 +requests==2.25.1 diff --git a/pkg/serving/test.go b/pkg/serving/test.go index 9c49dc6412..55955a1db7 100644 --- a/pkg/serving/test.go +++ b/pkg/serving/test.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "io" - "mime" "os" "path/filepath" "strings" @@ -16,6 +15,7 @@ import ( "github.com/replicate/cog/pkg/logger" "github.com/replicate/cog/pkg/model" "github.com/replicate/cog/pkg/util/console" + "github.com/replicate/cog/pkg/util/mime" ) // TODO(andreas): put this somewhere else since it's used by server? @@ -144,155 +144,6 @@ func validateServingExampleInput(args map[string]*model.RunArgument, input map[s return nil } -func extensionByType(mimeType string) string { - switch mimeType { - case "audio/aac": - return ".aac" - case "application/x-abiword": - return ".abw" - case "application/x-freearc": - return ".arc" - case "video/x-msvideo": - return ".avi" - case "application/vnd.amazon.ebook": - return ".azw" - case "application/octet-stream": - return ".bin" - case "image/bmp": - return ".bmp" - case "application/x-bzip": - return ".bz" - case "application/x-bzip2": - return ".bz2" - case "application/x-csh": - return ".csh" - case "text/css": - return ".css" - case "text/csv": - return ".csv" - case "application/msword": - return ".doc" - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return ".docx" - case "application/vnd.ms-fontobject": - return ".eot" - case "application/epub+zip": - return ".epub" - case "application/gzip": - return ".gz" - case "image/gif": - return ".gif" - case "text/html": - return ".html" - case "image/vnd.microsoft.icon": - return ".ico" - case "text/calendar": - return ".ics" - case "application/java-archive": - return ".jar" - case "image/jpeg": - return ".jpg" - case "text/javascript": - return ".js" - case "application/json": - return ".json" - case "application/ld+json": - return ".jsonld" - case "audio/midi audio/x-midi": - return ".midi" - case "audio/mpeg": - return ".mp3" - case "application/x-cdf": - return ".cda" - case "video/mp4": - return ".mp4" - case "video/mpeg": - return ".mpeg" - case "application/vnd.apple.installer+xml": - return ".mpkg" - case "application/vnd.oasis.opendocument.presentation": - return ".odp" - case "application/vnd.oasis.opendocument.spreadsheet": - return ".ods" - case "application/vnd.oasis.opendocument.text": - return ".odt" - case "audio/ogg": - return ".oga" - case "video/ogg": - return ".ogv" - case "application/ogg": - return ".ogx" - case "audio/opus": - return ".opus" - case "font/otf": - return ".otf" - case "image/png": - return ".png" - case "application/pdf": - return ".pdf" - case "application/x-httpd-php": - return ".php" - case "application/vnd.ms-powerpoint": - return ".ppt" - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return ".pptx" - case "application/vnd.rar": - return ".rar" - case "application/rtf": - return ".rtf" - case "application/x-sh": - return ".sh" - case "image/svg+xml": - return ".svg" - case "application/x-shockwave-flash": - return ".swf" - case "application/x-tar": - return ".tar" - case "image/tiff": - return ".tiff" - case "video/mp2t": - return ".ts" - case "font/ttf": - return ".ttf" - case "text/plain": - return ".txt" - case "application/vnd.visio": - return ".vsd" - case "audio/wav": - return ".wav" - case "audio/webm": - return ".weba" - case "video/webm": - return ".webm" - case "image/webp": - return ".webp" - case "font/woff": - return ".woff" - case "font/woff2": - return ".woff2" - case "application/xhtml+xml": - return ".xhtml" - case "application/vnd.ms-excel": - return ".xls" - case "application/xml": - return ".xml" - case "application/zip": - return ".zip" - case "video/3gpp": - return ".3gp" - case "video/3gpp2": - return ".3gp2" - case "application/x-7z-compressed": - return ".7z" - default: - extensions, _ := mime.ExtensionsByType(mimeType) - if len(extensions) == 0 { - return "" - } - return extensions[0] - } -} - func copyExamples(examples []*model.Example) []*model.Example { copy := []*model.Example{} for _, ex := range examples { @@ -353,7 +204,7 @@ func setAggregateStats(modelStats *model.Stats, setupTimes []float64, runTimes [ func updateExampleOutput(example *model.Example, newExampleOutputs map[string][]byte, outputBytes []byte, mimeType string, index int) { filename := fmt.Sprintf("output.%02d", index) - if ext := extensionByType(mimeType); ext != "" { + if ext := mime.ExtensionByType(mimeType); ext != "" { filename += ext } outputPath := filepath.Join(ExampleOutputDir, filename) diff --git a/pkg/serving/test_test.go b/pkg/serving/test_test.go index 86c7907c74..1eca3f148f 100644 --- a/pkg/serving/test_test.go +++ b/pkg/serving/test_test.go @@ -150,14 +150,6 @@ func TestValidateServingExampleInput(t *testing.T) { })) } -func TestExtensionByType(t *testing.T) { - require.Equal(t, ".txt", extensionByType("text/plain")) - require.Equal(t, ".jpg", extensionByType("image/jpeg")) - require.Equal(t, ".png", extensionByType("image/png")) - require.Equal(t, ".json", extensionByType("application/json")) - require.Equal(t, "", extensionByType("asdfasdf")) -} - func TestOutputBytesFromExample(t *testing.T) { tmpDir, err := os.MkdirTemp("", "test") require.NoError(t, err) diff --git a/pkg/util/mime/mime.go b/pkg/util/mime/mime.go new file mode 100644 index 0000000000..eb16e0db24 --- /dev/null +++ b/pkg/util/mime/mime.go @@ -0,0 +1,154 @@ +package mime + +import ( + "mime" +) + +func ExtensionByType(mimeType string) string { + switch mimeType { + case "audio/aac": + return ".aac" + case "application/x-abiword": + return ".abw" + case "application/x-freearc": + return ".arc" + case "video/x-msvideo": + return ".avi" + case "application/vnd.amazon.ebook": + return ".azw" + case "application/octet-stream": + return ".bin" + case "image/bmp": + return ".bmp" + case "application/x-bzip": + return ".bz" + case "application/x-bzip2": + return ".bz2" + case "application/x-csh": + return ".csh" + case "text/css": + return ".css" + case "text/csv": + return ".csv" + case "application/msword": + return ".doc" + case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": + return ".docx" + case "application/vnd.ms-fontobject": + return ".eot" + case "application/epub+zip": + return ".epub" + case "application/gzip": + return ".gz" + case "image/gif": + return ".gif" + case "text/html": + return ".html" + case "image/vnd.microsoft.icon": + return ".ico" + case "text/calendar": + return ".ics" + case "application/java-archive": + return ".jar" + case "image/jpeg": + return ".jpg" + case "text/javascript": + return ".js" + case "application/json": + return ".json" + case "application/ld+json": + return ".jsonld" + case "audio/midi audio/x-midi": + return ".midi" + case "audio/mpeg": + return ".mp3" + case "application/x-cdf": + return ".cda" + case "video/mp4": + return ".mp4" + case "video/mpeg": + return ".mpeg" + case "application/vnd.apple.installer+xml": + return ".mpkg" + case "application/vnd.oasis.opendocument.presentation": + return ".odp" + case "application/vnd.oasis.opendocument.spreadsheet": + return ".ods" + case "application/vnd.oasis.opendocument.text": + return ".odt" + case "audio/ogg": + return ".oga" + case "video/ogg": + return ".ogv" + case "application/ogg": + return ".ogx" + case "audio/opus": + return ".opus" + case "font/otf": + return ".otf" + case "image/png": + return ".png" + case "application/pdf": + return ".pdf" + case "application/x-httpd-php": + return ".php" + case "application/vnd.ms-powerpoint": + return ".ppt" + case "application/vnd.openxmlformats-officedocument.presentationml.presentation": + return ".pptx" + case "application/vnd.rar": + return ".rar" + case "application/rtf": + return ".rtf" + case "application/x-sh": + return ".sh" + case "image/svg+xml": + return ".svg" + case "application/x-shockwave-flash": + return ".swf" + case "application/x-tar": + return ".tar" + case "image/tiff": + return ".tiff" + case "video/mp2t": + return ".ts" + case "font/ttf": + return ".ttf" + case "text/plain": + return ".txt" + case "application/vnd.visio": + return ".vsd" + case "audio/wav": + return ".wav" + case "audio/webm": + return ".weba" + case "video/webm": + return ".webm" + case "image/webp": + return ".webp" + case "font/woff": + return ".woff" + case "font/woff2": + return ".woff2" + case "application/xhtml+xml": + return ".xhtml" + case "application/vnd.ms-excel": + return ".xls" + case "application/xml": + return ".xml" + case "application/zip": + return ".zip" + case "video/3gpp": + return ".3gp" + case "video/3gpp2": + return ".3gp2" + case "application/x-7z-compressed": + return ".7z" + default: + extensions, _ := mime.ExtensionsByType(mimeType) + if len(extensions) == 0 { + return "" + } + return extensions[0] + } +} diff --git a/pkg/util/mime/mime_test.go b/pkg/util/mime/mime_test.go new file mode 100644 index 0000000000..106328c750 --- /dev/null +++ b/pkg/util/mime/mime_test.go @@ -0,0 +1,15 @@ +package mime + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExtensionByType(t *testing.T) { + require.Equal(t, ".txt", ExtensionByType("text/plain")) + require.Equal(t, ".jpg", ExtensionByType("image/jpeg")) + require.Equal(t, ".png", ExtensionByType("image/png")) + require.Equal(t, ".json", ExtensionByType("application/json")) + require.Equal(t, "", ExtensionByType("asdfasdf")) +}