Skip to content

Commit 243d297

Browse files
8W9aGgithub-advanced-security[bot]aron
authored
Add cog pull (#2386)
* Add API package * Create a client for interacting with the API * Add pull command * Potential fix for code scanning alert no. 25: Arbitrary file write extracting an archive containing symbolic links Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Signed-off-by: Will Sackfield <[email protected]> * Only look for cfg if image isn’t present * Fetch version number if required * Extract the tar within the source function * Allow automatic deferment of close on response * Fix sec warnings * Security warnings have been accounted for by resolving the path and making sure it is within the project directory. * Add tests for pulling source * Update pkg/api/client.go Co-authored-by: Aron Carroll <[email protected]> Signed-off-by: Will Sackfield <[email protected]> * Remove looking into cfg for image name * Add message for docker pull * Pull to a folder with model name * Add r8.im into image name if missing * Add explicit error handling for not found --------- Signed-off-by: Will Sackfield <[email protected]> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: Aron Carroll <[email protected]>
1 parent 0c23233 commit 243d297

File tree

10 files changed

+751
-193
lines changed

10 files changed

+751
-193
lines changed

pkg/api/client.go

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
package api
2+
3+
import (
4+
"archive/tar"
5+
"bytes"
6+
"compress/gzip"
7+
"context"
8+
"encoding/json"
9+
"errors"
10+
"fmt"
11+
"io"
12+
"mime/multipart"
13+
"net/http"
14+
"net/url"
15+
"strconv"
16+
"strings"
17+
18+
"github.com/replicate/cog/pkg/docker/command"
19+
"github.com/replicate/cog/pkg/env"
20+
r8_errors "github.com/replicate/cog/pkg/errors"
21+
"github.com/replicate/cog/pkg/global"
22+
"github.com/replicate/cog/pkg/util"
23+
"github.com/replicate/cog/pkg/web"
24+
)
25+
26+
var (
27+
ErrorBadResponseNewVersionEndpoint = errors.New("Bad response from new version endpoint")
28+
)
29+
30+
type Client struct {
31+
dockerCommand command.Command
32+
client *http.Client
33+
tokens map[string]string
34+
webClient *web.Client
35+
}
36+
37+
type Version struct {
38+
Id string `json:"id"`
39+
}
40+
41+
type CreateRelease struct {
42+
Version string `json:"version"`
43+
}
44+
45+
type Model struct {
46+
LatestVersion Version `json:"latest_version"`
47+
}
48+
49+
func NewClient(dockerCommand command.Command, client *http.Client, webClient *web.Client) *Client {
50+
return &Client{
51+
dockerCommand: dockerCommand,
52+
client: client,
53+
tokens: map[string]string{},
54+
webClient: webClient,
55+
}
56+
}
57+
58+
func (c *Client) PostNewPipeline(ctx context.Context, image string, tarball *bytes.Buffer) error {
59+
id, err := c.postNewVersion(ctx, image, tarball)
60+
if err != nil {
61+
return err
62+
}
63+
64+
return c.postNewRelease(ctx, id, image)
65+
}
66+
67+
func (c *Client) PullSource(ctx context.Context, image string, tarFileProcess func(*tar.Header, *tar.Reader) error) error {
68+
_, entity, name, tag, err := decomposeImageName(image)
69+
if err != nil {
70+
return err
71+
}
72+
73+
// Check if we require the tag
74+
if tag == "" {
75+
model, err := c.getModel(ctx, entity, name)
76+
if err != nil {
77+
return err
78+
}
79+
tag = model.LatestVersion.Id
80+
}
81+
82+
// Fetch the source
83+
return c.getSource(ctx, entity, name, tag, tarFileProcess)
84+
}
85+
86+
func (c *Client) provideToken(ctx context.Context, entity string) (string, error) {
87+
token, ok := c.tokens[entity]
88+
if !ok {
89+
webToken, err := c.webClient.FetchAPIToken(ctx, entity)
90+
if err != nil {
91+
return "", err
92+
}
93+
token = webToken
94+
c.tokens[entity] = token
95+
}
96+
return token, nil
97+
}
98+
99+
func (c *Client) postNewVersion(ctx context.Context, image string, tarball *bytes.Buffer) (string, error) {
100+
// Fetch manifest
101+
manifest, err := c.dockerCommand.Inspect(ctx, image)
102+
if err != nil {
103+
return "", util.WrapError(err, "failed to inspect docker image")
104+
}
105+
106+
// Fetch token
107+
_, entity, name, _, err := decomposeImageName(image)
108+
if err != nil {
109+
return "", err
110+
}
111+
token, err := c.provideToken(ctx, entity)
112+
if err != nil {
113+
return "", err
114+
}
115+
116+
// Create form data body
117+
body := new(bytes.Buffer)
118+
mp := multipart.NewWriter(body)
119+
defer mp.Close()
120+
err = mp.WriteField("openapi_schema", manifest.Config.Labels[command.CogOpenAPISchemaLabelKey])
121+
if err != nil {
122+
return "", err
123+
}
124+
125+
dependencies := manifest.Config.Labels[command.CogModelDependenciesLabelKey]
126+
if dependencies != "" && dependencies != `[""]` {
127+
err = mp.WriteField("dependencies", dependencies)
128+
if err != nil {
129+
return "", err
130+
}
131+
}
132+
133+
var gzipBuf bytes.Buffer
134+
gzipWriter := gzip.NewWriter(&gzipBuf)
135+
_, err = io.Copy(gzipWriter, bytes.NewReader(tarball.Bytes()))
136+
if err != nil {
137+
return "", err
138+
}
139+
err = gzipWriter.Close()
140+
if err != nil {
141+
return "", err
142+
}
143+
144+
part, err := mp.CreateFormFile("source_archive", "source_archive.tar.gz")
145+
if err != nil {
146+
return "", err
147+
}
148+
149+
_, err = io.Copy(part, bytes.NewReader(gzipBuf.Bytes()))
150+
if err != nil {
151+
return "", err
152+
}
153+
mp.Close()
154+
155+
versionURL := newVersionsURL(entity, name)
156+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, versionURL.String(), bytes.NewReader(body.Bytes()))
157+
if err != nil {
158+
return "", err
159+
}
160+
req.Header.Set("Content-Type", mp.FormDataContentType())
161+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
162+
163+
// Make the request
164+
resp, err := c.client.Do(req)
165+
if err != nil {
166+
return "", err
167+
}
168+
defer resp.Body.Close()
169+
170+
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
171+
return "", util.WrapError(ErrorBadResponseNewVersionEndpoint, strconv.Itoa(resp.StatusCode))
172+
}
173+
174+
var version Version
175+
err = json.NewDecoder(resp.Body).Decode(&version)
176+
if err != nil {
177+
return "", err
178+
}
179+
180+
return version.Id, nil
181+
}
182+
183+
func (c *Client) postNewRelease(ctx context.Context, id string, image string) error {
184+
_, entity, name, _, err := decomposeImageName(image)
185+
if err != nil {
186+
return err
187+
}
188+
189+
token, err := c.provideToken(ctx, entity)
190+
if err != nil {
191+
return err
192+
}
193+
194+
releaseURL := newReleaseURL(entity, name)
195+
createRelease := CreateRelease{
196+
Version: id,
197+
}
198+
buf := new(bytes.Buffer)
199+
err = json.NewEncoder(buf).Encode(createRelease)
200+
if err != nil {
201+
return fmt.Errorf("Unable to encode JSON request body: %w", err)
202+
}
203+
204+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, releaseURL.String(), buf)
205+
if err != nil {
206+
return err
207+
}
208+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
209+
210+
// Make the request
211+
releaseResp, err := c.client.Do(req)
212+
if err != nil {
213+
return err
214+
}
215+
defer releaseResp.Body.Close()
216+
217+
if releaseResp.StatusCode != http.StatusNoContent {
218+
return fmt.Errorf("Bad response: %s attempting to create a release", strconv.Itoa(releaseResp.StatusCode))
219+
}
220+
221+
return nil
222+
}
223+
224+
func (c *Client) getSource(ctx context.Context, entity string, name string, tag string, tarFileProcess func(*tar.Header, *tar.Reader) error) error {
225+
token, err := c.provideToken(ctx, entity)
226+
if err != nil {
227+
return err
228+
}
229+
230+
sourceURL := newSourceURL(entity, name, tag)
231+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), nil)
232+
if err != nil {
233+
return err
234+
}
235+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
236+
237+
// Make the request
238+
resp, err := c.client.Do(req)
239+
if err != nil {
240+
return err
241+
}
242+
defer resp.Body.Close()
243+
244+
if resp.StatusCode == http.StatusNotFound {
245+
return fmt.Errorf("Model %s/%s does not have a source package associated with it.", entity, name)
246+
}
247+
248+
if resp.StatusCode >= 400 {
249+
return fmt.Errorf("Bad response: %s attempting to fetch the image source", strconv.Itoa(resp.StatusCode))
250+
}
251+
252+
tr := tar.NewReader(resp.Body)
253+
for {
254+
header, err := tr.Next()
255+
if err == io.EOF {
256+
break
257+
}
258+
if err != nil {
259+
return err
260+
}
261+
262+
err = tarFileProcess(header, tr)
263+
if err != nil {
264+
return err
265+
}
266+
}
267+
268+
return nil
269+
}
270+
271+
func (c *Client) getModel(ctx context.Context, entity string, name string) (*Model, error) {
272+
token, err := c.provideToken(ctx, entity)
273+
if err != nil {
274+
return nil, err
275+
}
276+
277+
modelURL := newModelURL(entity, name)
278+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, modelURL.String(), nil)
279+
if err != nil {
280+
return nil, err
281+
}
282+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
283+
284+
// Make the request
285+
resp, err := c.client.Do(req)
286+
if err != nil {
287+
return nil, err
288+
}
289+
defer resp.Body.Close()
290+
291+
if resp.StatusCode >= 400 {
292+
return nil, fmt.Errorf("Bad response: %s attempting to fetch the models versions", strconv.Itoa(resp.StatusCode))
293+
}
294+
295+
var model Model
296+
err = json.NewDecoder(resp.Body).Decode(&model)
297+
if err != nil {
298+
return nil, err
299+
}
300+
301+
return &model, nil
302+
}
303+
304+
func apiBaseURL() url.URL {
305+
return url.URL{
306+
Scheme: env.SchemeFromEnvironment(),
307+
Host: env.APIHostFromEnvironment(),
308+
}
309+
}
310+
311+
func newVersionsURL(entity string, name string) url.URL {
312+
newVersionUrl := apiBaseURL()
313+
newVersionUrl.Path = strings.Join([]string{"", "v1", "models", entity, name, "versions"}, "/")
314+
return newVersionUrl
315+
}
316+
317+
func newReleaseURL(entity string, name string) url.URL {
318+
newReleaseUrl := apiBaseURL()
319+
newReleaseUrl.Path = strings.Join([]string{"", "v1", "models", entity, name, "releases"}, "/")
320+
return newReleaseUrl
321+
}
322+
323+
func newSourceURL(entity string, name string, tag string) url.URL {
324+
newSourceUrl := apiBaseURL()
325+
newSourceUrl.Path = strings.Join([]string{"", "v1", "models", entity, name, "versions", tag, "source"}, "/")
326+
return newSourceUrl
327+
}
328+
329+
func newModelURL(entity string, name string) url.URL {
330+
newModelUrl := apiBaseURL()
331+
newModelUrl.Path = strings.Join([]string{"", "v1", "models", entity, name}, "/")
332+
return newModelUrl
333+
}
334+
335+
func decomposeImageName(image string) (string, string, string, string, error) {
336+
imageComponents := strings.Split(image, "/")
337+
338+
// Attempt normalisation of image
339+
if len(imageComponents) == 2 && imageComponents[0] != global.ReplicateRegistryHost {
340+
imageComponents = append([]string{global.ReplicateRegistryHost}, imageComponents...)
341+
}
342+
343+
if len(imageComponents) != 3 {
344+
return "", "", "", "", r8_errors.ErrorBadRegistryURL
345+
}
346+
if imageComponents[0] != global.ReplicateRegistryHost {
347+
return "", "", "", "", r8_errors.ErrorBadRegistryHost
348+
}
349+
tagComponents := strings.Split(image, ":")
350+
tag := ""
351+
if len(tagComponents) == 2 {
352+
tag = tagComponents[1]
353+
}
354+
return imageComponents[0], imageComponents[1], imageComponents[2], tag, nil
355+
}

0 commit comments

Comments
 (0)