Skip to content

Commit c195b62

Browse files
authored
Create client config file before writing to it (#857)
1 parent d93cc6d commit c195b62

File tree

2 files changed

+110
-59
lines changed

2 files changed

+110
-59
lines changed

pkg/client/config.go

Lines changed: 99 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package client
44

55
import (
6+
"errors"
67
"fmt"
78
"os"
89
"path/filepath"
@@ -59,6 +60,11 @@ type mcpClientConfig struct {
5960
IsTransportTypeFieldSupported bool
6061
}
6162

63+
var (
64+
// ErrConfigFileNotFound is returned when a client configuration file is not found
65+
ErrConfigFileNotFound = fmt.Errorf("client config file not found")
66+
)
67+
6268
var supportedClientIntegrations = []mcpClientConfig{
6369
{
6470
ClientType: RooCode,
@@ -190,18 +196,22 @@ type MCPServerConfig struct {
190196

191197
// FindClientConfig returns the client configuration file for a given client type.
192198
func FindClientConfig(clientType MCPClient) (*ConfigFile, error) {
193-
configFiles, err := FindClientConfigs()
199+
// retrieve the metadata of the config files
200+
configFile, err := retrieveConfigFileMetadata(clientType)
194201
if err != nil {
195-
return nil, fmt.Errorf("failed to fetch client configurations: %w", err)
196-
}
197-
198-
for _, cf := range configFiles {
199-
if cf.ClientType == clientType {
200-
return &cf, nil
202+
if errors.Is(err, ErrConfigFileNotFound) {
203+
// Propagate the error if the file is not found
204+
return nil, fmt.Errorf("%w: for client %s", ErrConfigFileNotFound, clientType)
201205
}
206+
return nil, err
202207
}
203208

204-
return nil, fmt.Errorf("client configuration for %s not found", clientType)
209+
// validate the format of the config files
210+
err = validateConfigFileFormat(configFile)
211+
if err != nil {
212+
return nil, fmt.Errorf("failed to validate config file format: %w", err)
213+
}
214+
return configFile, nil
205215
}
206216

207217
// FindClientConfigs searches for client configuration files in standard locations
@@ -211,25 +221,58 @@ func FindClientConfigs() ([]ConfigFile, error) {
211221
return nil, fmt.Errorf("failed to get client status: %w", err)
212222
}
213223

214-
notInstalledClients := make(map[string]bool)
224+
var configFiles []ConfigFile
215225
for _, clientStatus := range clientStatuses {
216226
if !clientStatus.Installed {
217-
notInstalledClients[string(clientStatus.ClientType)] = true
227+
continue
218228
}
229+
cf, err := FindClientConfig(clientStatus.ClientType)
230+
if err != nil {
231+
return nil, fmt.Errorf("failed to find client config for %s: %w", clientStatus.ClientType, err)
232+
}
233+
configFiles = append(configFiles, *cf)
219234
}
220235

221-
// retrieve the metadata of the config files
222-
configFiles, err := retrieveConfigFilesMetadata(notInstalledClients)
236+
return configFiles, nil
237+
}
238+
239+
// CreateClientConfig creates a new client configuration file for a given client type.
240+
func CreateClientConfig(clientType MCPClient) (*ConfigFile, error) {
241+
// Get home directory
242+
home, err := os.UserHomeDir()
223243
if err != nil {
224-
return nil, fmt.Errorf("failed to retrieve client config metadata: %w", err)
244+
return nil, fmt.Errorf("failed to get home directory: %w", err)
225245
}
226246

227-
// validate the format of the config files
228-
err = validateConfigFilesFormat(configFiles)
247+
// Find the configuration for the requested client type
248+
var clientCfg *mcpClientConfig
249+
for _, cfg := range supportedClientIntegrations {
250+
if cfg.ClientType == clientType {
251+
clientCfg = &cfg
252+
break
253+
}
254+
}
255+
256+
if clientCfg == nil {
257+
return nil, fmt.Errorf("unsupported client type: %s", clientType)
258+
}
259+
260+
// Build the path to the configuration file
261+
path := buildConfigFilePath(clientCfg.SettingsFile, clientCfg.RelPath, clientCfg.PlatformPrefix, []string{home})
262+
263+
// Validate that the file does not already exist
264+
if _, err := os.Stat(path); !os.IsNotExist(err) {
265+
return nil, fmt.Errorf("client config file already exists at %s", path)
266+
}
267+
268+
// Create the file if it does not exist
269+
logger.Infof("Creating new client config file at %s", path)
270+
err = os.WriteFile(path, []byte("{}"), 0600)
229271
if err != nil {
230-
return nil, fmt.Errorf("failed to validate config file format: %w", err)
272+
return nil, fmt.Errorf("failed to create client config file: %w", err)
231273
}
232-
return configFiles, nil
274+
275+
return FindClientConfig(clientType)
233276
}
234277

235278
// Upsert updates/inserts an MCP server in a client configuration file
@@ -265,44 +308,48 @@ func GenerateMCPServerURL(transportType string, host string, port int, container
265308
return ""
266309
}
267310

268-
// retrieveConfigFilesMetadata retrieves the metadata for client configuration files.
269-
// It returns a list of ConfigFile objects, which contain metadata about the file that
270-
// can be used when performing operations on the file.
271-
func retrieveConfigFilesMetadata(filters map[string]bool) ([]ConfigFile, error) {
272-
var configFiles []ConfigFile
273-
311+
// retrieveConfigFileMetadata retrieves the metadata for client configuration files for a given client type.
312+
func retrieveConfigFileMetadata(clientType MCPClient) (*ConfigFile, error) {
274313
// Get home directory
275314
home, err := os.UserHomeDir()
276315
if err != nil {
277316
return nil, fmt.Errorf("failed to get home directory: %w", err)
278317
}
279318

319+
// Find the configuration for the requested client type
320+
var clientCfg *mcpClientConfig
280321
for _, cfg := range supportedClientIntegrations {
281-
if filters[string(cfg.ClientType)] {
282-
continue
322+
if cfg.ClientType == clientType {
323+
clientCfg = &cfg
324+
break
283325
}
326+
}
284327

285-
path := buildConfigFilePath(cfg.SettingsFile, cfg.RelPath, cfg.PlatformPrefix, []string{home})
286-
287-
err := validateConfigFileExists(path)
288-
if err != nil {
289-
logger.Warnf("failed to validate config file: %w", err)
290-
continue
291-
}
328+
if clientCfg == nil {
329+
return nil, fmt.Errorf("unsupported client type: %s", clientType)
330+
}
292331

293-
configUpdater := &JSONConfigUpdater{Path: path, MCPServersPathPrefix: cfg.MCPServersPathPrefix}
332+
// Build the path to the configuration file
333+
path := buildConfigFilePath(clientCfg.SettingsFile, clientCfg.RelPath, clientCfg.PlatformPrefix, []string{home})
294334

295-
clientConfig := ConfigFile{
296-
Path: path,
297-
ConfigUpdater: configUpdater,
298-
ClientType: cfg.ClientType,
299-
Extension: cfg.Extension,
300-
}
335+
// Validate that the file exists
336+
if err := validateConfigFileExists(path); err != nil {
337+
return nil, err
338+
}
301339

302-
configFiles = append(configFiles, clientConfig)
340+
// Create a config updater for this file
341+
configUpdater := &JSONConfigUpdater{
342+
Path: path,
343+
MCPServersPathPrefix: clientCfg.MCPServersPathPrefix,
303344
}
304345

305-
return configFiles, nil
346+
// Return the configuration file metadata
347+
return &ConfigFile{
348+
Path: path,
349+
ConfigUpdater: configUpdater,
350+
ClientType: clientCfg.ClientType,
351+
Extension: clientCfg.Extension,
352+
}, nil
306353
}
307354

308355
func buildConfigFilePath(settingsFile string, relPath []string, platformPrefix map[string][]string, path []string) string {
@@ -317,27 +364,22 @@ func buildConfigFilePath(settingsFile string, relPath []string, platformPrefix m
317364
// validateConfigFileExists validates that a client configuration file exists.
318365
func validateConfigFileExists(path string) error {
319366
if _, err := os.Stat(path); os.IsNotExist(err) {
320-
return fmt.Errorf("file does not exist: %s", path)
367+
return ErrConfigFileNotFound
321368
}
322369
return nil
323370
}
324371

325-
// validateConfigFileFormat validates the format of a client configuration file
326-
// It returns an error if the file is not valid JSON.
327-
func validateConfigFilesFormat(configFiles []ConfigFile) error {
328-
for _, cf := range configFiles {
329-
data, err := os.ReadFile(cf.Path)
330-
if err != nil {
331-
return fmt.Errorf("failed to read file %s: %w", cf.Path, err)
332-
}
333-
334-
// Default to JSON
335-
// we don't care about the contents of the file, we just want to validate that it's valid JSON
336-
_, err = hujson.Parse(data)
337-
if err != nil {
338-
return fmt.Errorf("failed to parse JSON for file %s: %w", cf.Path, err)
339-
}
372+
func validateConfigFileFormat(cf *ConfigFile) error {
373+
data, err := os.ReadFile(cf.Path)
374+
if err != nil {
375+
return fmt.Errorf("failed to read file %s: %w", cf.Path, err)
340376
}
341377

378+
// Default to JSON
379+
// we don't care about the contents of the file, we just want to validate that it's valid JSON
380+
_, err = hujson.Parse(data)
381+
if err != nil {
382+
return fmt.Errorf("failed to parse JSON for file %s: %w", cf.Path, err)
383+
}
342384
return nil
343385
}

pkg/client/manager.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67

78
"github.com/stacklok/toolhive/pkg/config"
@@ -77,7 +78,7 @@ func (m *defaultManager) RegisterClients(ctx context.Context, clients []Client)
7778

7879
// Add currently running MCPs to the newly registered client
7980
if err := m.addRunningMCPsToClient(ctx, client.Name); err != nil {
80-
logger.Warnf("Warning: Failed to add running MCPs to client %s: %v", client.Name, err)
81+
return fmt.Errorf("failed to add running MCPs to client %s: %v", client.Name, err)
8182
}
8283
}
8384
return nil
@@ -107,7 +108,15 @@ func (m *defaultManager) addRunningMCPsToClient(ctx context.Context, clientType
107108
// Find the client configuration for the specified client
108109
clientConfig, err := FindClientConfig(clientType)
109110
if err != nil {
110-
return fmt.Errorf("failed to find client configurations: %w", err)
111+
if errors.Is(err, ErrConfigFileNotFound) {
112+
// Create a new client configuration if it doesn't exist
113+
clientConfig, err = CreateClientConfig(clientType)
114+
if err != nil {
115+
return fmt.Errorf("failed to create client configuration for %s: %w", clientType, err)
116+
}
117+
} else {
118+
return fmt.Errorf("failed to find client configuration: %w", err)
119+
}
111120
}
112121

113122
// For each running container, add it to the client configuration

0 commit comments

Comments
 (0)