Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 12 additions & 25 deletions memory/mysql/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"crypto/sha256"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sort"
"strings"
Expand Down Expand Up @@ -69,33 +68,21 @@ func NewService(options ...ServiceOpt) (*Service, error) {
option(&opts)
}

// Create MySQL client
builder := storage.GetClientBuilder()
var db storage.Client
var err error

builderOpts := []storage.ClientBuilderOpt{
storage.WithClientBuilderDSN(opts.dsn),
storage.WithExtraOptions(opts.extraOptions...),
}
// Priority: dsn > instanceName.
if opts.dsn != "" {
// Method 1: Use DSN directly (recommended).
db, err = builder(
storage.WithClientBuilderDSN(opts.dsn),
storage.WithExtraOptions(opts.extraOptions...),
)
if err != nil {
return nil, fmt.Errorf("create mysql client from dsn failed: %w", err)
}
} else if opts.instanceName != "" {
// Method 2: Use pre-registered MySQL instance.
builderOpts, ok := storage.GetMySQLInstance(opts.instanceName)
if !ok {
if opts.dsn == "" && opts.instanceName != "" {
var ok bool
if builderOpts, ok = storage.GetMySQLInstance(opts.instanceName); !ok {
return nil, fmt.Errorf("mysql instance %s not found", opts.instanceName)
}
db, err = builder(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create mysql client from instance name failed: %w", err)
}
} else {
return nil, errors.New("either dsn or instance name must be provided")
}

db, err := storage.GetClientBuilder()(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create mysql client failed: %w", err)
}

s := &Service{
Expand Down
8 changes: 0 additions & 8 deletions memory/mysql/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,6 @@ func TestNewService_DSNPriority(t *testing.T) {
assert.NoError(t, mock.ExpectationsWereMet())
}

// TestNewService_MissingDSNAndInstance tests that service creation fails when neither DSN nor instanceName is provided.
func TestNewService_MissingDSNAndInstance(t *testing.T) {
service, err := NewService()
require.Error(t, err)
assert.Nil(t, service)
assert.Contains(t, err.Error(), "either dsn or instance name must be provided")
}

// TestNewService_WithSkipDBInit tests that skipDBInit option skips database initialization.
func TestNewService_WithSkipDBInit(t *testing.T) {
mockDB, mock := setupMockDB(t)
Expand Down
34 changes: 11 additions & 23 deletions memory/postgres/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,36 +69,24 @@ func NewService(options ...ServiceOpt) (*Service, error) {
option(&opts)
}

var db storage.Client
var err error
builder := storage.GetClientBuilder()

builderOpts := []storage.ClientBuilderOpt{
storage.WithClientConnString(buildConnString(opts)),
storage.WithExtraOptions(opts.extraOptions...),
}
// Priority: direct connection settings > instance name
// If direct connection settings are provided, use them.
if opts.host != "" {
connString := buildConnString(opts)
db, err = builder(
context.Background(),
storage.WithClientConnString(connString),
storage.WithExtraOptions(opts.extraOptions...),
)
if err != nil {
return nil, fmt.Errorf("create postgres client from connection settings failed: %w", err)
}
} else if opts.instanceName != "" {
if opts.host == "" && opts.instanceName != "" {
// Otherwise, use instance name if provided.
builderOpts, ok := storage.GetPostgresInstance(opts.instanceName)
if !ok {
var ok bool
if builderOpts, ok = storage.GetPostgresInstance(opts.instanceName); !ok {
return nil, fmt.Errorf("postgres instance %s not found", opts.instanceName)
}
db, err = builder(context.Background(), builderOpts...)
if err != nil {
return nil, fmt.Errorf("create postgres client from instance name failed: %w", err)
}
} else {
return nil, fmt.Errorf("either connection settings (host, port, etc.) or instance name must be provided")
}

db, err := storage.GetClientBuilder()(context.Background(), builderOpts...)
if err != nil {
return nil, fmt.Errorf("create postgres client failed: %w", err)
}
// Build full table name with schema.
fullTableName := sqldb.BuildTableNameWithSchema(opts.schema, "", opts.tableName)

Expand Down
11 changes: 2 additions & 9 deletions memory/postgres/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1370,7 +1370,7 @@ func TestNewService_ConnectionSettingsBuilderError(t *testing.T) {

_, err := NewService(WithHost("localhost"), WithPort(5432), WithDatabase("testdb"))
require.Error(t, err)
assert.Contains(t, err.Error(), "create postgres client from connection settings failed")
assert.Contains(t, err.Error(), "create postgres client failed")
}

func TestNewService_InstanceNameBuilderError(t *testing.T) {
Expand All @@ -1387,7 +1387,7 @@ func TestNewService_InstanceNameBuilderError(t *testing.T) {

_, err := NewService(WithPostgresInstance("test-instance"))
require.Error(t, err)
assert.Contains(t, err.Error(), "create postgres client from instance name failed")
assert.Contains(t, err.Error(), "create postgres client failed")
}

func TestNewService_ConnectionSettingsPriority(t *testing.T) {
Expand Down Expand Up @@ -1789,13 +1789,6 @@ func TestBuildCreateIndexSQL(t *testing.T) {
}
}

// Test NewService with missing connection settings
func TestNewService_MissingConnectionSettings(t *testing.T) {
_, err := NewService()
require.Error(t, err)
assert.Contains(t, err.Error(), "either connection settings")
}

// Test NewService with skipDBInit
func TestNewService_WithSkipDBInit(t *testing.T) {
db, mock := setupMockDB(t)
Expand Down
37 changes: 8 additions & 29 deletions memory/redis/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,43 +63,22 @@ func NewService(options ...ServiceOpt) (*Service, error) {
option(&opts)
}

builder := storage.GetClientBuilder()
var (
redisClient redis.UniversalClient
err error
)
builderOpts := []storage.ClientBuilderOpt{
storage.WithClientBuilderURL(opts.url),
storage.WithExtraOptions(opts.extraOptions...),
}

// if instance name set, and url not set, use instance name to create redis client
if opts.url == "" && opts.instanceName != "" {
builderOpts, ok := storage.GetRedisInstance(opts.instanceName)
if !ok {
var ok bool
if builderOpts, ok = storage.GetRedisInstance(opts.instanceName); !ok {
return nil, fmt.Errorf("redis instance %s not found", opts.instanceName)
}
redisClient, err = builder(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create redis client from instance name failed: %w", err)
}

// Test connection with Ping to ensure Redis is accessible.
ctx, cancel := context.WithTimeout(context.Background(), defaultConnectionTimeout)
defer cancel()
if err := redisClient.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("redis connection test failed: %w", err)
}

return &Service{
opts: opts,
redisClient: redisClient,
cachedTools: make(map[string]tool.Tool),
}, nil
}

redisClient, err = builder(
storage.WithClientBuilderURL(opts.url),
storage.WithExtraOptions(opts.extraOptions...),
)
redisClient, err := storage.GetClientBuilder()(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create redis client from url failed: %w", err)
return nil, fmt.Errorf("create redis client failed: %w", err)
}

// Test connection with Ping to ensure Redis is accessible.
Expand Down
35 changes: 12 additions & 23 deletions session/mysql/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,32 +102,21 @@ func NewService(options ...ServiceOpt) (*Service, error) {
}

// Create MySQL client
builder := storage.GetClientBuilder()
var mysqlClient storage.Client
var err error

// Priority: dsn > instanceName
if opts.dsn != "" {
// Method 1: Use DSN directly (recommended)
mysqlClient, err = builder(
storage.WithClientBuilderDSN(opts.dsn),
storage.WithExtraOptions(opts.extraOptions...),
)
if err != nil {
return nil, fmt.Errorf("create mysql client from dsn failed: %w", err)
}
} else if opts.instanceName != "" {
builderOpts := []storage.ClientBuilderOpt{
storage.WithClientBuilderDSN(opts.dsn),
storage.WithExtraOptions(opts.extraOptions...),
}
if opts.dsn == "" && opts.instanceName != "" {
// Method 2: Use pre-registered MySQL instance
builderOpts, ok := storage.GetMySQLInstance(opts.instanceName)
if !ok {
var ok bool
if builderOpts, ok = storage.GetMySQLInstance(opts.instanceName); !ok {
return nil, fmt.Errorf("mysql instance %s not found", opts.instanceName)
}
mysqlClient, err = builder(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create mysql client from instance name failed: %w", err)
}
} else {
return nil, fmt.Errorf("either dsn or instance name must be provided")
}

mysqlClient, err := storage.GetClientBuilder()(builderOpts...)
if err != nil {
return nil, fmt.Errorf("create mysql client failed: %w", err)
}

// Build table names with prefix
Expand Down
6 changes: 3 additions & 3 deletions session/mysql/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,7 @@ func TestNewService_MissingDSNAndInstance(t *testing.T) {
svc, err := NewService()
assert.Error(t, err)
assert.Nil(t, svc)
assert.Contains(t, err.Error(), "either dsn or instance name must be provided")
assert.Contains(t, err.Error(), "create mysql client failed")
}

func TestNewService_WithInstance_Success(t *testing.T) {
Expand Down Expand Up @@ -1725,7 +1725,7 @@ func TestNewService_ClientBuilderError(t *testing.T) {
)
assert.Error(t, err)
assert.Nil(t, svc)
assert.Contains(t, err.Error(), "create mysql client from dsn failed")
assert.Contains(t, err.Error(), "create mysql client failed")

// Test with instance name
storage.RegisterMySQLInstance("test-error-instance",
Expand All @@ -1737,7 +1737,7 @@ func TestNewService_ClientBuilderError(t *testing.T) {
)
assert.Error(t, err)
assert.Nil(t, svc)
assert.Contains(t, err.Error(), "create mysql client from instance name failed")
assert.Contains(t, err.Error(), "create mysql client failed")
}

func TestNewService_DBInitFailure(t *testing.T) {
Expand Down
34 changes: 11 additions & 23 deletions session/postgres/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,22 @@ func NewService(options ...ServiceOpt) (*Service, error) {
}
}

var pgClient storage.Client
var err error
builder := storage.GetClientBuilder()

builderOpts := []storage.ClientBuilderOpt{
storage.WithClientConnString(buildConnString(opts)),
storage.WithExtraOptions(opts.extraOptions...),
}
// Priority: direct connection settings > instance name
// If direct connection settings are provided, use them
if opts.host != "" {
connString := buildConnString(opts)
pgClient, err = builder(
context.Background(),
storage.WithClientConnString(connString),
storage.WithExtraOptions(opts.extraOptions...),
)
if err != nil {
return nil, fmt.Errorf("create postgres client from connection settings failed: %w", err)
}
} else if opts.instanceName != "" {
if opts.host == "" && opts.instanceName != "" {
// Otherwise, use instance name if provided
builderOpts, ok := storage.GetPostgresInstance(opts.instanceName)
if !ok {
var ok bool
if builderOpts, ok = storage.GetPostgresInstance(opts.instanceName); !ok {
return nil, fmt.Errorf("postgres instance %s not found", opts.instanceName)
}
pgClient, err = builder(context.Background(), builderOpts...)
if err != nil {
return nil, fmt.Errorf("create postgres client from instance name failed: %w", err)
}
} else {
return nil, fmt.Errorf("either connection settings (host, port, etc.) or instance name must be provided")
}
pgClient, err := storage.GetClientBuilder()(context.Background(), builderOpts...)
if err != nil {
return nil, fmt.Errorf("create postgres client failed: %w", err)
}

s := &Service{
Expand Down
37 changes: 37 additions & 0 deletions session/postgres/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2305,3 +2305,40 @@ func TestClose_Multiple(t *testing.T) {
err = s.Close()
require.NoError(t, err)
}

func TestNewService_MissingDSNAndInstance(t *testing.T) {
svc, err := NewService()
assert.Error(t, err)
assert.Nil(t, svc)
assert.Contains(t, err.Error(), "create postgres client failed")
}

func TestNewService_WithInstance_Success(t *testing.T) {
db, mock, err := sqlmock.New(sqlmock.MonitorPingsOption(true))
require.NoError(t, err)
defer db.Close()

originalBuilder := storage.GetClientBuilder()
defer storage.SetClientBuilder(originalBuilder)

storage.SetClientBuilder(func(ctx context.Context, builderOpts ...storage.ClientBuilderOpt) (storage.Client, error) {
return &mockPostgresClient{db: db}, nil
})

// Register instance
instanceName := "test-instance-success"
storage.RegisterPostgresInstance(instanceName,
storage.WithClientConnString("test:test@tcp(localhost:3306)/testdb"),
)

svc, err := NewService(
WithPostgresInstance(instanceName),
WithSkipDBInit(true),
)
require.NoError(t, err)
require.NotNil(t, svc)

err = svc.Close()
assert.NoError(t, err)
assert.NoError(t, mock.ExpectationsWereMet())
}
Loading
Loading