Skip to content

Commit 15f6ff4

Browse files
gkeesh7gauravkuber
andauthored
Add unit tests for cluster_client.go Locations function (#539)
* Add unit tests for cluster_client.go Locations function - Add table-driven tests for blobclient.Locations() - Test cases: empty cluster, single node, multiple nodes, failover, all fail - Use t.Cleanup() for automatic test server cleanup - Add mock hostlist.List implementation for testing * Simplify slice comparison: replace reflect.DeepEqual with require.Equal --------- Co-authored-by: gauravk <gauravk@uber.com>
1 parent 812f744 commit 15f6ff4

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package blobclient_test
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/uber/kraken/core"
12+
"github.com/uber/kraken/lib/hostlist"
13+
"github.com/uber/kraken/origin/blobclient"
14+
"github.com/uber/kraken/utils/stringset"
15+
)
16+
17+
type mockList struct {
18+
addrs []string
19+
}
20+
21+
func (m mockList) Resolve() stringset.Set {
22+
return stringset.New(m.addrs...)
23+
}
24+
25+
func getMockList(addrs ...string) hostlist.List {
26+
return mockList{addrs}
27+
}
28+
29+
// stripHTTPPrefix removes the http:// prefix from a URL.
30+
func stripHTTPPrefix(url string) string {
31+
return strings.TrimPrefix(url, "http://")
32+
}
33+
34+
// testClusterServer creates a test HTTP server and returns the stripped address.
35+
func testClusterServer(t *testing.T, handler http.HandlerFunc) string {
36+
t.Helper()
37+
server := httptest.NewServer(handler)
38+
t.Cleanup(server.Close)
39+
return stripHTTPPrefix(server.URL)
40+
}
41+
42+
func TestClusterLocations(t *testing.T) {
43+
tests := []struct {
44+
name string
45+
setupServers func(t *testing.T) []string // returns server addresses for cluster
46+
want []string
47+
wantErr bool
48+
errContains string
49+
}{
50+
{
51+
name: "empty cluster",
52+
setupServers: func(t *testing.T) []string {
53+
return []string{} // no servers
54+
},
55+
wantErr: true,
56+
errContains: "cluster is empty",
57+
},
58+
{
59+
name: "single node cluster returns locations",
60+
setupServers: func(t *testing.T) []string {
61+
addr := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
62+
w.Header().Set("Origin-Locations", "origin1:8080,origin2:8080")
63+
w.WriteHeader(http.StatusOK)
64+
})
65+
return []string{addr}
66+
},
67+
want: []string{"origin1:8080", "origin2:8080"},
68+
wantErr: false,
69+
},
70+
{
71+
name: "multiple nodes - first succeeds",
72+
setupServers: func(t *testing.T) []string {
73+
addr1 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
74+
w.Header().Set("Origin-Locations", "origin1:8080")
75+
w.WriteHeader(http.StatusOK)
76+
})
77+
addr2 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
78+
w.WriteHeader(http.StatusInternalServerError)
79+
})
80+
return []string{addr1, addr2}
81+
},
82+
want: []string{"origin1:8080"},
83+
wantErr: false,
84+
},
85+
{
86+
name: "first node fails - second succeeds",
87+
setupServers: func(t *testing.T) []string {
88+
addr1 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
89+
w.WriteHeader(http.StatusInternalServerError)
90+
})
91+
addr2 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
92+
w.Header().Set("Origin-Locations", "origin2:8080")
93+
w.WriteHeader(http.StatusOK)
94+
})
95+
return []string{addr1, addr2}
96+
},
97+
want: []string{"origin2:8080"},
98+
wantErr: false,
99+
},
100+
{
101+
name: "all nodes fail",
102+
setupServers: func(t *testing.T) []string {
103+
addr1 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
104+
w.WriteHeader(http.StatusInternalServerError)
105+
})
106+
addr2 := testClusterServer(t, func(w http.ResponseWriter, r *http.Request) {
107+
w.WriteHeader(http.StatusInternalServerError)
108+
})
109+
return []string{addr1, addr2}
110+
},
111+
wantErr: true,
112+
},
113+
}
114+
115+
for _, tt := range tests {
116+
t.Run(tt.name, func(t *testing.T) {
117+
require := require.New(t)
118+
119+
addrs := tt.setupServers(t)
120+
p := blobclient.NewProvider()
121+
cluster := getMockList(addrs...)
122+
d := core.DigestFixture()
123+
124+
got, err := blobclient.Locations(p, cluster, d)
125+
126+
if tt.wantErr {
127+
require.Error(err)
128+
if tt.errContains != "" {
129+
require.Contains(err.Error(), tt.errContains)
130+
}
131+
return
132+
}
133+
134+
require.NoError(err)
135+
require.Equal(tt.want, got)
136+
})
137+
}
138+
}

0 commit comments

Comments
 (0)