Skip to content

Commit fdd1fd8

Browse files
authored
Support running pipelines locally (#2426)
* Move validation logic to procedure package * Add support for running pipelines locally * Add 3.13 to tests * Remove check * This is handled below * Fix fixture location * Fix stdout * Fix call graph to use replicate.use * Fix python tests for call graph
1 parent b6f67fc commit fdd1fd8

File tree

19 files changed

+643
-192
lines changed

19 files changed

+643
-192
lines changed

pkg/cli/build.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ func newBuildCommand() *cobra.Command {
5656
addFastFlag(cmd)
5757
addLocalImage(cmd)
5858
addConfigFlag(cmd)
59+
addPipelineImage(cmd)
5960
cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'")
6061
return cmd
6162
}
@@ -117,7 +118,8 @@ func buildCommand(cmd *cobra.Command, args []string) error {
117118
nil,
118119
buildLocalImage,
119120
dockerClient,
120-
registryClient); err != nil {
121+
registryClient,
122+
pipelinesImage); err != nil {
121123
logClient.EndBuild(ctx, err, logCtx)
122124
return err
123125
}

pkg/cli/predict.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ the prediction on that.`,
6969
addFastFlag(cmd)
7070
addLocalImage(cmd)
7171
addConfigFlag(cmd)
72+
addPipelineImage(cmd)
7273

7374
cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i [email protected]")
7475
cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path")
@@ -187,7 +188,7 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
187188
}
188189

189190
client := registry.NewRegistryClient()
190-
if buildFast {
191+
if buildFast || pipelinesImage {
191192
imageName = config.DockerImageName(projectDir)
192193
if err := image.Build(
193194
ctx,
@@ -208,7 +209,8 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
208209
nil,
209210
buildLocalImage,
210211
dockerClient,
211-
client); err != nil {
212+
client,
213+
pipelinesImage); err != nil {
212214
return err
213215
}
214216
} else {

pkg/cli/pull.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func pull(cmd *cobra.Command, args []string) error {
153153
}
154154

155155
// Check if we are in a pipeline
156-
if !pushPipeline {
156+
if !pipelinesImage {
157157
err = errors.New("Please use docker pull " + image + " to download this model.")
158158
logClient.EndPull(ctx, err, logCtx)
159159
return err

pkg/cli/push.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
"github.com/replicate/cog/pkg/util/console"
2020
)
2121

22-
var pushPipeline bool
22+
var pipelinesImage bool
2323

2424
func newPushCommand() *cobra.Command {
2525
cmd := &cobra.Command{
@@ -120,7 +120,8 @@ func push(cmd *cobra.Command, args []string) error {
120120
annotations,
121121
buildLocalImage,
122122
dockerClient,
123-
registryClient); err != nil {
123+
registryClient,
124+
pipelinesImage); err != nil {
124125
return err
125126
}
126127

@@ -134,7 +135,7 @@ func push(cmd *cobra.Command, args []string) error {
134135
err = docker.Push(ctx, imageName, buildFast, projectDir, dockerClient, docker.BuildInfo{
135136
BuildTime: buildDuration,
136137
BuildID: buildID.String(),
137-
Pipeline: pushPipeline,
138+
Pipeline: pipelinesImage,
138139
}, client, cfg)
139140
if err != nil {
140141
if strings.Contains(err.Error(), "404") {
@@ -167,6 +168,6 @@ func push(cmd *cobra.Command, args []string) error {
167168

168169
func addPipelineImage(cmd *cobra.Command) {
169170
const pipeline = "x-pipeline"
170-
cmd.Flags().BoolVar(&pushPipeline, pipeline, false, "Whether to use the experimental pipeline push feature")
171+
cmd.Flags().BoolVar(&pipelinesImage, pipeline, false, "Whether to use the experimental pipeline feature")
171172
_ = cmd.Flags().MarkHidden(pipeline)
172173
}

pkg/cli/run.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ func newRunCommand() *cobra.Command {
3939
addFastFlag(cmd)
4040
addLocalImage(cmd)
4141
addConfigFlag(cmd)
42+
addPipelineImage(cmd)
4243

4344
flags := cmd.Flags()
4445
// Flags after first argument are considered args and passed to command
@@ -67,7 +68,7 @@ func run(cmd *cobra.Command, args []string) error {
6768
}
6869

6970
var imageName string
70-
if cfg.Build.Fast || buildFast {
71+
if cfg.Build.Fast || buildFast || pipelinesImage {
7172
imageName = config.DockerImageName(projectDir)
7273
err = image.Build(
7374
ctx,
@@ -88,7 +89,8 @@ func run(cmd *cobra.Command, args []string) error {
8889
nil,
8990
buildLocalImage,
9091
dockerClient,
91-
client)
92+
client,
93+
pipelinesImage)
9294
if err != nil {
9395
return err
9496
}

pkg/config/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func DefaultConfig() *Config {
8585
return &Config{
8686
Build: &Build{
8787
GPU: false,
88-
PythonVersion: "3.12",
88+
PythonVersion: "3.13",
8989
},
9090
}
9191
}

pkg/config/config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,7 @@ func TestConfigMarshal(t *testing.T) {
726726
data, err := yaml.Marshal(cfg)
727727
require.NoError(t, err)
728728
require.Equal(t, `build:
729-
python_version: "3.12"
729+
python_version: "3.13"
730730
fast: false
731731
predict: ""
732732
`, string(data))

pkg/docker/pipeline_push.go

Lines changed: 2 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,19 @@ import (
44
"archive/tar"
55
"bytes"
66
"context"
7-
"errors"
87
"io"
98
"net/http"
10-
"net/url"
119
"os"
1210
"path/filepath"
13-
"strconv"
14-
"strings"
15-
16-
version "github.com/aquasecurity/go-pep440-version"
1711

1812
"github.com/replicate/cog/pkg/api"
1913
"github.com/replicate/cog/pkg/config"
20-
"github.com/replicate/cog/pkg/dockercontext"
2114
"github.com/replicate/cog/pkg/dockerignore"
22-
"github.com/replicate/cog/pkg/env"
23-
"github.com/replicate/cog/pkg/requirements"
24-
"github.com/replicate/cog/pkg/util"
25-
"github.com/replicate/cog/pkg/util/console"
26-
"github.com/replicate/cog/pkg/util/files"
27-
)
28-
29-
const EtagHeader = "etag"
30-
31-
var (
32-
ErrorBadStatus = errors.New("Bad status from pipelines-runtime requirements.txt endpoint")
33-
ErrorPythonPackage = errors.New("Python package not available in pipelines runtime")
34-
ErrorPythonPackages = errors.New("Python packages is not supported in pipelines runtime")
35-
ErrorETagHeaderNotFound = errors.New("ETag header was not found on pipelines runtime requirements.txt")
15+
"github.com/replicate/cog/pkg/procedure"
3616
)
3717

3818
func PipelinePush(ctx context.Context, image string, projectDir string, apiClient *api.Client, client *http.Client, cfg *config.Config) error {
39-
err := validateRequirements(projectDir, client, cfg)
19+
err := procedure.Validate(projectDir, client, cfg, false)
4020
if err != nil {
4121
return err
4222
}
@@ -104,145 +84,3 @@ func createTarball(folder string) (*bytes.Buffer, error) {
10484

10585
return &buf, nil
10686
}
107-
108-
func downloadRequirements(projectDir string, client *http.Client) (string, error) {
109-
tmpDir, err := dockercontext.CogBuildArtifactsDirPath(projectDir)
110-
if err != nil {
111-
return "", err
112-
}
113-
url := requirementsURL()
114-
115-
resp, err := client.Head(url.String())
116-
if err != nil {
117-
return "", err
118-
}
119-
defer resp.Body.Close()
120-
exists := false
121-
var requirementsFilePath string
122-
if resp.StatusCode >= 400 {
123-
console.Warn("Failed to fetch HEAD for pipelines-runtime requirements.txt")
124-
} else {
125-
etag := strings.ReplaceAll(filepath.Base(resp.Header.Get(EtagHeader)), "\"", "")
126-
requirementsFilePath = filepath.Join(tmpDir, "pipelines_runtime_requirements_"+etag+".txt")
127-
exists, err = files.Exists(requirementsFilePath)
128-
if err != nil {
129-
return "", err
130-
}
131-
}
132-
133-
if !exists {
134-
resp, err = client.Get(url.String())
135-
if err != nil {
136-
return "", err
137-
}
138-
139-
if resp.StatusCode >= 400 {
140-
return "", util.WrapError(ErrorBadStatus, strconv.Itoa(resp.StatusCode))
141-
}
142-
143-
etag := strings.ReplaceAll(filepath.Base(resp.Header.Get(EtagHeader)), "\"", "")
144-
if etag == "." {
145-
return "", ErrorETagHeaderNotFound
146-
}
147-
requirementsFilePath = filepath.Join(tmpDir, "pipelines_runtime_requirements_"+etag+".txt")
148-
149-
file, err := os.Create(requirementsFilePath)
150-
if err != nil {
151-
console.Info("CREATION FAILED!")
152-
return "", err
153-
}
154-
defer file.Close()
155-
156-
_, err = io.Copy(file, resp.Body)
157-
if err != nil {
158-
return "", err
159-
}
160-
}
161-
162-
return requirementsFilePath, nil
163-
}
164-
165-
func requirementsURL() url.URL {
166-
requirementsURL := pipelinesRuntimeBaseURL()
167-
requirementsURL.Path = "requirements.txt"
168-
return requirementsURL
169-
}
170-
171-
func pipelinesRuntimeBaseURL() url.URL {
172-
return url.URL{
173-
Scheme: env.SchemeFromEnvironment(),
174-
Host: env.PipelinesRuntimeHostFromEnvironment(),
175-
}
176-
}
177-
178-
func validateRequirements(projectDir string, client *http.Client, cfg *config.Config) error {
179-
if len(cfg.Build.PythonPackages) > 0 {
180-
return ErrorPythonPackages
181-
}
182-
183-
if cfg.Build.PythonRequirements == "" {
184-
return nil
185-
}
186-
187-
requirementsFilePath, err := downloadRequirements(projectDir, client)
188-
if err != nil {
189-
return err
190-
}
191-
192-
pipelineRequirements, err := requirements.ReadRequirements(requirementsFilePath)
193-
if err != nil {
194-
return err
195-
}
196-
197-
projectRequirements, err := requirements.ReadRequirements(cfg.RequirementsFile(projectDir))
198-
if err != nil {
199-
return err
200-
}
201-
202-
for _, projectRequirement := range projectRequirements {
203-
projectPackage := requirements.PackageName(projectRequirement)
204-
projectVersionSpecifier := requirements.VersionSpecifier(projectRequirement)
205-
// Continue in case the project does not specify a specific version
206-
if projectVersionSpecifier == "" {
207-
continue
208-
}
209-
found := false
210-
for _, pipelineRequirement := range pipelineRequirements {
211-
if pipelineRequirement == projectRequirement {
212-
found = true
213-
break
214-
}
215-
if strings.Contains(pipelineRequirement, "@") {
216-
continue
217-
}
218-
pipelinePackage, pipelineVersion, _, _, err := requirements.SplitPinnedPythonRequirement(pipelineRequirement)
219-
if err != nil {
220-
return err
221-
}
222-
if pipelinePackage == projectPackage {
223-
if pipelineVersion == "" {
224-
found = true
225-
} else {
226-
pipelineVersion, err := version.Parse(pipelineVersion)
227-
if err != nil {
228-
return err
229-
}
230-
specifier, err := version.NewSpecifiers(projectVersionSpecifier)
231-
if err != nil {
232-
return err
233-
}
234-
if specifier.Check(pipelineVersion) {
235-
found = true
236-
break
237-
}
238-
}
239-
break
240-
}
241-
}
242-
if !found {
243-
return util.WrapError(ErrorPythonPackage, projectRequirement)
244-
}
245-
}
246-
247-
return nil
248-
}

pkg/docker/pipeline_push_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/replicate/cog/pkg/docker/dockertest"
1616
"github.com/replicate/cog/pkg/env"
1717
cogHttp "github.com/replicate/cog/pkg/http"
18+
"github.com/replicate/cog/pkg/procedure"
1819
"github.com/replicate/cog/pkg/web"
1920
)
2021

@@ -210,7 +211,7 @@ func TestPipelinePushSuccessWithBetaPatch(t *testing.T) {
210211
case "/requirements.txt":
211212
// Mock requirements.txt response
212213
requirementsResponse := "mycustompackage==1.1.0b2"
213-
w.Header().Add(EtagHeader, "a")
214+
w.Header().Add(procedure.EtagHeader, "a")
214215
w.WriteHeader(http.StatusOK)
215216
w.Write([]byte(requirementsResponse))
216217
default:
@@ -304,7 +305,7 @@ func TestPipelinePushSuccessWithAlphaPatch(t *testing.T) {
304305
case "/requirements.txt":
305306
// Mock requirements.txt response
306307
requirementsResponse := "mycustompackage==1.1.0b2"
307-
w.Header().Add(EtagHeader, "a")
308+
w.Header().Add(procedure.EtagHeader, "a")
308309
w.WriteHeader(http.StatusOK)
309310
w.Write([]byte(requirementsResponse))
310311
default:
@@ -398,7 +399,7 @@ func TestPipelinePushSuccessWithURLInstallPath(t *testing.T) {
398399
case "/requirements.txt":
399400
// Mock requirements.txt response
400401
requirementsResponse := "mycustompackage==1.1.0b2\ncoglet @ https://github.com/replicate/cog-runtime/releases/download/v0.1.0-alpha29/coglet-0.1.0a29-py3-none-any.whl"
401-
w.Header().Add(EtagHeader, "a")
402+
w.Header().Add(procedure.EtagHeader, "a")
402403
w.WriteHeader(http.StatusOK)
403404
w.Write([]byte(requirementsResponse))
404405
default:

pkg/image/build.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121
"github.com/replicate/cog/pkg/dockerfile"
2222
"github.com/replicate/cog/pkg/dockerignore"
2323
"github.com/replicate/cog/pkg/global"
24+
"github.com/replicate/cog/pkg/http"
25+
"github.com/replicate/cog/pkg/procedure"
2426
"github.com/replicate/cog/pkg/registry"
2527
"github.com/replicate/cog/pkg/util/console"
2628
"github.com/replicate/cog/pkg/weights"
@@ -54,12 +56,24 @@ func Build(
5456
annotations map[string]string,
5557
localImage bool,
5658
dockerCommand command.Command,
57-
client registry.Client) error {
59+
client registry.Client,
60+
pipelinesImage bool) error {
5861
console.Infof("Building Docker image from environment in cog.yaml as %s...", imageName)
5962
if fastFlag {
6063
console.Info("Fast build enabled.")
6164
}
6265

66+
if pipelinesImage {
67+
httpClient, err := http.ProvideHTTPClient(ctx, dockerCommand)
68+
if err != nil {
69+
return err
70+
}
71+
err = procedure.Validate(dir, httpClient, cfg, true)
72+
if err != nil {
73+
return err
74+
}
75+
}
76+
6377
// remove bundled schema files that may be left from previous builds
6478
_ = os.Remove(bundledSchemaFile)
6579

0 commit comments

Comments
 (0)