Skip to content

Commit 728a01e

Browse files
authored
Keep requirements.txt in sync for pipeline models (#2515)
* Always download and update requirements.txt * lint * Overwrite requirements.txt before validation * skip test that expects validation errors since we're always overwriting the file * lint fixes * use CogBuildArtifactsFolder global
1 parent 08cef1a commit 728a01e

File tree

9 files changed

+272
-8
lines changed

9 files changed

+272
-8
lines changed

go.sum

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxK
175175
github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
176176
github.com/dnephin/pflag v1.0.7 h1:oxONGlWxhmUct0YzKTgrpQv9AUA1wtPBn7zuSjJqptk=
177177
github.com/dnephin/pflag v1.0.7/go.mod h1:uxE91IoWURlOiTUIA8Mq5ZZkAv3dPUfZNaT80Zm7OQE=
178-
github.com/docker/cli v28.1.1+incompatible h1:eyUemzeI45DY7eDPuwUcmDyDj1pM98oD5MdSpiItp8k=
179-
github.com/docker/cli v28.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
180178
github.com/docker/cli v28.3.0+incompatible h1:s+ttruVLhB5ayeuf2BciwDVxYdKi+RoUlxmwNHV3Vfo=
181179
github.com/docker/cli v28.3.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
182180
github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk=

pkg/cli/init.go

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ import (
55
"fmt"
66
"io"
77
"net/http"
8+
"net/url"
89
"os"
910
"path"
1011
"time"
1112

1213
"github.com/spf13/cobra"
1314

15+
"github.com/replicate/cog/pkg/env"
1416
"github.com/replicate/cog/pkg/util/console"
1517
"github.com/replicate/cog/pkg/util/files"
1618
)
@@ -119,8 +121,10 @@ func processTemplateFile(fs embed.FS, templateDir, filename, cwd string) error {
119121

120122
var content []byte
121123

122-
// Special handling for AGENTS.md - try to download from Replicate docs
123-
if filename == "AGENTS.md" {
124+
// Special handling for specific template files
125+
switch {
126+
case filename == "AGENTS.md":
127+
// Try to download from Replicate docs
124128
downloadedContent, err := downloadAgentsFile()
125129
if err != nil {
126130
console.Infof("Failed to download AGENTS.md: %v", err)
@@ -133,7 +137,21 @@ func processTemplateFile(fs embed.FS, templateDir, filename, cwd string) error {
133137
} else {
134138
content = downloadedContent
135139
}
136-
} else {
140+
case filename == "requirements.txt" && pipelineTemplate:
141+
// Special handling for requirements.txt in pipeline templates - download from runtime
142+
downloadedContent, err := downloadPipelineRequirementsFile()
143+
if err != nil {
144+
console.Infof("Failed to download pipeline requirements.txt: %v", err)
145+
console.Infof("Using template version instead...")
146+
// Fall back to template version
147+
content, err = fs.ReadFile(path.Join(templateDir, filename))
148+
if err != nil {
149+
return fmt.Errorf("Error reading template %s: %w", filename, err)
150+
}
151+
} else {
152+
content = downloadedContent
153+
}
154+
default:
137155
// Regular template file processing
138156
content, err = fs.ReadFile(path.Join(templateDir, filename))
139157
if err != nil {
@@ -174,6 +192,40 @@ func downloadAgentsFile() ([]byte, error) {
174192
return content, nil
175193
}
176194

195+
func downloadPipelineRequirementsFile() ([]byte, error) {
196+
requirementsURL := pipelinesRuntimeRequirementsURL()
197+
198+
client := &http.Client{
199+
Timeout: 10 * time.Second,
200+
}
201+
202+
resp, err := client.Get(requirementsURL.String())
203+
if err != nil {
204+
return nil, fmt.Errorf("%w", err)
205+
}
206+
defer resp.Body.Close()
207+
208+
if resp.StatusCode != http.StatusOK {
209+
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
210+
}
211+
212+
content, err := io.ReadAll(resp.Body)
213+
if err != nil {
214+
return nil, fmt.Errorf("failed to read response body: %w", err)
215+
}
216+
217+
return content, nil
218+
}
219+
220+
func pipelinesRuntimeRequirementsURL() url.URL {
221+
baseURL := url.URL{
222+
Scheme: env.SchemeFromEnvironment(),
223+
Host: env.PipelinesRuntimeHostFromEnvironment(),
224+
}
225+
baseURL.Path = "requirements.txt"
226+
return baseURL
227+
}
228+
177229
func addPipelineInit(cmd *cobra.Command) {
178230
const pipeline = "x-pipeline"
179231
cmd.Flags().BoolVar(&pipelineTemplate, pipeline, false, "Initialize a pipeline template")

pkg/docker/pipeline_push.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,29 @@ import (
88
"net/http"
99
"os"
1010
"path/filepath"
11+
"strings"
1112

1213
"github.com/replicate/cog/pkg/api"
1314
"github.com/replicate/cog/pkg/config"
1415
"github.com/replicate/cog/pkg/dockerignore"
16+
"github.com/replicate/cog/pkg/global"
1517
"github.com/replicate/cog/pkg/procedure"
1618
)
1719

1820
func PipelinePush(ctx context.Context, image string, projectDir string, apiClient *api.Client, client *http.Client, cfg *config.Config) error {
19-
err := procedure.Validate(projectDir, client, cfg, false)
21+
err := procedure.Validate(projectDir, client, cfg, true)
2022
if err != nil {
2123
return err
2224
}
2325

24-
tarball, err := createTarball(projectDir)
26+
tarball, err := createTarball(projectDir, cfg)
2527
if err != nil {
2628
return err
2729
}
2830
return apiClient.PostNewPipeline(ctx, image, tarball)
2931
}
3032

31-
func createTarball(folder string) (*bytes.Buffer, error) {
33+
func createTarball(folder string, cfg *config.Config) (*bytes.Buffer, error) {
3234
var buf bytes.Buffer
3335
tw := tar.NewWriter(&buf)
3436

@@ -37,6 +39,25 @@ func createTarball(folder string) (*bytes.Buffer, error) {
3739
return nil, err
3840
}
3941

42+
// Track if we need to add downloaded requirements to the tarball
43+
var downloadedRequirementsPath string
44+
var downloadedRequirementsContent []byte
45+
46+
// If config points to downloaded requirements (outside project directory),
47+
// we need to include them in the tarball as requirements.txt
48+
if cfg.Build.PythonRequirements != "" {
49+
reqPath := cfg.RequirementsFile(folder)
50+
if !strings.HasPrefix(reqPath, folder) || strings.Contains(reqPath, global.CogBuildArtifactsFolder) {
51+
// This is a downloaded requirements file, read its content
52+
content, err := os.ReadFile(reqPath)
53+
if err != nil {
54+
return nil, err
55+
}
56+
downloadedRequirementsPath = "requirements.txt"
57+
downloadedRequirementsContent = content
58+
}
59+
}
60+
4061
err = dockerignore.Walk(folder, matcher, func(path string, info os.FileInfo, err error) error {
4162
if err != nil {
4263
return err
@@ -51,6 +72,12 @@ func createTarball(folder string) (*bytes.Buffer, error) {
5172
return err
5273
}
5374

75+
// If this is the local requirements.txt and we have downloaded requirements,
76+
// skip the local one (we'll add the downloaded version instead)
77+
if downloadedRequirementsPath != "" && relPath == "requirements.txt" {
78+
return nil
79+
}
80+
5481
file, err := os.Open(path)
5582
if err != nil {
5683
return err
@@ -78,6 +105,25 @@ func createTarball(folder string) (*bytes.Buffer, error) {
78105
return nil, err
79106
}
80107

108+
// Add downloaded requirements as requirements.txt if we have them
109+
if downloadedRequirementsPath != "" {
110+
header := &tar.Header{
111+
Name: downloadedRequirementsPath,
112+
Mode: 0o644,
113+
Size: int64(len(downloadedRequirementsContent)),
114+
}
115+
116+
err = tw.WriteHeader(header)
117+
if err != nil {
118+
return nil, err
119+
}
120+
121+
_, err = tw.Write(downloadedRequirementsContent)
122+
if err != nil {
123+
return nil, err
124+
}
125+
}
126+
81127
if err := tw.Close(); err != nil {
82128
return nil, err
83129
}

pkg/docker/pipeline_push_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func TestPipelinePush(t *testing.T) {
9090
}
9191

9292
func TestPipelinePushFailWithExtraRequirements(t *testing.T) {
93+
t.Skip("Skipping for now, requirements.txt is always overwritten, and hopefully we replace that with support for custom requirements, if not this test comes back")
9394
// Setup mock web server for cog.replicate.com (token exchange)
9495
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
9596
switch r.URL.Path {

pkg/procedure/validate.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package procedure
22

33
import (
44
"errors"
5+
"fmt"
56
"io"
67
"net/http"
78
"net/url"
@@ -118,6 +119,15 @@ func validateRequirements(projectDir string, client *http.Client, cfg *config.Co
118119
return err
119120
}
120121

122+
// Update local requirements.txt to match production before validation if filling
123+
if fill {
124+
err := updateLocalRequirementsFile(projectDir, requirementsFilePath)
125+
if err != nil {
126+
// Log warning but don't fail the build - the downloaded requirements will still be used
127+
console.Warn(fmt.Sprintf("Failed to update local requirements.txt: %v", err))
128+
}
129+
}
130+
121131
if cfg.Build.PythonRequirements != "" {
122132
pipelineRequirements, err := requirements.ReadRequirements(filepath.Join(projectDir, requirementsFilePath))
123133
if err != nil {
@@ -243,6 +253,27 @@ func downloadRequirements(projectDir string, client *http.Client) (string, error
243253
return requirementsFilePath, nil
244254
}
245255

256+
// updateLocalRequirementsFile copies the downloaded requirements to the local requirements.txt file
257+
// This keeps the local file in sync with what's actually available in the runtime
258+
func updateLocalRequirementsFile(projectDir, downloadedRequirementsPath string) error {
259+
// Read the downloaded requirements
260+
downloadedPath := filepath.Join(projectDir, downloadedRequirementsPath)
261+
downloadedContent, err := os.ReadFile(downloadedPath)
262+
if err != nil {
263+
return fmt.Errorf("failed to read downloaded requirements: %w", err)
264+
}
265+
266+
// Write to local requirements.txt
267+
localRequirementsPath := filepath.Join(projectDir, "requirements.txt")
268+
err = os.WriteFile(localRequirementsPath, downloadedContent, 0o644)
269+
if err != nil {
270+
return fmt.Errorf("failed to write local requirements.txt: %w", err)
271+
}
272+
273+
console.Infof("Updated local requirements.txt with runtime requirements")
274+
return nil
275+
}
276+
246277
func requirementsURL() url.URL {
247278
requirementsURL := pipelinesRuntimeBaseURL()
248279
requirementsURL.Path = "requirements.txt"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Ignore dynamically generated requirements.txt file during tests
2+
# This file is created by test_predict_pipeline_downloaded_requirements
3+
# to simulate out-of-sync scenario and gets cleaned up automatically
4+
requirements.txt
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
build:
2+
python_version: "3.13"
3+
python_requirements: "requirements.txt"
4+
predict: "predict.py:Predictor"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from cog import BasePredictor
2+
import importlib.metadata
3+
4+
class Predictor(BasePredictor):
5+
def predict(self) -> str:
6+
"""Test function that verifies downloaded requirements are available"""
7+
8+
try:
9+
# Get all installed packages and their versions
10+
packages = []
11+
for dist in importlib.metadata.distributions():
12+
packages.append(f"{dist.metadata['name']}=={dist.version}")
13+
14+
# Sort for consistent output
15+
packages.sort()
16+
17+
# Create output with prompt and all packages listed
18+
return '\n'.join(packages)
19+
20+
except Exception as e:
21+
return f"ERROR: Unexpected error - {str(e)}"

0 commit comments

Comments
 (0)