Skip to content

Commit 1e44312

Browse files
committed
goenv: add Compare function to compare two Go version strings
1 parent 53a9418 commit 1e44312

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

goenv/version.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,28 @@ func WantGoVersion(s string, major, minor int) bool {
7474
return ma > major || (ma == major && mi >= minor)
7575
}
7676

77+
// Compare compares two Go version strings.
78+
// The result will be 0 if a == b, -1 if a < b, and +1 if a > b.
79+
// If either a or b is not a valid Go version, it is treated as "go0.0"
80+
// and compared lexicographically.
81+
// See [Parse] for more information.
82+
func Compare(a, b string) int {
83+
aMajor, aMinor, _ := Parse(a)
84+
bMajor, bMinor, _ := Parse(b)
85+
switch {
86+
case aMajor < bMajor:
87+
return -1
88+
case aMajor > bMajor:
89+
return +1
90+
case aMinor < bMinor:
91+
return -1
92+
case aMinor > bMinor:
93+
return +1
94+
default:
95+
return strings.Compare(a, b)
96+
}
97+
}
98+
7799
// GorootVersionString returns the version string as reported by the Go
78100
// toolchain. It is usually of the form `go1.x.y` but can have some variations
79101
// (for beta releases, for example).

goenv/version_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,33 @@ func TestWantGoVersion(t *testing.T) {
6666
})
6767
}
6868
}
69+
70+
func TestCompare(t *testing.T) {
71+
tests := []struct {
72+
a string
73+
b string
74+
want int
75+
}{
76+
{"", "", 0},
77+
{"go0", "go0", 0},
78+
{"go0", "go1", -1},
79+
{"go1", "go0", 1},
80+
{"go1", "go2", -1},
81+
{"go2", "go1", 1},
82+
{"go1.1", "go1.2", -1},
83+
{"go1.2", "go1.1", 1},
84+
{"go1.1.0", "go1.2.0", -1},
85+
{"go1.2.0", "go1.1.0", 1},
86+
{"go1.2.0", "go2.3.0", -1},
87+
// {"go1.23.2", "go1.23.10", -1}, // FIXME: parse patch number
88+
}
89+
for _, tt := range tests {
90+
t.Run(tt.a+" "+tt.b, func(t *testing.T) {
91+
got := Compare(tt.a, tt.b)
92+
if got != tt.want {
93+
t.Errorf("Compare(%q, %q): expected %d; got %d",
94+
tt.a, tt.b, tt.want, got)
95+
}
96+
})
97+
}
98+
}

0 commit comments

Comments
 (0)