Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions pkg/docker/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def handle_request():
raw_inputs[key] = val
for key, val in request.files.items():
if key in raw_inputs:
return abort(
400, f"Duplicated argument name in form and files: {key}"
return _abort400(
f"Duplicated argument name in form and files: {key}"
)
raw_inputs[key] = val

Expand All @@ -69,7 +69,7 @@ def handle_request():
raw_inputs, cleanup_functions
)
except InputValidationError as e:
return abort(400, str(e))
return _abort400(str(e))
else:
inputs = raw_inputs

Expand Down Expand Up @@ -186,9 +186,13 @@ def validate_and_convert_inputs(

if _is_numeric_type(input_spec.type):
if input_spec.max is not None and converted > input_spec.max:
raise InputValidationError(f"Value {converted} is greater than the max value {input_spec.max}")
raise InputValidationError(
f"Value {converted} is greater than the max value {input_spec.max}"
)
if input_spec.min is not None and converted < input_spec.min:
raise InputValidationError(f"Value {converted} is less than the min value {input_spec.min}")
raise InputValidationError(
f"Value {converted} is less than the min value {input_spec.min}"
)

else:
if input_spec.default is not _UNSPECIFIED:
Expand Down Expand Up @@ -290,3 +294,9 @@ def _is_numeric_type(typ: Type) -> bool:

def _method_arg_names(f) -> List[str]:
return inspect.getfullargspec(f)[0][1:] # 0 is self


def _abort400(message):
resp = jsonify({"message": message})
resp.status_code = 400
return resp
13 changes: 13 additions & 0 deletions pkg/serving/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,19 @@ func (d *LocalDockerDeployment) RunInference(input *Example, logWriter logger.Lo
}
defer resp.Body.Close()

if resp.StatusCode == http.StatusBadRequest {
body := struct {
Message string `json:"message"`
}{}
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, fmt.Errorf("/infer call return status 400, and the response body failed to decode: %w", err)
}
if body.Message == "" {
return nil, fmt.Errorf("Bad request")
}
return nil, fmt.Errorf("Bad request: %s", body.Message)
}

if resp.StatusCode != http.StatusOK {
d.writeContainerLogs(logWriter)
return nil, fmt.Errorf("/infer call returned status %d", resp.StatusCode)
Expand Down