Skip to content

Commit b737a7b

Browse files
authored
Add specific handling for pushing procedure with versions (#2421)
* Add specific handling for pushing procedure with versions * Parse the error body * Determine if the error is referencing models having versions attached to them. * Send an explicit error if this case exists. * Use new copy for error
1 parent 07f6a78 commit b737a7b

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

pkg/api/client.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var (
2929
ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint")
3030
ErrorBadDraftFormat = errors.New("Bad draft format")
3131
ErrorBadDraftUsernameDigestFormat = errors.New("Bad draft username/digest format")
32+
ErrorBadRequestModelHasVersions = errors.New("This model already has versions associated with it, and can't be used with procedures.")
3233
)
3334

3435
type Client struct {
@@ -50,6 +51,18 @@ type Model struct {
5051
LatestVersion Version `json:"latest_version"`
5152
}
5253

54+
type SubError struct {
55+
Detail string `json:"detail"`
56+
Pointer string `json:"pointer"`
57+
}
58+
59+
type Error struct {
60+
Detail string `json:"detail"`
61+
Errors []SubError `json:"errors"`
62+
Status int `json:"status"`
63+
Title string `json:"title"`
64+
}
65+
5366
func NewClient(dockerCommand command.Command, client *http.Client, webClient *web.Client) *Client {
5467
return &Client{
5568
dockerCommand: dockerCommand,
@@ -180,6 +193,14 @@ func (c *Client) postNewVersion(ctx context.Context, image string, tarball *byte
180193
defer resp.Body.Close()
181194

182195
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
196+
var bodyError Error
197+
err = json.NewDecoder(resp.Body).Decode(&bodyError)
198+
if err != nil {
199+
return "", err
200+
}
201+
if bodyError.Errors[0].Detail == "This endpoint does not support models that have versions published with `cog push`." {
202+
return "", util.WrapError(ErrorBadRequestModelHasVersions, strconv.Itoa(resp.StatusCode))
203+
}
183204
return "", util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode))
184205
}
185206

pkg/api/client_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,70 @@ func TestPullSourceWithTag(t *testing.T) {
332332
})
333333
require.NoError(t, err)
334334
}
335+
336+
func TestPostPipelineFailsModelAlreadyHasVersions(t *testing.T) {
337+
// Setup mock web server for cog.replicate.com (token exchange)
338+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
339+
switch r.URL.Path {
340+
case "/api/token/user":
341+
// Mock token exchange response
342+
//nolint:gosec
343+
tokenResponse := `{
344+
"keys": {
345+
"cog": {
346+
"key": "test-api-token",
347+
"expires_at": "2024-12-31T23:59:59Z"
348+
}
349+
}
350+
}`
351+
w.WriteHeader(http.StatusOK)
352+
w.Write([]byte(tokenResponse))
353+
default:
354+
w.WriteHeader(http.StatusNotFound)
355+
}
356+
}))
357+
defer webServer.Close()
358+
359+
// Setup mock API server for api.replicate.com (version and release endpoints)
360+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
361+
switch r.URL.Path {
362+
case "/v1/models/user/test/versions":
363+
// Mock version creation response
364+
versionResponse := "{\"detail\": \"The following errors occurred:\n- This endpoint does not support models that have versions published with `cog push`.\",\"errors\":[{\"detail\":\"This endpoint does not support models that have versions published with `cog push`.\",\"pointer\": \"/\",}],\"status\":400,\"title\":\"Validation failed\"}"
365+
w.WriteHeader(http.StatusBadRequest)
366+
w.Write([]byte(versionResponse))
367+
case "/v1/models/user/test/releases":
368+
// Mock release creation response - empty body with 204 status
369+
w.WriteHeader(http.StatusNoContent)
370+
default:
371+
w.WriteHeader(http.StatusNotFound)
372+
}
373+
}))
374+
defer apiServer.Close()
375+
376+
webURL, err := url.Parse(webServer.URL)
377+
require.NoError(t, err)
378+
apiURL, err := url.Parse(apiServer.URL)
379+
require.NoError(t, err)
380+
381+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
382+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
383+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
384+
385+
dir := t.TempDir()
386+
387+
// Create mock predict
388+
predictPyPath := filepath.Join(dir, "predict.py")
389+
handle, err := os.Create(predictPyPath)
390+
require.NoError(t, err)
391+
handle.WriteString("import cog")
392+
dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"beautifulsoup4==4.12.3\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}"
393+
394+
// Setup mock command
395+
command := dockertest.NewMockCommand()
396+
webClient := web.NewClient(command, http.DefaultClient)
397+
398+
client := NewClient(command, http.DefaultClient, webClient)
399+
err = client.PostNewPipeline(t.Context(), "r8.im/user/test", new(bytes.Buffer))
400+
require.Error(t, err)
401+
}

0 commit comments

Comments
 (0)