Skip to content

Commit 171b1d5

Browse files
authored
feat: add registry migration support with credential fallback (#2541)
1 parent 76bf285 commit 171b1d5

File tree

6 files changed

+248
-46
lines changed

6 files changed

+248
-46
lines changed

pkg/cli/login.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ type VerifyResponse struct {
2525
}
2626

2727
func newLoginCommand() *cobra.Command {
28-
var cmd = &cobra.Command{
28+
cmd := &cobra.Command{
2929
Use: "login",
3030
SuggestFor: []string{"auth", "authenticate", "authorize"},
3131
Short: "Log in to Replicate Docker registry",
@@ -34,19 +34,15 @@ func newLoginCommand() *cobra.Command {
3434
}
3535

3636
cmd.Flags().Bool("token-stdin", false, "Pass login token on stdin instead of opening a browser. You can find your Replicate login token at https://replicate.com/auth/token")
37-
cmd.Flags().String("registry", global.ReplicateRegistryHost, "Registry host")
38-
_ = cmd.Flags().MarkHidden("registry")
3937

4038
return cmd
4139
}
4240

4341
func login(cmd *cobra.Command, args []string) error {
4442
ctx := cmd.Context()
4543

46-
registryHost, err := cmd.Flags().GetString("registry")
47-
if err != nil {
48-
return err
49-
}
44+
// Use global registry host (can be set via --registry flag or COG_REGISTRY_HOST env var)
45+
registryHost := global.ReplicateRegistryHost
5046
tokenStdin, err := cmd.Flags().GetBool("token-stdin")
5147
if err != nil {
5248
return err

pkg/cli/root.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,7 @@ func setPersistentFlags(cmd *cobra.Command) {
5656
cmd.PersistentFlags().BoolVar(&global.Debug, "debug", false, "Show debugging output")
5757
cmd.PersistentFlags().BoolVar(&global.ProfilingEnabled, "profile", false, "Enable profiling")
5858
cmd.PersistentFlags().Bool("version", false, "Show version of Cog")
59+
cmd.PersistentFlags().StringVar(&global.ReplicateRegistryHost, "registry", global.ReplicateRegistryHost, "Registry host")
5960
_ = cmd.PersistentFlags().MarkHidden("profile")
61+
_ = cmd.PersistentFlags().MarkHidden("registry")
6062
}

pkg/docker/api_client.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import (
3131
"github.com/replicate/go/types/ptr"
3232

3333
"github.com/replicate/cog/pkg/docker/command"
34+
"github.com/replicate/cog/pkg/global"
3435
"github.com/replicate/cog/pkg/util/console"
3536
)
3637

@@ -70,17 +71,13 @@ func NewAPIClient(ctx context.Context, opts ...Option) (*apiClient, error) {
7071
return nil, fmt.Errorf("error pinging docker daemon: %w", err)
7172
}
7273

73-
authConfig := make(map[string]registry.AuthConfig)
74-
userInfo, err := loadUserInformation(ctx, "r8.im")
74+
// Load authentication for configured registry and any other registries that might be needed
75+
authConfig, err := loadRegistryAuths(ctx, global.ReplicateRegistryHost)
7576
if err != nil {
7677
return nil, fmt.Errorf("error loading user information: %w, you may need to authenticate using cog login", err)
7778
}
78-
authConfig["r8.im"] = registry.AuthConfig{
79-
Username: userInfo.Username,
80-
Password: userInfo.Token,
81-
ServerAddress: "r8.im",
82-
}
8379

80+
// Add any additional auth configs passed via options
8481
for _, opt := range clientOptions.authConfigs {
8582
authConfig[opt.ServerAddress] = opt
8683
}
@@ -209,8 +206,19 @@ func (c *apiClient) Push(ctx context.Context, imageRef string) error {
209206

210207
// eagerly set auth config, or do it async
211208
var authConfig registry.AuthConfig
212-
if auth, ok := c.authConfig[parsedName.Context().RegistryStr()]; ok {
209+
registryHost := parsedName.Context().RegistryStr()
210+
if auth, ok := c.authConfig[registryHost]; ok {
213211
authConfig = auth
212+
} else {
213+
// Dynamically load authentication for this registry if not already loaded
214+
authConfigs, err := loadRegistryAuths(ctx, registryHost)
215+
if err == nil {
216+
if auth, ok := authConfigs[registryHost]; ok {
217+
authConfig = auth
218+
// Cache the auth config for future use
219+
c.authConfig[registryHost] = auth
220+
}
221+
}
214222
}
215223

216224
var opts image.PushOptions

pkg/docker/credentials.go

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/docker/docker/api/types/registry"
1616

1717
"github.com/replicate/cog/pkg/docker/command"
18+
"github.com/replicate/cog/pkg/global"
1819
"github.com/replicate/cog/pkg/util/console"
1920
)
2021

@@ -47,49 +48,62 @@ func loadAuthFromConfig(conf *configfile.ConfigFile, registryHost string) (types
4748

4849
func loadRegistryAuths(ctx context.Context, registryHosts ...string) (map[string]registry.AuthConfig, error) {
4950
conf := config.LoadDefaultConfigFile(os.Stderr)
50-
5151
out := make(map[string]registry.AuthConfig)
5252

5353
for _, host := range registryHosts {
54-
console.Debugf("=== loadRegistryAuths %s", host)
55-
// check the credentials store first if set
56-
if conf.CredentialsStore != "" {
57-
console.Debugf("=== loadRegistryAuths %s: credentials store set", host)
58-
credsHelper, err := loadAuthFromCredentialsStore(ctx, conf.CredentialsStore, host)
59-
if err != nil {
60-
console.Debugf("=== loadRegistryAuths %s: error loading credentials store: %s", host, err)
61-
return nil, err
62-
}
63-
console.Debugf("=== loadRegistryAuths %s: credentials store loaded", host)
64-
out[host] = registry.AuthConfig{
65-
Username: credsHelper.Username,
66-
Password: credsHelper.Secret,
67-
ServerAddress: host,
68-
}
54+
// Try loading auth for the requested host
55+
auth, err := tryLoadAuthForHost(ctx, conf, host)
56+
if err == nil && auth != nil {
57+
out[host] = *auth
6958
continue
7059
}
7160

72-
// next, check if the auth config exists in the config file
73-
if auth, ok := conf.AuthConfigs[host]; ok {
74-
console.Debugf("=== loadRegistryAuths %s: auth config found in config file", host)
75-
out[host] = registry.AuthConfig{
76-
Username: auth.Username,
77-
Password: auth.Password,
78-
Auth: auth.Auth,
79-
Email: auth.Email,
80-
ServerAddress: host,
81-
IdentityToken: auth.IdentityToken,
82-
RegistryToken: auth.RegistryToken,
61+
// FALLBACK: If requesting alternate registry and no auth found,
62+
// try reusing r8.im credentials
63+
if host != global.DefaultReplicateRegistryHost {
64+
auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
65+
if err == nil && auth != nil {
66+
// Reuse credentials for the alternate registry
67+
auth.ServerAddress = host // Update to new host
68+
out[host] = *auth
69+
console.Infof("Using existing %s credentials for %s", global.DefaultReplicateRegistryHost, host)
70+
continue
8371
}
84-
continue
8572
}
86-
87-
console.Debugf("=== loadRegistryAuths %s: no auth config found", host)
8873
}
8974

9075
return out, nil
9176
}
9277

78+
func tryLoadAuthForHost(ctx context.Context, conf *configfile.ConfigFile, host string) (*registry.AuthConfig, error) {
79+
// Try credentials store first (e.g., osxkeychain, pass)
80+
if conf.CredentialsStore != "" {
81+
credsHelper, err := loadAuthFromCredentialsStore(ctx, conf.CredentialsStore, host)
82+
if err == nil {
83+
return &registry.AuthConfig{
84+
Username: credsHelper.Username,
85+
Password: credsHelper.Secret,
86+
ServerAddress: host,
87+
}, nil
88+
}
89+
}
90+
91+
// Fallback to config file
92+
if auth, ok := conf.AuthConfigs[host]; ok {
93+
return &registry.AuthConfig{
94+
Username: auth.Username,
95+
Password: auth.Password,
96+
Auth: auth.Auth,
97+
Email: auth.Email,
98+
ServerAddress: host,
99+
IdentityToken: auth.IdentityToken,
100+
RegistryToken: auth.RegistryToken,
101+
}, nil
102+
}
103+
104+
return nil, fmt.Errorf("no credentials found for %s", host)
105+
}
106+
93107
func loadAuthFromCredentialsStore(ctx context.Context, credsStore string, registryHost string) (*CredentialHelperInput, error) {
94108
var out strings.Builder
95109
binary := dockerCredentialBinary(credsStore)

pkg/docker/credentials_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package docker
2+
3+
import (
4+
"context"
5+
"path/filepath"
6+
"testing"
7+
8+
"github.com/docker/cli/cli/config/configfile"
9+
"github.com/docker/cli/cli/config/types"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/replicate/cog/pkg/global"
14+
)
15+
16+
func TestLoadRegistryAuths_Fallback(t *testing.T) {
17+
ctx := context.Background()
18+
19+
t.Run("uses credentials for requested host when available", func(t *testing.T) {
20+
// Create a mock config with credentials for the requested host
21+
conf := &configfile.ConfigFile{
22+
AuthConfigs: map[string]types.AuthConfig{
23+
"registry.example.com": {
24+
Username: "user1",
25+
Password: "pass1",
26+
},
27+
},
28+
}
29+
30+
auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
31+
require.NoError(t, err)
32+
require.NotNil(t, auth)
33+
assert.Equal(t, "user1", auth.Username)
34+
assert.Equal(t, "pass1", auth.Password)
35+
assert.Equal(t, "registry.example.com", auth.ServerAddress)
36+
})
37+
38+
t.Run("falls back to default registry credentials when alternate registry has no credentials", func(t *testing.T) {
39+
// Set up a temporary docker config file
40+
tmpDir := t.TempDir()
41+
dockerConfigPath := filepath.Join(tmpDir, "config.json")
42+
43+
// Create a config file with credentials only for the default registry
44+
conf := &configfile.ConfigFile{
45+
Filename: dockerConfigPath,
46+
AuthConfigs: map[string]types.AuthConfig{
47+
global.DefaultReplicateRegistryHost: {
48+
Username: "defaultuser",
49+
Password: "defaultpass",
50+
},
51+
},
52+
}
53+
require.NoError(t, conf.Save())
54+
55+
// Point Docker to our test config
56+
t.Setenv("DOCKER_CONFIG", tmpDir)
57+
58+
// Try loading auth for an alternate registry that doesn't have credentials
59+
auths, err := loadRegistryAuths(ctx, "registry.example.com")
60+
require.NoError(t, err)
61+
require.NotNil(t, auths)
62+
63+
// Should have fallen back to default registry credentials
64+
auth, ok := auths["registry.example.com"]
65+
require.True(t, ok, "should have auth for registry.example.com")
66+
assert.Equal(t, "defaultuser", auth.Username)
67+
assert.Equal(t, "defaultpass", auth.Password)
68+
assert.Equal(t, "registry.example.com", auth.ServerAddress, "server address should be updated to the requested host")
69+
})
70+
71+
t.Run("does not fallback when requesting default registry", func(t *testing.T) {
72+
// This test uses tryLoadAuthForHost directly to avoid credential store issues
73+
conf := &configfile.ConfigFile{
74+
AuthConfigs: map[string]types.AuthConfig{},
75+
}
76+
77+
// Try loading auth for the default registry
78+
auth, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
79+
require.Error(t, err, "should error when no credentials found")
80+
assert.Nil(t, auth)
81+
assert.Contains(t, err.Error(), "no credentials found")
82+
})
83+
84+
t.Run("prefers direct credentials over fallback", func(t *testing.T) {
85+
// Create a mock config with credentials for both registries
86+
conf := &configfile.ConfigFile{
87+
AuthConfigs: map[string]types.AuthConfig{
88+
global.DefaultReplicateRegistryHost: {
89+
Username: "defaultuser",
90+
Password: "defaultpass",
91+
},
92+
"registry.example.com": {
93+
Username: "directuser",
94+
Password: "directpass",
95+
},
96+
},
97+
}
98+
99+
// Try loading auth for the alternate registry
100+
auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
101+
require.NoError(t, err)
102+
require.NotNil(t, auth)
103+
104+
// Should use direct credentials, not fallback
105+
assert.Equal(t, "directuser", auth.Username)
106+
assert.Equal(t, "directpass", auth.Password)
107+
assert.Equal(t, "registry.example.com", auth.ServerAddress)
108+
})
109+
110+
t.Run("returns empty map when no credentials available", func(t *testing.T) {
111+
// This test uses tryLoadAuthForHost to avoid credential store issues
112+
// The loadRegistryAuths function doesn't error when no credentials are found,
113+
// it just returns an empty map
114+
conf := &configfile.ConfigFile{
115+
AuthConfigs: map[string]types.AuthConfig{},
116+
}
117+
118+
// Try loading auth for an alternate registry (will fail)
119+
auth1, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
120+
require.Error(t, err)
121+
assert.Nil(t, auth1)
122+
123+
// Try loading auth for default registry (will also fail)
124+
auth2, err := tryLoadAuthForHost(ctx, conf, global.DefaultReplicateRegistryHost)
125+
require.Error(t, err)
126+
assert.Nil(t, auth2)
127+
128+
// Since both fail, loadRegistryAuths would return an empty map
129+
// (it doesn't error, just silently skips hosts without credentials)
130+
})
131+
}
132+
133+
func TestTryLoadAuthForHost(t *testing.T) {
134+
ctx := context.Background()
135+
136+
t.Run("loads auth from config file", func(t *testing.T) {
137+
conf := &configfile.ConfigFile{
138+
AuthConfigs: map[string]types.AuthConfig{
139+
"registry.example.com": {
140+
Username: "testuser",
141+
Password: "testpass",
142+
Auth: "dGVzdHVzZXI6dGVzdHBhc3M=",
143+
144+
},
145+
},
146+
}
147+
148+
auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
149+
require.NoError(t, err)
150+
require.NotNil(t, auth)
151+
assert.Equal(t, "testuser", auth.Username)
152+
assert.Equal(t, "testpass", auth.Password)
153+
assert.Equal(t, "dGVzdHVzZXI6dGVzdHBhc3M=", auth.Auth)
154+
assert.Equal(t, "[email protected]", auth.Email)
155+
assert.Equal(t, "registry.example.com", auth.ServerAddress)
156+
})
157+
158+
t.Run("returns error when no auth found", func(t *testing.T) {
159+
conf := &configfile.ConfigFile{
160+
AuthConfigs: map[string]types.AuthConfig{},
161+
}
162+
163+
auth, err := tryLoadAuthForHost(ctx, conf, "registry.example.com")
164+
require.Error(t, err)
165+
assert.Nil(t, auth)
166+
assert.Contains(t, err.Error(), "no credentials found")
167+
})
168+
}

pkg/global/global.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
package global
22

3+
import "os"
4+
5+
const (
6+
DefaultReplicateRegistryHost = "r8.im"
7+
)
8+
39
var (
410
Version = "dev"
511
Commit = ""
612
BuildTime = "none"
713
Debug = false
814
ProfilingEnabled = false
9-
ReplicateRegistryHost = "r8.im"
15+
ReplicateRegistryHost = getDefaultRegistryHost()
1016
ReplicateWebsiteHost = "replicate.com"
1117
LabelNamespace = "run.cog."
1218
CogBuildArtifactsFolder = ".cog"
1319
)
20+
21+
func getDefaultRegistryHost() string {
22+
// Priority: flag will override at runtime, but env var provides default
23+
if host := os.Getenv("COG_REGISTRY_HOST"); host != "" {
24+
return host
25+
}
26+
return DefaultReplicateRegistryHost
27+
}

0 commit comments

Comments
 (0)