@@ -251,3 +251,84 @@ func TestPullDraftSource(t *testing.T) {
251251 })
252252 require .NoError (t , err )
253253}
254+
255+ func TestPullSourceWithTag (t * testing.T ) {
256+ // Create file to pull
257+ dir := t .TempDir ()
258+ predictPyPath := filepath .Join (dir , "predict.py" )
259+ handle , err := os .Create (predictPyPath )
260+ require .NoError (t , err )
261+ handle .WriteString ("import cog" )
262+ err = handle .Close ()
263+ require .NoError (t , err )
264+ info , err := os .Stat (predictPyPath )
265+ require .NoError (t , err )
266+
267+ // Setup mock web server for cog.replicate.com (token exchange)
268+ webServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
269+ switch r .URL .Path {
270+ case "/api/token/user" :
271+ // Mock token exchange response
272+ //nolint:gosec
273+ tokenResponse := `{
274+ "keys": {
275+ "cog": {
276+ "key": "test-api-token",
277+ "expires_at": "2024-12-31T23:59:59Z"
278+ }
279+ }
280+ }`
281+ w .WriteHeader (http .StatusOK )
282+ w .Write ([]byte (tokenResponse ))
283+ default :
284+ w .WriteHeader (http .StatusNotFound )
285+ }
286+ }))
287+ defer webServer .Close ()
288+
289+ // Setup mock API server for api.replicate.com (model and source endpoints)
290+ apiServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
291+ switch r .URL .Path {
292+ case "/v1/models/user/test/versions/12435/source" :
293+ // Mock source pull endpoint
294+ var buf bytes.Buffer
295+ tw := tar .NewWriter (& buf )
296+ header , err := tar .FileInfoHeader (info , info .Name ())
297+ require .NoError (t , err )
298+ header .Name = "predict.py"
299+ err = tw .WriteHeader (header )
300+ require .NoError (t , err )
301+ file , err := os .Open (predictPyPath )
302+ require .NoError (t , err )
303+ defer file .Close ()
304+ _ , err = io .Copy (tw , file )
305+ require .NoError (t , err )
306+ err = tw .Close ()
307+ require .NoError (t , err )
308+ w .WriteHeader (http .StatusOK )
309+ w .Write (buf .Bytes ())
310+ default :
311+ w .WriteHeader (http .StatusNotFound )
312+ }
313+ }))
314+ defer apiServer .Close ()
315+
316+ webURL , err := url .Parse (webServer .URL )
317+ require .NoError (t , err )
318+ apiURL , err := url .Parse (apiServer .URL )
319+ require .NoError (t , err )
320+
321+ t .Setenv (env .SchemeEnvVarName , webURL .Scheme )
322+ t .Setenv (env .WebHostEnvVarName , webURL .Host )
323+ t .Setenv (env .APIHostEnvVarName , apiURL .Host )
324+
325+ // Setup mock command
326+ command := dockertest .NewMockCommand ()
327+ webClient := web .NewClient (command , http .DefaultClient )
328+
329+ client := NewClient (command , http .DefaultClient , webClient )
330+ err = client .PullSource (t .Context (), "r8.im/user/test:12435" , func (header * tar.Header , tr * tar.Reader ) error {
331+ return nil
332+ })
333+ require .NoError (t , err )
334+ }
0 commit comments