|
9 | 9 |
|
10 | 10 | import pytest |
11 | 11 | from flask.testing import FlaskClient |
| 12 | +import numpy as np |
12 | 13 | from PIL import Image |
13 | 14 |
|
14 | 15 | import cog |
@@ -220,7 +221,7 @@ def predict(self, num1, num2, num3): |
220 | 221 | client = make_client(Model()) |
221 | 222 | resp = client.post("/predict", data={"num1": 3, "num2": -4, "num3": -4}) |
222 | 223 | assert resp.status_code == 200 |
223 | | - assert resp.data == b"-5.0\n" |
| 224 | + assert resp.data == b"-5.0" |
224 | 225 | resp = client.post("/predict", data={"num1": 2, "num2": -4, "num3": -4}) |
225 | 226 | assert resp.status_code == 400 |
226 | 227 | resp = client.post("/predict", data={"num1": 3, "num2": -4.1, "num3": -4}) |
@@ -392,6 +393,21 @@ def predict(self): |
392 | 393 | assert resp.content_length == 195894 |
393 | 394 |
|
394 | 395 |
|
| 396 | +def test_json_output_numpy(): |
| 397 | + class Model(cog.Model): |
| 398 | + def setup(self): |
| 399 | + pass |
| 400 | + |
| 401 | + def predict(self): |
| 402 | + return {"foo": np.float32(1.0)} |
| 403 | + |
| 404 | + client = make_client(Model()) |
| 405 | + resp = client.post("/predict") |
| 406 | + assert resp.status_code == 200 |
| 407 | + assert resp.content_type == "application/json" |
| 408 | + assert resp.data == b'{"foo": 1.0}' |
| 409 | + |
| 410 | + |
395 | 411 | def test_multiple_arguments(): |
396 | 412 | class Model(cog.Model): |
397 | 413 | def setup(self): |
|
0 commit comments