Skip to content

Commit 358ca2c

Browse files
committed
More refactoring
1 parent e6d8a94 commit 358ca2c

File tree

2 files changed

+73
-42
lines changed

2 files changed

+73
-42
lines changed

cmd/migrate_from_qdrant.go

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,20 @@ type MigrateFromQdrantCmd struct {
3939
targetTLS bool
4040
}
4141

42+
func getPort(u *url.URL) (int, error) {
43+
if u.Port() != "" {
44+
sourcePort, err := strconv.Atoi(u.Port())
45+
if err != nil {
46+
return 0, fmt.Errorf("failed to parse source port: %w", err)
47+
}
48+
return sourcePort, nil
49+
} else if u.Scheme == "https" {
50+
return 443, nil
51+
}
52+
53+
return 80, nil
54+
}
55+
4256
func (r *MigrateFromQdrantCmd) Parse() error {
4357
sourceUrl, err := url.Parse(r.SourceUrl)
4458
if err != nil {
@@ -47,16 +61,7 @@ func (r *MigrateFromQdrantCmd) Parse() error {
4761

4862
r.sourceHost = sourceUrl.Hostname()
4963
r.sourceTLS = sourceUrl.Scheme == "https"
50-
if sourceUrl.Port() != "" {
51-
r.sourcePort, err = strconv.Atoi(sourceUrl.Port())
52-
if err != nil {
53-
return fmt.Errorf("failed to parse source port: %w", err)
54-
}
55-
} else if r.sourceTLS {
56-
r.sourcePort = 443
57-
} else {
58-
r.sourcePort = 80
59-
}
64+
r.sourcePort, err = getPort(sourceUrl)
6065

6166
targetUrl, err := url.Parse(r.TargetUrl)
6267
if err != nil {
@@ -65,16 +70,7 @@ func (r *MigrateFromQdrantCmd) Parse() error {
6570

6671
r.targetHost = targetUrl.Hostname()
6772
r.targetTLS = targetUrl.Scheme == "https"
68-
if targetUrl.Port() != "" {
69-
r.targetPort, err = strconv.Atoi(targetUrl.Port())
70-
if err != nil {
71-
return fmt.Errorf("failed to parse target port: %w", err)
72-
}
73-
} else if r.targetTLS {
74-
r.targetPort = 443
75-
} else {
76-
r.targetPort = 80
77-
}
73+
r.targetPort, err = getPort(targetUrl)
7874

7975
return nil
8076
}
@@ -98,9 +94,13 @@ func (r *MigrateFromQdrantCmd) Run(globals *Globals) error {
9894
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
9995
defer stop()
10096

101-
sourceClient, targetClient, err := r.connect(globals)
97+
sourceClient, err := r.connect(globals, r.sourceHost, r.sourcePort, r.SourceAPIKey, r.sourceTLS)
10298
if err != nil {
103-
return fmt.Errorf("failed to connect to source or target: %w", err)
99+
return fmt.Errorf("failed to connect to source: %w", err)
100+
}
101+
targetClient, err := r.connect(globals, r.targetHost, r.targetPort, r.TargetAPIKey, r.targetTLS)
102+
if err != nil {
103+
return fmt.Errorf("failed to connect to target: %w", err)
104104
}
105105

106106
sourcePointCount, err := sourceClient.Count(ctx, &qdrant.CountPoints{
@@ -166,7 +166,7 @@ func (r *MigrateFromQdrantCmd) Run(globals *Globals) error {
166166
return nil
167167
}
168168

169-
func (r *MigrateFromQdrantCmd) connect(globals *Globals) (*qdrant.Client, *qdrant.Client, error) {
169+
func (r *MigrateFromQdrantCmd) connect(globals *Globals, host string, port int, apiKey string, useTLS bool) (*qdrant.Client, error) {
170170
debugLogger := logging.LoggerFunc(func(ctx context.Context, lvl logging.Level, msg string, fields ...any) {
171171
pterm.Debug.Printf(msg, fields...)
172172
})
@@ -190,31 +190,19 @@ func (r *MigrateFromQdrantCmd) connect(globals *Globals) (*qdrant.Client, *qdran
190190
InsecureSkipVerify: true,
191191
}
192192

193-
sourceClient, err := qdrant.NewClient(&qdrant.Config{
194-
Host: r.sourceHost,
195-
Port: r.sourcePort,
196-
APIKey: r.SourceAPIKey,
197-
UseTLS: r.sourceTLS,
198-
TLSConfig: &tlsConfig,
199-
GrpcOptions: grpcOptions,
200-
})
201-
if err != nil {
202-
return nil, nil, fmt.Errorf("failed to create source client: %w", err)
203-
}
204-
205-
targetClient, err := qdrant.NewClient(&qdrant.Config{
206-
Host: r.targetHost,
207-
Port: r.targetPort,
208-
APIKey: r.TargetAPIKey,
209-
UseTLS: r.targetTLS,
193+
client, err := qdrant.NewClient(&qdrant.Config{
194+
Host: host,
195+
Port: port,
196+
APIKey: apiKey,
197+
UseTLS: useTLS,
210198
TLSConfig: &tlsConfig,
211199
GrpcOptions: grpcOptions,
212200
})
213201
if err != nil {
214-
return nil, nil, fmt.Errorf("failed to create target client: %w", err)
202+
return nil, fmt.Errorf("failed to create client: %w", err)
215203
}
216204

217-
return sourceClient, targetClient, nil
205+
return client, nil
218206
}
219207

220208
func (r *MigrateFromQdrantCmd) perpareTargetCollection(ctx context.Context, sourceClient *qdrant.Client, sourceCollection string, targetClient *qdrant.Client, targetCollection string) (error, *uint64) {

cmd/migrate_from_qdrant_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package cmd
2+
3+
import (
4+
"net/url"
5+
"testing"
6+
)
7+
8+
func Test_getPort(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
url *url.URL
12+
expected int
13+
}{
14+
{
15+
name: "tls enabled, custom port",
16+
url: &url.URL{Scheme: "https", Host: "localhost:6334"},
17+
expected: 6334,
18+
},
19+
{
20+
name: "tls enabled, default port",
21+
url: &url.URL{Scheme: "https", Host: "localhost"},
22+
expected: 443,
23+
},
24+
{
25+
name: "tls disabled, default port",
26+
url: &url.URL{Scheme: "http", Host: "localhost"},
27+
expected: 80,
28+
},
29+
{
30+
name: "tls disabled, custom port",
31+
url: &url.URL{Scheme: "http", Host: "localhost:6334"},
32+
expected: 6334,
33+
},
34+
}
35+
for _, tt := range tests {
36+
t.Run(tt.name, func(t *testing.T) {
37+
got, _ := getPort(tt.url)
38+
if got != tt.expected {
39+
t.Errorf("getPort() got = %v, expected %v", got, tt.expected)
40+
}
41+
})
42+
}
43+
}

0 commit comments

Comments
 (0)