Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions pkg/server/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ func (s *Server) testModel(mod *model.Model, dir string, logWriter logger.Logger
if err := validateServingExampleInput(help, example.Input); err != nil {
return nil, fmt.Errorf("Example input doesn't match run arguments: %w", err)
}
var expectedOutput []byte = nil
outputIsFile := false
if example.Output != "" {
if strings.HasPrefix(example.Output, "@") {
outputIsFile = true
expectedOutput, err = os.ReadFile(filepath.Join(dir, example.Output[1:]))
if err != nil {
return nil, fmt.Errorf("Failed to read example output file %s: %w", example.Output[1:], err)
}
} else {
expectedOutput = []byte(example.Output)
}
}

input := serving.NewExampleWithBaseDir(example.Input, dir)

result, err := deployment.RunInference(input, logWriter)
Expand All @@ -164,8 +178,14 @@ func (s *Server) testModel(mod *model.Model, dir string, logWriter logger.Logger
return nil, fmt.Errorf("Failed to read output: %w", err)
}
logWriter.Infof(fmt.Sprintf("Inference result length: %d, mime type: %s", len(outputBytes), output.MimeType))
if example.Output != "" && strings.TrimSpace(string(outputBytes)) != example.Output {
return nil, fmt.Errorf("Output %s doesn't match expected: %s", outputBytes, example.Output)
if expectedOutput != nil {
if !bytes.Equal(expectedOutput, outputBytes) {
if outputIsFile {
return nil, fmt.Errorf("Output file contents doesn't match expected %s", example.Output[1:])
} else {
return nil, fmt.Errorf("Output %s doesn't match expected: %s", string(outputBytes), example.Output)
}
}
}
}

Expand Down