diff --git a/pkg/docker/cog.py b/pkg/docker/cog.py index 448a2718b2..b7e829701d 100644 --- a/pkg/docker/cog.py +++ b/pkg/docker/cog.py @@ -58,8 +58,8 @@ def handle_request(): raw_inputs[key] = val for key, val in request.files.items(): if key in raw_inputs: - return abort( - 400, f"Duplicated argument name in form and files: {key}" + return _abort400( + f"Duplicated argument name in form and files: {key}" ) raw_inputs[key] = val @@ -69,7 +69,7 @@ def handle_request(): raw_inputs, cleanup_functions ) except InputValidationError as e: - return abort(400, str(e)) + return _abort400(str(e)) else: inputs = raw_inputs @@ -186,9 +186,13 @@ def validate_and_convert_inputs( if _is_numeric_type(input_spec.type): if input_spec.max is not None and converted > input_spec.max: - raise InputValidationError(f"Value {converted} is greater than the max value {input_spec.max}") + raise InputValidationError( + f"Value {converted} is greater than the max value {input_spec.max}" + ) if input_spec.min is not None and converted < input_spec.min: - raise InputValidationError(f"Value {converted} is less than the min value {input_spec.min}") + raise InputValidationError( + f"Value {converted} is less than the min value {input_spec.min}" + ) else: if input_spec.default is not _UNSPECIFIED: @@ -290,3 +294,9 @@ def _is_numeric_type(typ: Type) -> bool: def _method_arg_names(f) -> List[str]: return inspect.getfullargspec(f)[0][1:] # 0 is self + + +def _abort400(message): + resp = jsonify({"message": message}) + resp.status_code = 400 + return resp diff --git a/pkg/serving/local.go b/pkg/serving/local.go index 02f3f9a8e6..ca94e4c0e8 100644 --- a/pkg/serving/local.go +++ b/pkg/serving/local.go @@ -212,6 +212,19 @@ func (d *LocalDockerDeployment) RunInference(input *Example, logWriter logger.Lo } defer resp.Body.Close() + if resp.StatusCode == http.StatusBadRequest { + body := struct { + Message string `json:"message"` + }{} + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return nil, fmt.Errorf("/infer call return status 400, and the response body failed to decode: %w", err) + } + if body.Message == "" { + return nil, fmt.Errorf("Bad request") + } + return nil, fmt.Errorf("Bad request: %s", body.Message) + } + if resp.StatusCode != http.StatusOK { d.writeContainerLogs(logWriter) return nil, fmt.Errorf("/infer call returned status %d", resp.StatusCode)