Skip to content

Commit 0e0a0a5

Browse files
authored
In event of error, check equality (#2415)
* In event of error, checking equality * When checking if 2 version strings are greater or equal, if an error occurs parsing the version fall back to string equality instead of panicking * Use pep-440 for version comparison * Use go-pep440 to evaluate the project version specifier
1 parent 61cb550 commit 0e0a0a5

File tree

9 files changed

+268
-14
lines changed

9 files changed

+268
-14
lines changed

go.mod

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ require (
6969
github.com/alexkohler/prealloc v1.0.0 // indirect
7070
github.com/alingse/asasalint v0.0.11 // indirect
7171
github.com/alingse/nilnesserr v0.1.2 // indirect
72+
github.com/aquasecurity/go-pep440-version v0.0.1 // indirect
73+
github.com/aquasecurity/go-version v0.0.1 // indirect
7274
github.com/ashanbrown/forbidigo v1.6.0 // indirect
7375
github.com/ashanbrown/makezero v1.2.0 // indirect
7476
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
@@ -304,6 +306,7 @@ require (
304306
golang.org/x/text v0.25.0 // indirect
305307
golang.org/x/time v0.11.0 // indirect
306308
golang.org/x/tools v0.33.0 // indirect
309+
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect
307310
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect
308311
google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 // indirect
309312
google.golang.org/protobuf v1.36.6 // indirect

go.sum

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ github.com/anaskhan96/soup v1.2.5 h1:V/FHiusdTrPrdF4iA1YkVxsOpdNcgvqT1hG+YtcZ5hM
5858
github.com/anaskhan96/soup v1.2.5/go.mod h1:6YnEp9A2yywlYdM4EgDz9NEHclocMepEtku7wg6Cq3s=
5959
github.com/anchore/go-struct-converter v0.0.0-20221118182256-c68fdcfa2092 h1:aM1rlcoLz8y5B2r4tTLMiVTrMtpfY0O8EScKJxaSaEc=
6060
github.com/anchore/go-struct-converter v0.0.0-20221118182256-c68fdcfa2092/go.mod h1:rYqSE9HbjzpHTI74vwPvae4ZVYZd1lue2ta6xHPdblA=
61+
github.com/aquasecurity/go-pep440-version v0.0.1 h1:8VKKQtH2aV61+0hovZS3T//rUF+6GDn18paFTVS0h0M=
62+
github.com/aquasecurity/go-pep440-version v0.0.1/go.mod h1:3naPe+Bp6wi3n4l5iBFCZgS0JG8vY6FT0H4NGhFJ+i4=
63+
github.com/aquasecurity/go-version v0.0.1 h1:4cNl516agK0TCn5F7mmYN+xVs1E3S45LkgZk3cbaW2E=
64+
github.com/aquasecurity/go-version v0.0.1/go.mod h1:s1UU6/v2hctXcOa3OLwfj5d9yoXHa3ahf+ipSwEvGT0=
6165
github.com/ashanbrown/forbidigo v1.6.0 h1:D3aewfM37Yb3pxHujIPSpTf6oQk9sc9WZi8gerOIVIY=
6266
github.com/ashanbrown/forbidigo v1.6.0/go.mod h1:Y8j9jy9ZYAEHXdu723cUlraTqbzjKF1MUyfOKL+AjcU=
6367
github.com/ashanbrown/makezero v1.2.0 h1:/2Lp1bypdmK9wDIq7uWBlDF1iMUpIIS4A+pF6C9IEUU=
@@ -829,6 +833,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
829833
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
830834
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
831835
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
836+
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU=
837+
golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90=
832838
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a h1:nwKuGPlUAt+aR+pcrkfFRrTU1BVrSmYyYMxYbUIVHr0=
833839
google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a/go.mod h1:3kWAYMk1I75K4vykHtKt2ycnOgpA6974V7bREqbsenU=
834840
google.golang.org/genproto/googleapis/rpc v0.0.0-20250219182151-9fdb1cabc7b2 h1:DMTIbak9GhdaSxEjvVzAeNZvyc03I61duqNbnm3SU0M=

pkg/config/config.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,10 +347,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
347347

348348
includePackageNames := []string{}
349349
for _, pkg := range includePackages {
350-
packageName, err := requirements.PackageName(pkg)
351-
if err != nil {
352-
return "", err
353-
}
350+
packageName := requirements.PackageName(pkg)
354351
includePackageNames = append(includePackageNames, packageName)
355352
}
356353

@@ -372,7 +369,7 @@ func (c *Config) PythonRequirementsForArch(goos string, goarch string, includePa
372369
}
373370
}
374371

375-
packageName, _ := requirements.PackageName(archPkg)
372+
packageName := requirements.PackageName(archPkg)
376373
if packageName != "" {
377374
foundIdx := -1
378375
for i, includePkg := range includePackageNames {

pkg/docker/pipeline_push.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"strconv"
1414
"strings"
1515

16+
version "github.com/aquasecurity/go-pep440-version"
17+
1618
"github.com/replicate/cog/pkg/api"
1719
"github.com/replicate/cog/pkg/config"
1820
"github.com/replicate/cog/pkg/dockercontext"
@@ -22,7 +24,6 @@ import (
2224
"github.com/replicate/cog/pkg/util"
2325
"github.com/replicate/cog/pkg/util/console"
2426
"github.com/replicate/cog/pkg/util/files"
25-
"github.com/replicate/cog/pkg/util/version"
2627
)
2728

2829
const EtagHeader = "etag"
@@ -199,18 +200,39 @@ func validateRequirements(projectDir string, client *http.Client, cfg *config.Co
199200
}
200201

201202
for _, projectRequirement := range projectRequirements {
202-
projectPackage, projectVersion, _, _, err := requirements.SplitPinnedPythonRequirement(projectRequirement)
203-
if err != nil {
204-
return err
203+
projectPackage := requirements.PackageName(projectRequirement)
204+
projectVersionSpecifier := requirements.VersionSpecifier(projectRequirement)
205+
// Continue in case the project does not specify a specific version
206+
if projectVersionSpecifier == "" {
207+
continue
205208
}
206209
found := false
207210
for _, pipelineRequirement := range pipelineRequirements {
211+
if pipelineRequirement == projectRequirement {
212+
found = true
213+
break
214+
}
208215
pipelinePackage, pipelineVersion, _, _, err := requirements.SplitPinnedPythonRequirement(pipelineRequirement)
209216
if err != nil {
210217
return err
211218
}
212219
if pipelinePackage == projectPackage {
213-
found = pipelineVersion == "" || version.GreaterOrEqual(projectVersion, pipelineVersion)
220+
if pipelineVersion == "" {
221+
found = true
222+
} else {
223+
pipelineVersion, err := version.Parse(pipelineVersion)
224+
if err != nil {
225+
return err
226+
}
227+
specifier, err := version.NewSpecifiers(projectVersionSpecifier)
228+
if err != nil {
229+
return err
230+
}
231+
if specifier.Check(pipelineVersion) {
232+
found = true
233+
break
234+
}
235+
}
214236
break
215237
}
216238
}

pkg/docker/pipeline_push_test.go

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,191 @@ func TestPipelinePushFailWithExtraRequirements(t *testing.T) {
164164
err = PipelinePush(t.Context(), "r8.im/user/test", dir, apiClient, client, cfg)
165165
require.Error(t, err)
166166
}
167+
168+
func TestPipelinePushSuccessWithBetaPatch(t *testing.T) {
169+
// Setup mock web server for cog.replicate.com (token exchange)
170+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
171+
switch r.URL.Path {
172+
case "/api/token/user":
173+
// Mock token exchange response
174+
//nolint:gosec
175+
tokenResponse := `{
176+
"keys": {
177+
"cog": {
178+
"key": "test-api-token",
179+
"expires_at": "2024-12-31T23:59:59Z"
180+
}
181+
}
182+
}`
183+
w.WriteHeader(http.StatusOK)
184+
w.Write([]byte(tokenResponse))
185+
default:
186+
w.WriteHeader(http.StatusNotFound)
187+
}
188+
}))
189+
defer webServer.Close()
190+
191+
// Setup mock API server for api.replicate.com (version and release endpoints)
192+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
193+
switch r.URL.Path {
194+
case "/v1/models/user/test/versions":
195+
// Mock version creation response
196+
versionResponse := `{"id": "test-version-id"}`
197+
w.WriteHeader(http.StatusCreated)
198+
w.Write([]byte(versionResponse))
199+
case "/v1/models/user/test/releases":
200+
// Mock release creation response - empty body with 204 status
201+
w.WriteHeader(http.StatusNoContent)
202+
default:
203+
w.WriteHeader(http.StatusNotFound)
204+
}
205+
}))
206+
defer apiServer.Close()
207+
208+
cdnServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
209+
switch r.URL.Path {
210+
case "/requirements.txt":
211+
// Mock requirements.txt response
212+
requirementsResponse := "mycustompackage==1.1.0b2"
213+
w.Header().Add(EtagHeader, "a")
214+
w.WriteHeader(http.StatusOK)
215+
w.Write([]byte(requirementsResponse))
216+
default:
217+
w.WriteHeader(http.StatusNotFound)
218+
}
219+
}))
220+
defer cdnServer.Close()
221+
222+
webURL, err := url.Parse(webServer.URL)
223+
require.NoError(t, err)
224+
apiURL, err := url.Parse(apiServer.URL)
225+
require.NoError(t, err)
226+
cdnURL, err := url.Parse(cdnServer.URL)
227+
require.NoError(t, err)
228+
229+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
230+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
231+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
232+
t.Setenv(env.PipelinesRuntimeHostEnvVarName, cdnURL.Host)
233+
234+
dir := t.TempDir()
235+
236+
// Create mock predict
237+
predictPyPath := filepath.Join(dir, "predict.py")
238+
handle, err := os.Create(predictPyPath)
239+
require.NoError(t, err)
240+
handle.WriteString("import cog")
241+
handle.Close()
242+
dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"mycustompackage==1.1.0b2\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}"
243+
244+
// Setup mock command
245+
command := dockertest.NewMockCommand()
246+
client, err := cogHttp.ProvideHTTPClient(t.Context(), command)
247+
require.NoError(t, err)
248+
webClient := web.NewClient(command, client)
249+
apiClient := api.NewClient(command, client, webClient)
250+
251+
cfg := config.DefaultConfig()
252+
requirementsPath := filepath.Join(dir, "requirements.txt")
253+
handle, err = os.Create(requirementsPath)
254+
require.NoError(t, err)
255+
handle.WriteString("mycustompackage==1.1.0b2")
256+
handle.Close()
257+
cfg.Build.PythonRequirements = filepath.Base(requirementsPath)
258+
err = PipelinePush(t.Context(), "r8.im/user/test", dir, apiClient, client, cfg)
259+
require.NoError(t, err)
260+
}
261+
262+
func TestPipelinePushSuccessWithAlphaPatch(t *testing.T) {
263+
// Setup mock web server for cog.replicate.com (token exchange)
264+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
265+
switch r.URL.Path {
266+
case "/api/token/user":
267+
// Mock token exchange response
268+
//nolint:gosec
269+
tokenResponse := `{
270+
"keys": {
271+
"cog": {
272+
"key": "test-api-token",
273+
"expires_at": "2024-12-31T23:59:59Z"
274+
}
275+
}
276+
}`
277+
w.WriteHeader(http.StatusOK)
278+
w.Write([]byte(tokenResponse))
279+
default:
280+
w.WriteHeader(http.StatusNotFound)
281+
}
282+
}))
283+
defer webServer.Close()
284+
285+
// Setup mock API server for api.replicate.com (version and release endpoints)
286+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
287+
switch r.URL.Path {
288+
case "/v1/models/user/test/versions":
289+
// Mock version creation response
290+
versionResponse := `{"id": "test-version-id"}`
291+
w.WriteHeader(http.StatusCreated)
292+
w.Write([]byte(versionResponse))
293+
case "/v1/models/user/test/releases":
294+
// Mock release creation response - empty body with 204 status
295+
w.WriteHeader(http.StatusNoContent)
296+
default:
297+
w.WriteHeader(http.StatusNotFound)
298+
}
299+
}))
300+
defer apiServer.Close()
301+
302+
cdnServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
303+
switch r.URL.Path {
304+
case "/requirements.txt":
305+
// Mock requirements.txt response
306+
requirementsResponse := "mycustompackage==1.1.0b2"
307+
w.Header().Add(EtagHeader, "a")
308+
w.WriteHeader(http.StatusOK)
309+
w.Write([]byte(requirementsResponse))
310+
default:
311+
w.WriteHeader(http.StatusNotFound)
312+
}
313+
}))
314+
defer cdnServer.Close()
315+
316+
webURL, err := url.Parse(webServer.URL)
317+
require.NoError(t, err)
318+
apiURL, err := url.Parse(apiServer.URL)
319+
require.NoError(t, err)
320+
cdnURL, err := url.Parse(cdnServer.URL)
321+
require.NoError(t, err)
322+
323+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
324+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
325+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
326+
t.Setenv(env.PipelinesRuntimeHostEnvVarName, cdnURL.Host)
327+
328+
dir := t.TempDir()
329+
330+
// Create mock predict
331+
predictPyPath := filepath.Join(dir, "predict.py")
332+
handle, err := os.Create(predictPyPath)
333+
require.NoError(t, err)
334+
handle.WriteString("import cog")
335+
handle.Close()
336+
dockertest.MockCogConfig = "{\"build\":{\"python_version\":\"3.12\",\"python_packages\":[\"torch==2.5.0\",\"mycustompackage>=1.0\"],\"system_packages\":[\"git\"]},\"image\":\"test\",\"predict\":\"" + predictPyPath + ":Predictor\"}"
337+
338+
// Setup mock command
339+
command := dockertest.NewMockCommand()
340+
client, err := cogHttp.ProvideHTTPClient(t.Context(), command)
341+
require.NoError(t, err)
342+
webClient := web.NewClient(command, client)
343+
apiClient := api.NewClient(command, client, webClient)
344+
345+
cfg := config.DefaultConfig()
346+
requirementsPath := filepath.Join(dir, "requirements.txt")
347+
handle, err = os.Create(requirementsPath)
348+
require.NoError(t, err)
349+
handle.WriteString("mycustompackage>=1.0")
350+
handle.Close()
351+
cfg.Build.PythonRequirements = filepath.Base(requirementsPath)
352+
err = PipelinePush(t.Context(), "r8.im/user/test", dir, apiClient, client, cfg)
353+
require.NoError(t, err)
354+
}

pkg/requirements/requirements.go

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,21 @@ func SplitPinnedPythonRequirement(requirement string) (name string, version stri
176176
return name, version, findLinks, extraIndexURLs, nil
177177
}
178178

179-
func PackageName(pipRequirement string) (string, error) {
180-
name, _, _, _, err := SplitPinnedPythonRequirement(pipRequirement)
181-
return name, err
179+
func PackageName(pipRequirement string) string {
180+
re := regexp.MustCompile(`^([a-zA-Z0-9_\-\.]+(?:\[[^\]]+\])?)`)
181+
match := re.FindStringSubmatch(pipRequirement)
182+
if len(match) > 1 {
183+
return match[1]
184+
}
185+
return ""
186+
}
187+
188+
func VersionSpecifier(pipRequirement string) string {
189+
re := regexp.MustCompile(`^[a-zA-Z0-9_\-\.]+(?:\[[^\]]+\])?\s*([<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*(?:\s*\|\|\s*[<>=!~]=?\s*[^;,#\s]+(?:\s*,\s*[<>=!~]=?\s*[^;,#\s]+)*)*)?`)
190+
match := re.FindStringSubmatch(pipRequirement)
191+
if len(match) > 1 {
192+
// Optional: strip spaces for uniform output
193+
return strings.ReplaceAll(match[1], " ", "")
194+
}
195+
return ""
182196
}

pkg/requirements/requirements_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,16 @@ func TestReadRequirementsWithEditable(t *testing.T) {
434434
require.Equal(t, []string{"torch==2.5.1"}, requirements)
435435
}
436436

437+
func TestVersionSpecifier(t *testing.T) {
438+
specifier := VersionSpecifier("mypackage>= 1.0, < 1.4 || > 2.0")
439+
require.Equal(t, specifier, ">=1.0,<1.4||>2.0")
440+
}
441+
442+
func TestPackageName(t *testing.T) {
443+
name := PackageName("mypackage>= 1.0, < 1.4 || > 2.0")
444+
require.Equal(t, name, "mypackage")
445+
}
446+
437447
func checkRequirements(t *testing.T, expected []string, actual []string) {
438448
t.Helper()
439449
for n, expectLine := range expected {

pkg/util/version/version.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,15 @@ func Greater(v1 string, v2 string) bool {
112112
}
113113

114114
func GreaterOrEqual(v1 string, v2 string) bool {
115-
return MustVersion(v1).GreaterOrEqual(MustVersion(v2))
115+
leftVersion, err := NewVersion(v1)
116+
if err != nil {
117+
return v1 == v2
118+
}
119+
rightVersion, err := NewVersion(v2)
120+
if err != nil {
121+
return v1 == v2
122+
}
123+
return leftVersion.GreaterOrEqual(rightVersion)
116124
}
117125

118126
func (v *Version) Matches(other *Version) bool {

pkg/util/version/version_test.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,9 @@ func TestVersionMatchesModifier(t *testing.T) {
7979
matchVersion := "2.3.2+cu118"
8080
require.True(t, Matches(version, matchVersion))
8181
}
82+
83+
func TestGreaterThanOrEqualToWithInvalidPatch(t *testing.T) {
84+
leftVersion := "1.1.0b2"
85+
rightVersion := "1.1.0b2"
86+
require.True(t, GreaterOrEqual(leftVersion, rightVersion))
87+
}

0 commit comments

Comments
 (0)