Skip to content

Commit 82de546

Browse files
authored
Add support for pulling draft sources (#2395)
* Add support for pulling draft sources * Draft sources can now be pulled using draft:user/digest * Extracts tarball logic into lower function, only Significant change is the URL processing * Remove console log
1 parent 7723476 commit 82de546

File tree

2 files changed

+129
-2
lines changed

2 files changed

+129
-2
lines changed

pkg/api/client.go

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ import (
2323
"github.com/replicate/cog/pkg/web"
2424
)
2525

26+
const DraftsPrefix = "draft:"
27+
2628
var (
2729
ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint")
30+
ErrorBadDraftFormat = errors.New("Bad draft format")
31+
ErrorBadDraftUsernameDigestFormat = errors.New("Bad draft username/digest format")
2832
)
2933

3034
type Client struct {
@@ -65,6 +69,14 @@ func (c *Client) PostNewPipeline(ctx context.Context, image string, tarball *byt
6569
}
6670

6771
func (c *Client) PullSource(ctx context.Context, image string, tarFileProcess func(*tar.Header, *tar.Reader) error) error {
72+
if strings.HasPrefix(image, DraftsPrefix) {
73+
username, digest, err := decomposeDraftSlug(image)
74+
if err != nil {
75+
return err
76+
}
77+
return c.getDraftSource(ctx, username, digest, tarFileProcess)
78+
}
79+
6880
_, entity, name, tag, err := decomposeImageName(image)
6981
if err != nil {
7082
return err
@@ -228,7 +240,21 @@ func (c *Client) getSource(ctx context.Context, entity string, name string, tag
228240
}
229241

230242
sourceURL := newSourceURL(entity, name, tag)
231-
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), nil)
243+
return c.downloadTarball(ctx, token, sourceURL, strings.Join([]string{entity, name}, "/"), tarFileProcess)
244+
}
245+
246+
func (c *Client) getDraftSource(ctx context.Context, username string, digest string, tarFileProcess func(*tar.Header, *tar.Reader) error) error {
247+
token, err := c.provideToken(ctx, username)
248+
if err != nil {
249+
return err
250+
}
251+
252+
draftURL := newDraftSourceURL(digest)
253+
return c.downloadTarball(ctx, token, draftURL, DraftsPrefix+strings.Join([]string{username, digest}, "/"), tarFileProcess)
254+
}
255+
256+
func (c *Client) downloadTarball(ctx context.Context, token string, url url.URL, slug string, tarFileProcess func(*tar.Header, *tar.Reader) error) error {
257+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil)
232258
if err != nil {
233259
return err
234260
}
@@ -242,7 +268,7 @@ func (c *Client) getSource(ctx context.Context, entity string, name string, tag
242268
defer resp.Body.Close()
243269

244270
if resp.StatusCode == http.StatusNotFound {
245-
return fmt.Errorf("Model %s/%s does not have a source package associated with it.", entity, name)
271+
return fmt.Errorf("Entity %s does not have a source package associated with it.", slug)
246272
}
247273

248274
if resp.StatusCode >= 400 {
@@ -332,6 +358,12 @@ func newModelURL(entity string, name string) url.URL {
332358
return newModelUrl
333359
}
334360

361+
func newDraftSourceURL(digest string) url.URL {
362+
newDraftSourceUrl := apiBaseURL()
363+
newDraftSourceUrl.Path = strings.Join([]string{"", "v1", "drafts", digest, "source"}, "/")
364+
return newDraftSourceUrl
365+
}
366+
335367
func decomposeImageName(image string) (string, string, string, string, error) {
336368
imageComponents := strings.Split(image, "/")
337369

@@ -353,3 +385,17 @@ func decomposeImageName(image string) (string, string, string, string, error) {
353385
}
354386
return imageComponents[0], imageComponents[1], imageComponents[2], tag, nil
355387
}
388+
389+
func decomposeDraftSlug(slug string) (string, string, error) {
390+
slugComponents := strings.Split(slug, ":")
391+
if len(slugComponents) != 2 {
392+
return "", "", ErrorBadDraftFormat
393+
}
394+
395+
draftComponents := strings.Split(slugComponents[1], "/")
396+
if len(draftComponents) != 2 {
397+
return "", "", ErrorBadDraftUsernameDigestFormat
398+
}
399+
400+
return draftComponents[0], draftComponents[1], nil
401+
}

pkg/api/client_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,84 @@ func TestPullSource(t *testing.T) {
170170
})
171171
require.NoError(t, err)
172172
}
173+
174+
func TestPullDraftSource(t *testing.T) {
175+
// Create file to pull
176+
dir := t.TempDir()
177+
predictPyPath := filepath.Join(dir, "predict.py")
178+
handle, err := os.Create(predictPyPath)
179+
require.NoError(t, err)
180+
handle.WriteString("import cog")
181+
err = handle.Close()
182+
require.NoError(t, err)
183+
info, err := os.Stat(predictPyPath)
184+
require.NoError(t, err)
185+
186+
// Setup mock web server for cog.replicate.com (token exchange)
187+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
188+
switch r.URL.Path {
189+
case "/api/token/user":
190+
// Mock token exchange response
191+
//nolint:gosec
192+
tokenResponse := `{
193+
"keys": {
194+
"cog": {
195+
"key": "test-api-token",
196+
"expires_at": "2024-12-31T23:59:59Z"
197+
}
198+
}
199+
}`
200+
w.WriteHeader(http.StatusOK)
201+
w.Write([]byte(tokenResponse))
202+
default:
203+
w.WriteHeader(http.StatusNotFound)
204+
}
205+
}))
206+
defer webServer.Close()
207+
208+
// Setup mock API server for api.replicate.com (model and source endpoints)
209+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
210+
switch r.URL.Path {
211+
case "/v1/drafts/digest/source":
212+
// Mock draft source pull endpoint
213+
var buf bytes.Buffer
214+
tw := tar.NewWriter(&buf)
215+
header, err := tar.FileInfoHeader(info, info.Name())
216+
require.NoError(t, err)
217+
header.Name = "predict.py"
218+
err = tw.WriteHeader(header)
219+
require.NoError(t, err)
220+
file, err := os.Open(predictPyPath)
221+
require.NoError(t, err)
222+
defer file.Close()
223+
_, err = io.Copy(tw, file)
224+
require.NoError(t, err)
225+
err = tw.Close()
226+
require.NoError(t, err)
227+
w.WriteHeader(http.StatusOK)
228+
w.Write(buf.Bytes())
229+
default:
230+
w.WriteHeader(http.StatusNotFound)
231+
}
232+
}))
233+
defer apiServer.Close()
234+
235+
webURL, err := url.Parse(webServer.URL)
236+
require.NoError(t, err)
237+
apiURL, err := url.Parse(apiServer.URL)
238+
require.NoError(t, err)
239+
240+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
241+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
242+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
243+
244+
// Setup mock command
245+
command := dockertest.NewMockCommand()
246+
webClient := web.NewClient(command, http.DefaultClient)
247+
248+
client := NewClient(command, http.DefaultClient, webClient)
249+
err = client.PullSource(t.Context(), "draft:user/digest", func(header *tar.Header, tr *tar.Reader) error {
250+
return nil
251+
})
252+
require.NoError(t, err)
253+
}

0 commit comments

Comments
 (0)