Skip to content

Commit 84b9c99

Browse files
authored
Allow to set a local registry via file path (#1223)
Signed-off-by: Radoslav Dimitrov <[email protected]>
1 parent 042994b commit 84b9c99

File tree

5 files changed

+127
-20
lines changed

5 files changed

+127
-20
lines changed

pkg/auth/token_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ func writeTestServerCert(t *testing.T, server *httptest.Server) string {
380380
cert := server.Certificate()
381381
if cert == nil {
382382
t.Fatal("Test server has no certificate")
383+
return ""
383384
}
384385

385386
// Create temp file

pkg/config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type Config struct {
2626
Secrets Secrets `yaml:"secrets"`
2727
Clients Clients `yaml:"clients"`
2828
RegistryUrl string `yaml:"registry_url"`
29+
LocalRegistryPath string `yaml:"local_registry_path"`
2930
AllowPrivateRegistryIp bool `yaml:"allow_private_registry_ip"`
3031
CACertificatePath string `yaml:"ca_certificate_path,omitempty"`
3132
OTEL OpenTelemetryConfig `yaml:"otel,omitempty"`

pkg/registry/factory.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ func NewRegistryProvider(cfg *config.Config) Provider {
1717
if cfg != nil && len(cfg.RegistryUrl) > 0 {
1818
return NewRemoteRegistryProvider(cfg.RegistryUrl, cfg.AllowPrivateRegistryIp)
1919
}
20-
return NewEmbeddedRegistryProvider()
20+
if cfg != nil && len(cfg.LocalRegistryPath) > 0 {
21+
return NewLocalRegistryProvider(cfg.LocalRegistryPath)
22+
}
23+
return NewLocalRegistryProvider()
2124
}
2225

2326
// GetDefaultProvider returns the default registry provider instance

pkg/registry/provider_embedded.go renamed to pkg/registry/provider_local.go

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,45 @@ import (
44
"embed"
55
"encoding/json"
66
"fmt"
7+
"os"
78
"strings"
89
)
910

1011
//go:embed data/registry.json
1112
var embeddedRegistryFS embed.FS
1213

13-
// EmbeddedRegistryProvider provides registry data from embedded JSON files
14-
type EmbeddedRegistryProvider struct {
14+
// LocalRegistryProvider provides registry data from embedded JSON files or local files
15+
type LocalRegistryProvider struct {
16+
filePath string
1517
}
1618

17-
// NewEmbeddedRegistryProvider creates a new embedded registry provider
18-
func NewEmbeddedRegistryProvider() *EmbeddedRegistryProvider {
19-
return &EmbeddedRegistryProvider{}
19+
// NewLocalRegistryProvider creates a new local registry provider
20+
// If filePath is provided, it will read from that file; otherwise uses embedded data
21+
func NewLocalRegistryProvider(filePath ...string) *LocalRegistryProvider {
22+
var path string
23+
if len(filePath) > 0 {
24+
path = filePath[0]
25+
}
26+
return &LocalRegistryProvider{filePath: path}
2027
}
2128

22-
// GetRegistry returns the embedded registry data
23-
func (*EmbeddedRegistryProvider) GetRegistry() (*Registry, error) {
24-
data, err := embeddedRegistryFS.ReadFile("data/registry.json")
25-
if err != nil {
26-
return nil, fmt.Errorf("failed to read embedded registry data: %w", err)
29+
// GetRegistry returns the registry data from file path or embedded data
30+
func (p *LocalRegistryProvider) GetRegistry() (*Registry, error) {
31+
var data []byte
32+
var err error
33+
34+
if p.filePath != "" {
35+
// Read from local file
36+
data, err = os.ReadFile(p.filePath)
37+
if err != nil {
38+
return nil, fmt.Errorf("failed to read local registry file %s: %w", p.filePath, err)
39+
}
40+
} else {
41+
// Read from embedded data
42+
data, err = embeddedRegistryFS.ReadFile("data/registry.json")
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to read embedded registry data: %w", err)
45+
}
2746
}
2847

2948
registry, err := parseRegistryData(data)
@@ -40,7 +59,7 @@ func (*EmbeddedRegistryProvider) GetRegistry() (*Registry, error) {
4059
}
4160

4261
// GetServer returns a specific server by name
43-
func (p *EmbeddedRegistryProvider) GetServer(name string) (*ImageMetadata, error) {
62+
func (p *LocalRegistryProvider) GetServer(name string) (*ImageMetadata, error) {
4463
reg, err := p.GetRegistry()
4564
if err != nil {
4665
return nil, err
@@ -55,7 +74,7 @@ func (p *EmbeddedRegistryProvider) GetServer(name string) (*ImageMetadata, error
5574
}
5675

5776
// SearchServers searches for servers matching the query
58-
func (p *EmbeddedRegistryProvider) SearchServers(query string) ([]*ImageMetadata, error) {
77+
func (p *LocalRegistryProvider) SearchServers(query string) ([]*ImageMetadata, error) {
5978
reg, err := p.GetRegistry()
6079
if err != nil {
6180
return nil, err
@@ -90,7 +109,7 @@ func (p *EmbeddedRegistryProvider) SearchServers(query string) ([]*ImageMetadata
90109
}
91110

92111
// ListServers returns all available servers
93-
func (p *EmbeddedRegistryProvider) ListServers() ([]*ImageMetadata, error) {
112+
func (p *LocalRegistryProvider) ListServers() ([]*ImageMetadata, error) {
94113
reg, err := p.GetRegistry()
95114
if err != nil {
96115
return nil, err

pkg/registry/provider_test.go

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package registry
22

33
import (
4+
"os"
5+
"path/filepath"
46
"testing"
57

68
"github.com/stacklok/toolhive/pkg/config"
@@ -16,14 +18,14 @@ func TestNewRegistryProvider(t *testing.T) {
1618
{
1719
name: "nil config returns embedded provider",
1820
config: nil,
19-
expectedType: "*registry.EmbeddedRegistryProvider",
21+
expectedType: "*registry.LocalRegistryProvider",
2022
},
2123
{
2224
name: "empty registry URL returns embedded provider",
2325
config: &config.Config{
2426
RegistryUrl: "",
2527
},
26-
expectedType: "*registry.EmbeddedRegistryProvider",
28+
expectedType: "*registry.LocalRegistryProvider",
2729
},
2830
{
2931
name: "registry URL returns remote provider",
@@ -32,6 +34,21 @@ func TestNewRegistryProvider(t *testing.T) {
3234
},
3335
expectedType: "*registry.RemoteRegistryProvider",
3436
},
37+
{
38+
name: "local registry path returns embedded provider with file path",
39+
config: &config.Config{
40+
LocalRegistryPath: "/path/to/registry.json",
41+
},
42+
expectedType: "*registry.LocalRegistryProvider",
43+
},
44+
{
45+
name: "registry URL takes precedence over local path",
46+
config: &config.Config{
47+
RegistryUrl: "https://example.com/registry.json",
48+
LocalRegistryPath: "/path/to/registry.json",
49+
},
50+
expectedType: "*registry.RemoteRegistryProvider",
51+
},
3552
}
3653

3754
for _, tt := range tests {
@@ -48,9 +65,9 @@ func TestNewRegistryProvider(t *testing.T) {
4865
}
4966
}
5067

51-
func TestEmbeddedRegistryProvider(t *testing.T) {
68+
func TestLocalRegistryProvider(t *testing.T) {
5269
t.Parallel()
53-
provider := NewEmbeddedRegistryProvider()
70+
provider := NewLocalRegistryProvider()
5471

5572
// Test GetRegistry
5673
registry, err := provider.GetRegistry()
@@ -118,11 +135,77 @@ func TestRemoteRegistryProvider(t *testing.T) {
118135
var _ Provider = provider
119136
}
120137

138+
func TestLocalRegistryProviderWithLocalFile(t *testing.T) {
139+
t.Parallel()
140+
141+
// Create a temporary registry file
142+
tempDir := t.TempDir()
143+
registryFile := filepath.Join(tempDir, "test_registry.json")
144+
145+
// Write test registry data
146+
testRegistry := `{
147+
"version": "1.0.0",
148+
"last_updated": "2023-01-01T00:00:00Z",
149+
"servers": {
150+
"test-server": {
151+
"image": "test/image:latest",
152+
"description": "Test server",
153+
"tier": "community",
154+
"status": "active",
155+
"transport": "stdio",
156+
"permissions": {
157+
"allow_local_resource_access": false,
158+
"allow_internet_access": false
159+
},
160+
"tools": ["test-tool"],
161+
"env_vars": [],
162+
"args": []
163+
}
164+
}
165+
}`
166+
167+
err := os.WriteFile(registryFile, []byte(testRegistry), 0644)
168+
if err != nil {
169+
t.Fatalf("Failed to write test registry file: %v", err)
170+
}
171+
172+
// Test with local file path
173+
provider := NewLocalRegistryProvider(registryFile)
174+
175+
// Test GetRegistry
176+
registry, err := provider.GetRegistry()
177+
if err != nil {
178+
t.Fatalf("GetRegistry() error = %v", err)
179+
}
180+
181+
if registry == nil {
182+
t.Fatal("GetRegistry() returned nil registry")
183+
return
184+
}
185+
186+
if len(registry.Servers) != 1 {
187+
t.Errorf("Expected 1 server, got %d", len(registry.Servers))
188+
}
189+
190+
server, exists := registry.Servers["test-server"]
191+
if !exists {
192+
t.Error("Expected test-server to exist in registry")
193+
}
194+
195+
if server.Name != "test-server" {
196+
t.Errorf("Expected server name 'test-server', got '%s'", server.Name)
197+
}
198+
199+
if server.Image != "test/image:latest" {
200+
t.Errorf("Expected image 'test/image:latest', got '%s'", server.Image)
201+
}
202+
}
203+
121204
// getTypeName returns the type name of an interface value
122205
func getTypeName(v interface{}) string {
123206
switch v.(type) {
124-
case *EmbeddedRegistryProvider:
125-
return "*registry.EmbeddedRegistryProvider"
207+
case *LocalRegistryProvider:
208+
return "*registry.LocalRegistryProvider"
126209
case *RemoteRegistryProvider:
127210
return "*registry.RemoteRegistryProvider"
128211
default:

0 commit comments

Comments
 (0)