Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func predictIndividualInputs(predictor predict.Predictor, inputFlags []string, o
return err
}

inputs, err := parseInputFlags(inputFlags)
inputs, err := parseInputFlags(inputFlags, schema)
if err != nil {
return err
}
Expand Down Expand Up @@ -380,7 +380,7 @@ func writeDataURLOutput(outputString string, outputPath string, addExtension boo
return nil
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error) {
keyVals := map[string][]string{}
for _, input := range inputs {
var name, value string
Expand All @@ -402,7 +402,7 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) {
keyVals[name] = append(keyVals[name], value)
}

return predict.NewInputs(keyVals), nil
return predict.NewInputs(keyVals, schema)
}

func addSetupTimeoutFlag(cmd *cobra.Command) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/docker/docker_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (c *DockerCommand) exec(name string, capture bool, args ...string) (string,
console.Debug("$ " + strings.Join(cmd.Args, " "))
err := cmd.Run()
if err != nil {
return "", err
return "", fmt.Errorf("%s: %w", out.String(), err)
}
return out.String(), nil
}
Expand Down
34 changes: 32 additions & 2 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package predict

import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"

"github.com/getkin/kin-openapi/openapi3"
"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

Expand All @@ -16,11 +18,20 @@ type Input struct {
String *string
File *string
Array *[]any
Json *json.RawMessage
}

type Inputs map[string]Input

func NewInputs(keyVals map[string][]string) Inputs {
func NewInputs(keyVals map[string][]string, schema *openapi3.T) (Inputs, error) {
var inputComponent *openapi3.SchemaRef
for name, component := range schema.Components.Schemas {
if name == "Input" {
inputComponent = component
break
}
}

input := Inputs{}
for key, vals := range keyVals {
if len(vals) == 1 {
Expand All @@ -29,6 +40,23 @@ func NewInputs(keyVals map[string][]string) Inputs {
val = val[1:]
input[key] = Input{File: &val}
} else {
// Check if we should explicitly parse the JSON based on a known schema
if inputComponent != nil {
properties, err := inputComponent.JSONLookup("properties")
if err != nil {
return input, err
}
propertiesSchemas := properties.(openapi3.Schemas)
property, err := propertiesSchemas.JSONLookup(key)
if err == nil {
propertySchema := property.(*openapi3.Schema)
if propertySchema.Type.Is("object") {
encodedVal := json.RawMessage(val)
input[key] = Input{Json: &encodedVal}
continue
}
}
}
input[key] = Input{String: &val}
}
} else if len(vals) > 1 {
Expand All @@ -39,7 +67,7 @@ func NewInputs(keyVals map[string][]string) Inputs {
input[key] = Input{Array: &anyVals}
}
}
return input
return input, nil
}

func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
Expand Down Expand Up @@ -86,6 +114,8 @@ func (inputs *Inputs) toMap() (map[string]any, error) {
}
}
keyVals[key] = dataURLs
case input.Json != nil:
keyVals[key] = *input.Json
}
}
return keyVals, nil
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
build:
python_version: "3.11"
fast: true
python_requirements: "requirements.txt"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from anthropic.types import MessageParam
from cog import BasePredictor, Input
from cog.coder import json_coder # noqa: F401


class Predictor(BasePredictor):
def predict(
self,
message: MessageParam = Input(description="Messages API."),
) -> str:
return "Content: " + message["content"]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
anthropic[vertex]==0.45.2
29 changes: 28 additions & 1 deletion test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,34 @@ def test_predict_optional_project(tmpdir_factory):
assert result.stdout == "hello No One\n"


def test_predict_complex_types(docker_image):
project_dir = Path(__file__).parent / "fixtures/complex-types"

build_process = subprocess.run(
["cog", "build", "-t", docker_image, "--x-fast", "--x-localimage"],
cwd=project_dir,
capture_output=True,
)
assert build_process.returncode == 0
result = subprocess.run(
[
"cog",
"predict",
"--debug",
docker_image,
"-i",
'message={"content": "Hi There", "role": "user"}',
],
cwd=project_dir,
check=True,
capture_output=True,
text=True,
timeout=DEFAULT_TIMEOUT,
)
assert result.returncode == 0
assert result.stdout == "Content: Hi There\n"


def test_predict_overrides_project(docker_image):
project_dir = Path(__file__).parent / "fixtures/overrides-project"
build_process = subprocess.run(
Expand Down Expand Up @@ -462,7 +490,6 @@ def test_predict_zsh_package(docker_image):
capture_output=True,
)
assert build_process.returncode == 0

result = subprocess.run(
["cog", "predict", "--debug", docker_image],
cwd=project_dir,
Expand Down
Loading