Skip to content

Commit f843830

Browse files
committed
Support more types of object in JSON output
Signed-off-by: Ben Firshman <[email protected]>
1 parent f451eea commit f843830

File tree

7 files changed

+77
-8
lines changed

7 files changed

+77
-8
lines changed

pkg/cli/predict.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
138138
console.Output(string(output))
139139
return nil
140140
} else if output.MimeType == "application/json" {
141-
var obj map[string]interface{}
141+
var obj interface{}
142142
dec := json.NewDecoder(output.Buffer)
143143
if err := dec.Decode(&obj); err != nil {
144144
return err

python/cog/json.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import json
2+
3+
# Based on keepsake.json
4+
5+
# We load numpy but not torch or tensorflow because numpy loads very fast and
6+
# they're probably using it anyway
7+
# fmt: off
8+
try:
9+
import numpy as np # type: ignore
10+
has_numpy = True
11+
except ImportError:
12+
has_numpy = False
13+
# fmt: on
14+
15+
# Tensorflow takes a solid 10 seconds to import on a modern Macbook Pro, so instead of importing,
16+
# do this instead
17+
def _is_tensorflow_tensor(obj):
18+
# e.g. __module__='tensorflow.python.framework.ops', __name__='EagerTensor'
19+
return (
20+
obj.__class__.__module__.split(".")[0] == "tensorflow"
21+
and "Tensor" in obj.__class__.__name__
22+
)
23+
24+
25+
def _is_torch_tensor(obj):
26+
return (obj.__class__.__module__, obj.__class__.__name__) == ("torch", "Tensor")
27+
28+
29+
class CustomJSONEncoder(json.JSONEncoder):
30+
def default(self, o):
31+
if has_numpy:
32+
if isinstance(o, np.integer):
33+
return int(o)
34+
elif isinstance(o, np.floating):
35+
return float(o)
36+
elif isinstance(o, np.ndarray):
37+
return o.tolist()
38+
if _is_torch_tensor(o):
39+
return o.detach().tolist()
40+
if _is_tensorflow_tensor(o):
41+
return o.numpy().tolist()
42+
return json.JSONEncoder.default(self, o)
43+
44+
45+
def to_json(obj):
46+
return json.dumps(obj, cls=CustomJSONEncoder)

python/cog/server/ai_platform.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_type_name,
1010
UNSPECIFIED,
1111
)
12+
from ..json import to_json
1213
from ..model import Model, run_model, load_model
1314

1415

@@ -38,10 +39,13 @@ def handle_request():
3839
except InputValidationError as e:
3940
return jsonify({"error": str(e)})
4041
results.append(run_model(self.model, instance, cleanup_functions))
41-
return jsonify(
42-
{
43-
"predictions": results,
44-
}
42+
return Response(
43+
to_json(
44+
{
45+
"predictions": results,
46+
}
47+
),
48+
mimetype="application/json",
4549
)
4650
except Exception as e:
4751
tb = traceback.format_exc()

python/cog/server/http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
get_type_name,
1212
UNSPECIFIED,
1313
)
14+
from ..json import to_json
1415
from ..model import Model, run_model, load_model
1516

1617

@@ -99,7 +100,7 @@ def create_response(self, result, setup_time, run_time):
99100
elif isinstance(result, str):
100101
resp = Response(result, mimetype="text/plain")
101102
else:
102-
resp = jsonify(result)
103+
resp = Response(to_json(result), mimetype="application/json")
103104
resp.headers["X-Setup-Time"] = setup_time
104105
resp.headers["X-Run-Time"] = run_time
105106
return resp

python/cog/server/redis_queue.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
from ..input import InputValidationError, validate_and_convert_inputs
17+
from ..json import to_json
1718
from ..model import Model, run_model, load_model
1819

1920

@@ -210,7 +211,7 @@ def push_result(self, response_queue, result):
210211
}
211212
else:
212213
message = {
213-
"value": json.dumps(result),
214+
"value": to_json(result),
214215
}
215216

216217
sys.stderr.write(f"Pushing successful result to {response_queue}\n")

python/cog_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import pytest
1111
from flask.testing import FlaskClient
12+
import numpy as np
1213
from PIL import Image
1314

1415
import cog
@@ -220,7 +221,7 @@ def predict(self, num1, num2, num3):
220221
client = make_client(Model())
221222
resp = client.post("/predict", data={"num1": 3, "num2": -4, "num3": -4})
222223
assert resp.status_code == 200
223-
assert resp.data == b"-5.0\n"
224+
assert resp.data == b"-5.0"
224225
resp = client.post("/predict", data={"num1": 2, "num2": -4, "num3": -4})
225226
assert resp.status_code == 400
226227
resp = client.post("/predict", data={"num1": 3, "num2": -4.1, "num3": -4})
@@ -392,6 +393,21 @@ def predict(self):
392393
assert resp.content_length == 195894
393394

394395

396+
def test_json_output_numpy():
397+
class Model(cog.Model):
398+
def setup(self):
399+
pass
400+
401+
def predict(self):
402+
return {"foo": np.float32(1.0)}
403+
404+
client = make_client(Model())
405+
resp = client.post("/predict")
406+
assert resp.status_code == 200
407+
assert resp.content_type == "application/json"
408+
assert resp.data == b'{"foo": 1.0}'
409+
410+
395411
def test_multiple_arguments():
396412
class Model(cog.Model):
397413
def setup(self):

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
flask==2.0.1
2+
numpy==1.21.1
23
pillow==8.2.0
34
pytest==6.2.4
45
PyYAML==5.4.1

0 commit comments

Comments
 (0)