Skip to content

Commit 50f8c98

Browse files
authored
Implement option to disable private IP blocking for remote registries (#744)
1 parent b69de2e commit 50f8c98

File tree

9 files changed

+167
-63
lines changed

9 files changed

+167
-63
lines changed

cmd/thv/app/config.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
rt "github.com/stacklok/toolhive/pkg/container/runtime"
1717
"github.com/stacklok/toolhive/pkg/labels"
1818
"github.com/stacklok/toolhive/pkg/logger"
19+
"github.com/stacklok/toolhive/pkg/networking"
1920
"github.com/stacklok/toolhive/pkg/transport"
2021
)
2122

@@ -114,6 +115,10 @@ var unsetRegistryURLCmd = &cobra.Command{
114115
RunE: unsetRegistryURLCmdFunc,
115116
}
116117

118+
var (
119+
allowPrivateRegistryIp bool
120+
)
121+
117122
func init() {
118123
// Add config command to root command
119124
rootCmd.AddCommand(configCmd)
@@ -126,8 +131,16 @@ func init() {
126131
configCmd.AddCommand(getCACertCmd)
127132
configCmd.AddCommand(unsetCACertCmd)
128133
configCmd.AddCommand(setRegistryURLCmd)
134+
setRegistryURLCmd.Flags().BoolVarP(
135+
&allowPrivateRegistryIp,
136+
"allow-private-ip",
137+
"p",
138+
false,
139+
"Allow setting the registry URL, even if it references a private IP address",
140+
)
129141
configCmd.AddCommand(getRegistryURLCmd)
130142
configCmd.AddCommand(unsetRegistryURLCmd)
143+
131144
}
132145

133146
func registerClientCmdFunc(cmd *cobra.Command, args []string) error {
@@ -385,15 +398,33 @@ func setRegistryURLCmdFunc(_ *cobra.Command, args []string) error {
385398
return fmt.Errorf("registry URL must start with http:// or https://")
386399
}
387400

401+
if !allowPrivateRegistryIp {
402+
registryClient := networking.GetHttpClient(false)
403+
_, err := registryClient.Get(registryURL)
404+
if err != nil && strings.Contains(fmt.Sprint(err), networking.ErrPrivateIpAddress) {
405+
return err
406+
}
407+
}
408+
388409
// Update the configuration
389410
err := config.UpdateConfig(func(c *config.Config) {
390411
c.RegistryUrl = registryURL
412+
c.AllowPrivateRegistryIp = allowPrivateRegistryIp
391413
})
392414
if err != nil {
393415
return fmt.Errorf("failed to update configuration: %w", err)
394416
}
395417

396418
fmt.Printf("Successfully set registry URL: %s\n", registryURL)
419+
if allowPrivateRegistryIp {
420+
fmt.Print("Successfully enabled use of private IP addresses for the remote registry\n")
421+
fmt.Print("Caution: allowing registry URLs containing private IP addresses may decrease your security.\n" +
422+
"Make sure you trust any remote registries you configure with ToolHive.")
423+
} else {
424+
fmt.Printf("Use of private IP addresses for the remote registry has been disabled" +
425+
" as it's not needed for the provided registry.\n")
426+
}
427+
397428
return nil
398429
}
399430

docs/cli/thv_config_set-registry-url.md

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/config/config.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ const lockTimeout = 1 * time.Second
2323

2424
// Config represents the configuration of the application.
2525
type Config struct {
26-
Secrets Secrets `yaml:"secrets"`
27-
Clients Clients `yaml:"clients"`
28-
RegistryUrl string `yaml:"registry_url"`
29-
CACertificatePath string `yaml:"ca_certificate_path,omitempty"`
26+
Secrets Secrets `yaml:"secrets"`
27+
Clients Clients `yaml:"clients"`
28+
RegistryUrl string `yaml:"registry_url"`
29+
AllowPrivateRegistryIp bool `yaml:"allow_private_registry_ip"`
30+
CACertificatePath string `yaml:"ca_certificate_path,omitempty"`
3031
}
3132

3233
// Secrets contains the settings for secrets management.
@@ -89,7 +90,8 @@ func createNewConfigWithDefaults() Config {
8990
ProviderType: "", // No default provider - user must run setup
9091
SetupCompleted: false,
9192
},
92-
RegistryUrl: "",
93+
RegistryUrl: "",
94+
AllowPrivateRegistryIp: false,
9395
}
9496
}
9597

pkg/config/config_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,48 @@ func TestRegistryURLConfig(t *testing.T) {
208208
}
209209
})
210210
})
211+
212+
t.Run("TestAllowPrivateRegistryIp", func(t *testing.T) {
213+
t.Parallel()
214+
tempDir, configPath := SetupTestConfig(t, &Config{
215+
Secrets: Secrets{
216+
ProviderType: string(secrets.EncryptedType),
217+
},
218+
Clients: Clients{
219+
RegisteredClients: []string{},
220+
},
221+
RegistryUrl: "",
222+
AllowPrivateRegistryIp: false,
223+
})
224+
225+
// Test enabling
226+
err := UpdateConfigAtPath(configPath, func(c *Config) {
227+
c.AllowPrivateRegistryIp = true
228+
})
229+
require.NoError(t, err)
230+
231+
// Load the config and verify the setting was toggled to true
232+
config, err := LoadOrCreateConfigWithPath(configPath)
233+
require.NoError(t, err)
234+
assert.Equal(t, true, config.AllowPrivateRegistryIp)
235+
236+
// Test toggling setting to false
237+
err = UpdateConfigAtPath(configPath, func(c *Config) {
238+
c.AllowPrivateRegistryIp = false
239+
})
240+
require.NoError(t, err)
241+
242+
// Load the config and verify the setting was toggled to false
243+
config, err = LoadOrCreateConfigWithPath(configPath)
244+
require.NoError(t, err)
245+
assert.Equal(t, false, config.AllowPrivateRegistryIp)
246+
247+
t.Cleanup(func() {
248+
if err := os.RemoveAll(tempDir); err != nil {
249+
t.Logf("Failed to remove temp dir: %v", err)
250+
}
251+
})
252+
})
211253
}
212254

213255
func TestSecrets_GetProviderType_EnvironmentVariable(t *testing.T) {

pkg/networking/http_client.go

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package networking
22

33
import (
4-
"errors"
54
"fmt"
65
"net"
76
"net/http"
@@ -15,51 +14,14 @@ var privateIPBlocks []*net.IPNet
1514
// HttpTimeout is the timeout for outgoing HTTP requests
1615
const HttpTimeout = 30 * time.Second
1716

18-
func init() {
19-
for _, cidr := range []string{
20-
"127.0.0.0/8", // IPv4 loopback
21-
"10.0.0.0/8", // RFC1918
22-
"172.16.0.0/12", // RFC1918
23-
"192.168.0.0/16", // RFC1918
24-
"169.254.0.0/16", // RFC3927 link-local
25-
"::1/128", // IPv6 loopback
26-
"fe80::/10", // IPv6 link-local
27-
"fc00::/7", // IPv6 unique local addr
28-
} {
29-
_, block, err := net.ParseCIDR(cidr)
30-
if err != nil {
31-
panic(fmt.Errorf("parse error on %q: %v", cidr, err))
32-
}
33-
privateIPBlocks = append(privateIPBlocks, block)
34-
}
35-
}
36-
37-
func isPrivateIP(ip net.IP) bool {
38-
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
39-
return true
40-
}
41-
for _, block := range privateIPBlocks {
42-
if block.Contains(ip) {
43-
return true
44-
}
45-
}
46-
return false
47-
}
48-
4917
// Dialer control function for validating addresses prior to connection
50-
func protectedDialerControl(network, address string, _ syscall.RawConn) error {
51-
52-
fmt.Printf("protectedDialerControl: %s, %s\n", network, address)
18+
func protectedDialerControl(_, address string, _ syscall.RawConn) error {
5319

54-
host, _, err := net.SplitHostPort(address)
20+
err := AddressReferencesPrivateIp(address)
5521
if err != nil {
5622
return err
5723
}
58-
// Check for a private IP address or loopback
59-
ip := net.ParseIP(host)
60-
if isPrivateIP(ip) {
61-
return errors.New("private IP address not allowed")
62-
}
24+
6325
return nil
6426
}
6527

@@ -85,18 +47,23 @@ func (t *ValidatingTransport) RoundTrip(req *http.Request) (*http.Response, erro
8547
return t.Transport.RoundTrip(req)
8648
}
8749

88-
// GetProtectedHttpClient returns a new http client with a protected dialer and URL validation
89-
func GetProtectedHttpClient() *http.Client {
50+
// GetHttpClient returns a new http client which uses a protected dialer and URL validation by default
51+
func GetHttpClient(allowPrivateIp bool) *http.Client {
9052

91-
protectedTransport := &http.Transport{
92-
DialContext: (&net.Dialer{
93-
Control: protectedDialerControl,
94-
}).DialContext,
53+
var transport *http.Transport
54+
if !allowPrivateIp {
55+
transport = &http.Transport{
56+
DialContext: (&net.Dialer{
57+
Control: protectedDialerControl,
58+
}).DialContext,
59+
}
60+
} else {
61+
transport = &http.Transport{}
9562
}
9663

9764
client := &http.Client{
9865
Transport: &ValidatingTransport{
99-
Transport: protectedTransport,
66+
Transport: transport,
10067
},
10168
Timeout: HttpTimeout,
10269
}

pkg/networking/utilities.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package networking
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net"
7+
)
8+
9+
const (
10+
// ErrPrivateIpAddress is the error returned when the provided URL redirects to a private IP address
11+
ErrPrivateIpAddress = "the provided registry URL redirects to a private IP address, which is not allowed; " +
12+
"to override this, reset the registry URL using the --allow-private-ip (-p) flag"
13+
)
14+
15+
func init() {
16+
for _, cidr := range []string{
17+
"127.0.0.0/8", // IPv4 loopback
18+
"10.0.0.0/8", // RFC1918
19+
"172.16.0.0/12", // RFC1918
20+
"192.168.0.0/16", // RFC1918
21+
"169.254.0.0/16", // RFC3927 link-local
22+
"::1/128", // IPv6 loopback
23+
"fe80::/10", // IPv6 link-local
24+
"fc00::/7", // IPv6 unique local addr
25+
} {
26+
_, block, err := net.ParseCIDR(cidr)
27+
if err != nil {
28+
panic(fmt.Errorf("parse error on %q: %v", cidr, err))
29+
}
30+
privateIPBlocks = append(privateIPBlocks, block)
31+
}
32+
}
33+
34+
func isPrivateIP(ip net.IP) bool {
35+
if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
36+
return true
37+
}
38+
for _, block := range privateIPBlocks {
39+
if block.Contains(ip) {
40+
return true
41+
}
42+
}
43+
return false
44+
}
45+
46+
// AddressReferencesPrivateIp returns an error if the address references a private IP address
47+
func AddressReferencesPrivateIp(address string) error {
48+
host, _, err := net.SplitHostPort(address)
49+
if err != nil {
50+
return err
51+
}
52+
// Check for a private IP address or loopback
53+
ip := net.ParseIP(host)
54+
if isPrivateIP(ip) {
55+
return errors.New(ErrPrivateIpAddress)
56+
}
57+
58+
return nil
59+
}

pkg/registry/factory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ var (
1515
// NewRegistryProvider creates a new registry provider based on the configuration
1616
func NewRegistryProvider(cfg *config.Config) Provider {
1717
if cfg != nil && len(cfg.RegistryUrl) > 0 {
18-
return NewRemoteRegistryProvider(cfg.RegistryUrl)
18+
return NewRemoteRegistryProvider(cfg.RegistryUrl, cfg.AllowPrivateRegistryIp)
1919
}
2020
return NewEmbeddedRegistryProvider()
2121
}

pkg/registry/provider_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func TestRemoteRegistryProvider(t *testing.T) {
108108
t.Parallel()
109109
// Note: This test would require a mock HTTP server for full testing
110110
// For now, we just test the creation
111-
provider := NewRemoteRegistryProvider("https://example.com/registry.json")
111+
provider := NewRemoteRegistryProvider("https://example.com/registry.json", false)
112112

113113
if provider == nil {
114114
t.Fatal("NewRemoteRegistryProvider() returned nil")

pkg/registry/remote_provider.go

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@ import (
1212

1313
// RemoteRegistryProvider provides registry data from a remote HTTP endpoint
1414
type RemoteRegistryProvider struct {
15-
registryURL string
16-
registry *Registry
17-
registryOnce sync.Once
18-
registryErr error
15+
registryURL string
16+
allowPrivateIp bool
17+
registry *Registry
18+
registryOnce sync.Once
19+
registryErr error
1920
}
2021

2122
// NewRemoteRegistryProvider creates a new remote registry provider
22-
func NewRemoteRegistryProvider(registryURL string) *RemoteRegistryProvider {
23+
func NewRemoteRegistryProvider(registryURL string, allowPrivateIp bool) *RemoteRegistryProvider {
2324
return &RemoteRegistryProvider{
24-
registryURL: registryURL,
25+
registryURL: registryURL,
26+
allowPrivateIp: allowPrivateIp,
2527
}
2628
}
2729

2830
// GetRegistry returns the remote registry data
2931
func (p *RemoteRegistryProvider) GetRegistry() (*Registry, error) {
3032
p.registryOnce.Do(func() {
31-
client := networking.GetProtectedHttpClient()
33+
client := networking.GetHttpClient(p.allowPrivateIp)
3234
resp, err := client.Get(p.registryURL)
3335
if err != nil {
3436
p.registryErr = fmt.Errorf("failed to fetch registry data from URL %s: %w", p.registryURL, err)

0 commit comments

Comments
 (0)