Skip to content

Commit 75bf1b4

Browse files
committed
make NewClient take ClientID as param
1 parent 3c7d643 commit 75bf1b4

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

internal/oauthdevice/device_flow.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ import (
1717
)
1818

1919
const (
20-
ClientID = "sgo_cid_sourcegraph-cli"
20+
// DefaultClientID is a predefined Client ID built into Sourcegraph
21+
DefaultClientID = "sgo_cid_sourcegraph-cli"
2122

23+
// wellKnownPath is the path on the sourcegraph server where clients can discover OAuth configuration
2224
wellKnownPath = "/.well-known/openid-configuration"
2325

2426
GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code"
@@ -69,13 +71,15 @@ type Client interface {
6971
}
7072

7173
type httpClient struct {
72-
client *http.Client
74+
clientID string
75+
client *http.Client
7376
// cached OIDC configuration per endpoint
7477
configCache map[string]*OIDCConfiguration
7578
}
7679

77-
func NewClient() Client {
80+
func NewClient(clientID string) Client {
7881
return &httpClient{
82+
clientID: clientID,
7983
client: &http.Client{
8084
Timeout: 30 * time.Second,
8185
},
@@ -152,7 +156,7 @@ func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string
152156
}
153157

154158
data := url.Values{}
155-
data.Set("client_id", ClientID)
159+
data.Set("client_id", DefaultClientID)
156160
if len(scopes) > 0 {
157161
data.Set("scope", strings.Join(scopes, " "))
158162
} else {
@@ -266,7 +270,7 @@ func (e *PollError) Error() string {
266270

267271
func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) {
268272
data := url.Values{}
269-
data.Set("client_id", ClientID)
273+
data.Set("client_id", DefaultClientID)
270274
data.Set("device_code", deviceCode)
271275
data.Set("grant_type", GrantTypeDeviceCode)
272276

internal/oauthdevice/device_flow_test.go

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func TestDiscover_Success(t *testing.T) {
5050
server := newTestServer(t, testServerOptions{})
5151
defer server.Close()
5252

53-
client := NewClient()
53+
client := NewClient(DefaultClientID)
5454
config, err := client.Discover(context.Background(), server.URL)
5555
if err != nil {
5656
t.Fatalf("Discover() error = %v", err)
@@ -78,7 +78,7 @@ func TestDiscover_Caching(t *testing.T) {
7878
})
7979
defer server.Close()
8080

81-
client := NewClient()
81+
client := NewClient(DefaultClientID)
8282

8383
// Populate the cache
8484
_, err := client.Discover(context.Background(), server.URL)
@@ -105,7 +105,7 @@ func TestDiscover_Error(t *testing.T) {
105105
})
106106
defer server.Close()
107107

108-
client := NewClient()
108+
client := NewClient(DefaultClientID)
109109
_, err := client.Discover(context.Background(), server.URL)
110110
if err == nil {
111111
t.Fatal("Discover() expected error, got nil")
@@ -141,8 +141,8 @@ func TestStart_Success(t *testing.T) {
141141
return
142142
}
143143

144-
if got := r.FormValue("client_id"); got != ClientID {
145-
t.Errorf("unexpected client_id: got %q, want %q", got, ClientID)
144+
if got := r.FormValue("client_id"); got != DefaultClientID {
145+
t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID)
146146
}
147147

148148
w.Header().Set("Content-Type", "application/json")
@@ -152,7 +152,7 @@ func TestStart_Success(t *testing.T) {
152152
})
153153
defer server.Close()
154154

155-
client := NewClient()
155+
client := NewClient(DefaultClientID)
156156
resp, err := client.Start(context.Background(), server.URL, nil)
157157
if err != nil {
158158
t.Fatalf("Start() error = %v", err)
@@ -204,7 +204,7 @@ func TestStart_WithScopes(t *testing.T) {
204204
})
205205
defer server.Close()
206206

207-
client := NewClient()
207+
client := NewClient(DefaultClientID)
208208
_, err := client.Start(context.Background(), server.URL, []string{"read", "write"})
209209
if err != nil {
210210
t.Fatalf("Start() error = %v", err)
@@ -230,7 +230,7 @@ func TestStart_Error(t *testing.T) {
230230
})
231231
defer server.Close()
232232

233-
client := NewClient()
233+
client := NewClient(DefaultClientID)
234234
_, err := client.Start(context.Background(), server.URL, nil)
235235
if err == nil {
236236
t.Fatal("Start() expected error, got nil")
@@ -253,7 +253,7 @@ func TestStart_NoDeviceEndpoint(t *testing.T) {
253253
})
254254
defer server.Close()
255255

256-
client := NewClient()
256+
client := NewClient(DefaultClientID)
257257
_, err := client.Start(context.Background(), server.URL, nil)
258258
if err == nil {
259259
t.Fatal("Start() expected error, got nil")
@@ -287,8 +287,8 @@ func TestPoll_Success(t *testing.T) {
287287
return
288288
}
289289

290-
if got := r.FormValue("client_id"); got != ClientID {
291-
t.Errorf("unexpected client_id: got %q, want %q", got, ClientID)
290+
if got := r.FormValue("client_id"); got != DefaultClientID {
291+
t.Errorf("unexpected client_id: got %q, want %q", got, DefaultClientID)
292292
}
293293
if got := r.FormValue("grant_type"); got != GrantTypeDeviceCode {
294294
t.Errorf("unexpected grant_type: got %q", got)
@@ -301,7 +301,7 @@ func TestPoll_Success(t *testing.T) {
301301
})
302302
defer server.Close()
303303

304-
client := NewClient().(*httpClient)
304+
client := NewClient(DefaultClientID).(*httpClient)
305305
resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60)
306306
if err != nil {
307307
t.Fatalf("Poll() error = %v", err)
@@ -343,7 +343,7 @@ func TestPoll_AuthorizationPending(t *testing.T) {
343343
})
344344
defer server.Close()
345345

346-
client := NewClient().(*httpClient)
346+
client := NewClient(DefaultClientID).(*httpClient)
347347
resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60)
348348
if err != nil {
349349
t.Fatalf("Poll() error = %v", err)
@@ -385,7 +385,7 @@ func TestPoll_SlowDown(t *testing.T) {
385385
})
386386
defer server.Close()
387387

388-
client := NewClient().(*httpClient)
388+
client := NewClient(DefaultClientID).(*httpClient)
389389
resp, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60)
390390
if err != nil {
391391
t.Fatalf("Poll() error = %v", err)
@@ -415,7 +415,7 @@ func TestPoll_ExpiredToken(t *testing.T) {
415415
})
416416
defer server.Close()
417417

418-
client := NewClient().(*httpClient)
418+
client := NewClient(DefaultClientID).(*httpClient)
419419
_, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60)
420420
if err == nil {
421421
t.Fatal("Poll() expected error, got nil")
@@ -442,7 +442,7 @@ func TestPoll_AccessDenied(t *testing.T) {
442442
})
443443
defer server.Close()
444444

445-
client := NewClient().(*httpClient)
445+
client := NewClient(DefaultClientID).(*httpClient)
446446
_, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 60)
447447
if err == nil {
448448
t.Fatal("Poll() expected error, got nil")
@@ -468,7 +468,7 @@ func TestPoll_Timeout(t *testing.T) {
468468
})
469469
defer server.Close()
470470

471-
client := NewClient().(*httpClient)
471+
client := NewClient(DefaultClientID).(*httpClient)
472472
_, err := client.Poll(context.Background(), server.URL, "test-device-code", 10*time.Millisecond, 0)
473473
if err == nil {
474474
t.Fatal("Poll() expected error, got nil")
@@ -497,7 +497,7 @@ func TestPoll_ContextCancellation(t *testing.T) {
497497
ctx, cancel := context.WithCancel(context.Background())
498498
cancel()
499499

500-
client := NewClient().(*httpClient)
500+
client := NewClient(DefaultClientID).(*httpClient)
501501
_, err := client.Poll(ctx, server.URL, "test-device-code", 10*time.Millisecond, 3600)
502502
if err == nil {
503503
t.Fatal("Poll() expected error, got nil")

0 commit comments

Comments
 (0)