Skip to content

Commit dbaf981

Browse files
authored
fix: handle error messages better for cog push with pipelines (#2435)
1 parent 3e6d5e6 commit dbaf981

File tree

3 files changed

+90
-75
lines changed

3 files changed

+90
-75
lines changed

pkg/api/client.go

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ var (
2929
ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint")
3030
ErrorBadDraftFormat = errors.New("Bad draft format")
3131
ErrorBadDraftUsernameDigestFormat = errors.New("Bad draft username/digest format")
32-
ErrorBadRequestModelHasVersions = errors.New("This model already has versions associated with it, and can't be used with procedures.")
3332
)
3433

3534
type Client struct {
@@ -142,15 +141,14 @@ func (c *Client) postNewVersion(ctx context.Context, image string, tarball *byte
142141
body := new(bytes.Buffer)
143142
mp := multipart.NewWriter(body)
144143
defer mp.Close()
145-
err = mp.WriteField("openapi_schema", manifest.Config.Labels[command.CogOpenAPISchemaLabelKey])
146-
if err != nil {
144+
145+
if err := mp.WriteField("openapi_schema", manifest.Config.Labels[command.CogOpenAPISchemaLabelKey]); err != nil {
147146
return "", err
148147
}
149148

150149
dependencies := manifest.Config.Labels[command.CogModelDependenciesLabelKey]
151150
if dependencies != "" && dependencies != `[""]` {
152-
err = mp.WriteField("dependencies", dependencies)
153-
if err != nil {
151+
if err := mp.WriteField("dependencies", dependencies); err != nil {
154152
return "", err
155153
}
156154
}
@@ -161,8 +159,7 @@ func (c *Client) postNewVersion(ctx context.Context, image string, tarball *byte
161159
if err != nil {
162160
return "", err
163161
}
164-
err = gzipWriter.Close()
165-
if err != nil {
162+
if err := gzipWriter.Close(); err != nil {
166163
return "", err
167164
}
168165

@@ -194,19 +191,17 @@ func (c *Client) postNewVersion(ctx context.Context, image string, tarball *byte
194191

195192
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
196193
var bodyError Error
197-
err = json.NewDecoder(resp.Body).Decode(&bodyError)
198-
if err != nil {
194+
if err := json.NewDecoder(resp.Body).Decode(&bodyError); err != nil {
199195
return "", err
200196
}
201-
if bodyError.Errors[0].Detail == "This endpoint does not support models that have versions published with `cog push`." {
202-
return "", util.WrapError(ErrorBadRequestModelHasVersions, strconv.Itoa(resp.StatusCode))
197+
if len(bodyError.Errors) > 0 && bodyError.Errors[0].Detail != "" {
198+
return "", errors.New(bodyError.Errors[0].Detail)
203199
}
204-
return "", util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode))
200+
return "", ErrorBadResponseNewVersionEndpoint
205201
}
206202

207203
var version Version
208-
err = json.NewDecoder(resp.Body).Decode(&version)
209-
if err != nil {
204+
if err := json.NewDecoder(resp.Body).Decode(&version); err != nil {
210205
return "", err
211206
}
212207

@@ -292,7 +287,7 @@ func (c *Client) downloadTarball(ctx context.Context, token string, url url.URL,
292287
return fmt.Errorf("Entity %s does not have a source package associated with it.", slug)
293288
}
294289

295-
if resp.StatusCode >= 400 {
290+
if resp.StatusCode >= http.StatusBadRequest {
296291
return fmt.Errorf("Bad response: %s attempting to fetch the image source", strconv.Itoa(resp.StatusCode))
297292
}
298293

@@ -306,8 +301,7 @@ func (c *Client) downloadTarball(ctx context.Context, token string, url url.URL,
306301
return err
307302
}
308303

309-
err = tarFileProcess(header, tr)
310-
if err != nil {
304+
if err := tarFileProcess(header, tr); err != nil {
311305
return err
312306
}
313307
}
@@ -335,13 +329,12 @@ func (c *Client) getModel(ctx context.Context, entity string, name string) (*Mod
335329
}
336330
defer resp.Body.Close()
337331

338-
if resp.StatusCode >= 400 {
332+
if resp.StatusCode >= http.StatusBadRequest {
339333
return nil, fmt.Errorf("Bad response: %s attempting to fetch the models versions", strconv.Itoa(resp.StatusCode))
340334
}
341335

342336
var model Model
343-
err = json.NewDecoder(resp.Body).Decode(&model)
344-
if err != nil {
337+
if err := json.NewDecoder(resp.Body).Decode(&model); err != nil {
345338
return nil, err
346339
}
347340

pkg/api/client_test.go

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -333,69 +333,91 @@ func TestPullSourceWithTag(t *testing.T) {
333333
require.NoError(t, err)
334334
}
335335

336-
func TestPostPipelineFailsModelAlreadyHasVersions(t *testing.T) {
337-
// Setup mock web server for cog.replicate.com (token exchange)
338-
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
339-
switch r.URL.Path {
340-
case "/api/token/user":
341-
// Mock token exchange response
342-
//nolint:gosec
343-
tokenResponse := `{
336+
func TestPostPipelineFails(t *testing.T) {
337+
338+
type testCase struct {
339+
name string
340+
body string
341+
wantError string
342+
}
343+
344+
for _, tt := range []testCase{
345+
{
346+
name: "model already has versions",
347+
body: "{\"detail\": \"The following errors occurred:\\n- This endpoint does not support models that have versions published with `cog push`.\",\"errors\":[{\"detail\":\"This endpoint does not support models that have versions published with `cog push`.\",\"pointer\": \"/\"}],\"status\":400,\"title\":\"Validation failed\"}",
348+
wantError: "This endpoint does not support models that have versions published with `cog push`.",
349+
},
350+
{
351+
name: "model uses procedures",
352+
body: "{\"detail\": \"The following errors occurred:\\n- You cannot use this mechanism to push versions of a model that uses pipelines.\",\"errors\":[{\"detail\":\"You cannot use this mechanism to push versions of a model that uses pipelines.\",\"pointer\": \"/\"}],\"status\":400,\"title\":\"Validation failed\"}",
353+
wantError: "You cannot use this mechanism to push versions of a model that uses pipelines.",
354+
},
355+
} {
356+
t.Run(tt.name, func(t *testing.T) {
357+
// Setup mock web server for cog.replicate.com (token exchange)
358+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
359+
switch r.URL.Path {
360+
case "/api/token/user":
361+
// Mock token exchange response
362+
//nolint:gosec
363+
tokenResponse := `{
344364
"keys": {
345365
"cog": {
346366
"key": "test-api-token",
347367
"expires_at": "2024-12-31T23:59:59Z"
348368
}
349369
}
350370
}`
351-
w.WriteHeader(http.StatusOK)
352-
w.Write([]byte(tokenResponse))
353-
default:
354-
w.WriteHeader(http.StatusNotFound)
355-
}
356-
}))
357-
defer webServer.Close()
371+
w.WriteHeader(http.StatusOK)
372+
w.Write([]byte(tokenResponse))
373+
default:
374+
w.WriteHeader(http.StatusNotFound)
375+
}
376+
}))
377+
defer webServer.Close()
378+
379+
// Setup mock API server for api.replicate.com (version and release endpoints)
380+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
381+
switch r.URL.Path {
382+
case "/v1/models/user/test/versions":
383+
// Mock version creation response
384+
w.WriteHeader(http.StatusBadRequest)
385+
w.Write([]byte(tt.body))
386+
case "/v1/models/user/test/releases":
387+
// Mock release creation response - empty body with 204 status
388+
w.WriteHeader(http.StatusNoContent)
389+
default:
390+
w.WriteHeader(http.StatusNotFound)
391+
}
392+
}))
393+
defer apiServer.Close()
358394

359-
// Setup mock API server for api.replicate.com (version and release endpoints)
360-
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361-
switch r.URL.Path {
362-
case "/v1/models/user/test/versions":
363-
// Mock version creation response
364-
versionResponse := "{\"detail\": \"The following errors occurred:\n- This endpoint does not support models that have versions published with `cog push`.\",\"errors\":[{\"detail\":\"This endpoint does not support models that have versions published with `cog push`.\",\"pointer\": \"/\",}],\"status\":400,\"title\":\"Validation failed\"}"
365-
w.WriteHeader(http.StatusBadRequest)
366-
w.Write([]byte(versionResponse))
367-
case "/v1/models/user/test/releases":
368-
// Mock release creation response - empty body with 204 status
369-
w.WriteHeader(http.StatusNoContent)
370-
default:
371-
w.WriteHeader(http.StatusNotFound)
372-
}
373-
}))
374-
defer apiServer.Close()
395+
webURL, err := url.Parse(webServer.URL)
396+
require.NoError(t, err)
397+
apiURL, err := url.Parse(apiServer.URL)
398+
require.NoError(t, err)
375399

376-
webURL, err := url.Parse(webServer.URL)
377-
require.NoError(t, err)
378-
apiURL, err := url.Parse(apiServer.URL)
379-
require.NoError(t, err)
400+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
401+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
402+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
380403

381-
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
382-
t.Setenv(env.WebHostEnvVarName, webURL.Host)
383-
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
404+
dir := t.TempDir()
384405

385-
dir := t.TempDir()
406+
// Create mock predict
407+
predictPyPath := filepath.Join(dir, "predict.py")
408+
handle, err := os.Create(predictPyPath)
409+
require.NoError(t, err)
410+
handle.WriteString("import cog")
411+
dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}"
386412

387-
// Create mock predict
388-
predictPyPath := filepath.Join(dir, "predict.py")
389-
handle, err := os.Create(predictPyPath)
390-
require.NoError(t, err)
391-
handle.WriteString("import cog")
392-
dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}"
413+
// Setup mock command
414+
command := dockertest.NewMockCommand()
415+
webClient := web.NewClient(command, http.DefaultClient)
393416

394-
// Setup mock command
395-
command := dockertest.NewMockCommand()
396-
webClient := web.NewClient(command, http.DefaultClient)
417+
client := NewClient(command, http.DefaultClient, webClient)
418+
err = client.PostNewPipeline(t.Context(), "r8.im/user/test", new(bytes.Buffer))
419+
require.EqualError(t, err, tt.wantError)
420+
})
421+
}
397422

398-
client := NewClient(command, http.DefaultClient, webClient)
399-
err = client.PostNewPipeline(t.Context(), "r8.im/user/test", new(bytes.Buffer))
400-
require.Error(t, err)
401423
}

pkg/coglog/client.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func (c *Client) EndBuild(ctx context.Context, err error, logContext BuildLogCon
7070
errorStr = &errStr
7171
}
7272
buildLog := buildLog{
73-
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
73+
DurationMs: float32(time.Since(logContext.started).Milliseconds()),
7474
BuildError: errorStr,
7575
Fast: logContext.fast,
7676
LocalImage: logContext.localImage,
@@ -107,7 +107,7 @@ func (c *Client) EndPush(ctx context.Context, err error, logContext PushLogConte
107107
errorStr = &errStr
108108
}
109109
pushLog := pushLog{
110-
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
110+
DurationMs: float32(time.Since(logContext.started).Milliseconds()),
111111
BuildError: errorStr,
112112
Fast: logContext.fast,
113113
LocalImage: logContext.localImage,
@@ -140,7 +140,7 @@ func (c *Client) EndMigrate(ctx context.Context, err error, logContext *MigrateL
140140
errorStr = &errStr
141141
}
142142
migrateLog := migrateLog{
143-
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
143+
DurationMs: float32(time.Since(logContext.started).Milliseconds()),
144144
BuildError: errorStr,
145145
Accept: logContext.accept,
146146
PythonPackageStatus: logContext.PythonPackageStatus,
@@ -178,7 +178,7 @@ func (c *Client) EndPull(ctx context.Context, err error, logContext PullLogConte
178178
errorStr = &errStr
179179
}
180180
pushLog := pullLog{
181-
DurationMs: float32(time.Now().Sub(logContext.started).Milliseconds()),
181+
DurationMs: float32(time.Since(logContext.started).Milliseconds()),
182182
BuildError: errorStr,
183183
}
184184

0 commit comments

Comments
 (0)