@@ -23,8 +23,12 @@ import (
2323 "github.com/replicate/cog/pkg/web"
2424)
2525
26+ const DraftsPrefix = "draft:"
27+
2628var (
2729 ErrorBadResponseNewVersionEndpoint = errors .New ("Bad response from new version endpoint" )
30+ ErrorBadDraftFormat = errors .New ("Bad draft format" )
31+ ErrorBadDraftUsernameDigestFormat = errors .New ("Bad draft username/digest format" )
2832)
2933
3034type Client struct {
@@ -65,6 +69,14 @@ func (c *Client) PostNewPipeline(ctx context.Context, image string, tarball *byt
6569}
6670
6771func (c * Client ) PullSource (ctx context.Context , image string , tarFileProcess func (* tar.Header , * tar.Reader ) error ) error {
72+ if strings .HasPrefix (image , DraftsPrefix ) {
73+ username , digest , err := decomposeDraftSlug (image )
74+ if err != nil {
75+ return err
76+ }
77+ return c .getDraftSource (ctx , username , digest , tarFileProcess )
78+ }
79+
6880 _ , entity , name , tag , err := decomposeImageName (image )
6981 if err != nil {
7082 return err
@@ -228,7 +240,21 @@ func (c *Client) getSource(ctx context.Context, entity string, name string, tag
228240 }
229241
230242 sourceURL := newSourceURL (entity , name , tag )
231- req , err := http .NewRequestWithContext (ctx , http .MethodGet , sourceURL .String (), nil )
243+ return c .downloadTarball (ctx , token , sourceURL , strings .Join ([]string {entity , name }, "/" ), tarFileProcess )
244+ }
245+
246+ func (c * Client ) getDraftSource (ctx context.Context , username string , digest string , tarFileProcess func (* tar.Header , * tar.Reader ) error ) error {
247+ token , err := c .provideToken (ctx , username )
248+ if err != nil {
249+ return err
250+ }
251+
252+ draftURL := newDraftSourceURL (digest )
253+ return c .downloadTarball (ctx , token , draftURL , DraftsPrefix + strings .Join ([]string {username , digest }, "/" ), tarFileProcess )
254+ }
255+
256+ func (c * Client ) downloadTarball (ctx context.Context , token string , url url.URL , slug string , tarFileProcess func (* tar.Header , * tar.Reader ) error ) error {
257+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , url .String (), nil )
232258 if err != nil {
233259 return err
234260 }
@@ -242,7 +268,7 @@ func (c *Client) getSource(ctx context.Context, entity string, name string, tag
242268 defer resp .Body .Close ()
243269
244270 if resp .StatusCode == http .StatusNotFound {
245- return fmt .Errorf ("Model %s/%s does not have a source package associated with it." , entity , name )
271+ return fmt .Errorf ("Entity %s does not have a source package associated with it." , slug )
246272 }
247273
248274 if resp .StatusCode >= 400 {
@@ -332,6 +358,12 @@ func newModelURL(entity string, name string) url.URL {
332358 return newModelUrl
333359}
334360
361+ func newDraftSourceURL (digest string ) url.URL {
362+ newDraftSourceUrl := apiBaseURL ()
363+ newDraftSourceUrl .Path = strings .Join ([]string {"" , "v1" , "drafts" , digest , "source" }, "/" )
364+ return newDraftSourceUrl
365+ }
366+
335367func decomposeImageName (image string ) (string , string , string , string , error ) {
336368 imageComponents := strings .Split (image , "/" )
337369
@@ -353,3 +385,17 @@ func decomposeImageName(image string) (string, string, string, string, error) {
353385 }
354386 return imageComponents [0 ], imageComponents [1 ], imageComponents [2 ], tag , nil
355387}
388+
389+ func decomposeDraftSlug (slug string ) (string , string , error ) {
390+ slugComponents := strings .Split (slug , ":" )
391+ if len (slugComponents ) != 2 {
392+ return "" , "" , ErrorBadDraftFormat
393+ }
394+
395+ draftComponents := strings .Split (slugComponents [1 ], "/" )
396+ if len (draftComponents ) != 2 {
397+ return "" , "" , ErrorBadDraftUsernameDigestFormat
398+ }
399+
400+ return draftComponents [0 ], draftComponents [1 ], nil
401+ }
0 commit comments