Skip to content

Commit e33fe91

Browse files
committed
validate keyring oauth creds and re-auth if needed
1 parent 9120eea commit e33fe91

3 files changed

Lines changed: 150 additions & 15 deletions

File tree

cmd/src/login_oauth.go

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,40 @@ import (
1515
"github.com/sourcegraph/src-cli/internal/oauth"
1616
)
1717

18-
var loadStoredOAuthToken = oauth.LoadToken
18+
var (
19+
loadStoredOAuthToken = oauth.LoadToken
20+
storeOAuthToken = oauth.StoreToken
21+
)
1922

2023
func runOAuthLogin(ctx context.Context, p loginParams) error {
21-
client, err := oauthLoginClient(ctx, p)
24+
client, loadedFromStore, err := oauthLoginClient(ctx, p)
2225
if err != nil {
2326
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
2427
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
2528
return cmderrors.ExitCode1
2629
}
2730

31+
if loadedFromStore {
32+
username, validateErr := currentUsername(ctx, client)
33+
if validateErr == nil && username != "" {
34+
printAuthenticatedUser(p.out, username, p.cfg.endpointURL)
35+
fmt.Fprintln(p.out)
36+
fmt.Fprint(p.out, "✔︎ Authenticated with OAuth credentials")
37+
fmt.Fprintln(p.out)
38+
return nil
39+
}
40+
41+
fmt.Fprintln(p.out)
42+
fmt.Fprintln(p.out, "⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.")
43+
44+
client, err = newOAuthLoginClient(ctx, p)
45+
if err != nil {
46+
printLoginProblem(p.out, fmt.Sprintf("OAuth Device flow authentication failed: %s", err))
47+
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
48+
return cmderrors.ExitCode1
49+
}
50+
}
51+
2852
if err := validateCurrentUser(ctx, client, p.out, p.cfg.endpointURL); err != nil {
2953
return err
3054
}
@@ -38,18 +62,23 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
3862
// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
3963
// and use it if one is present.
4064
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
41-
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
42-
// if we have a stored token, used it. Otherwise run the device flow
65+
func oauthLoginClient(ctx context.Context, p loginParams) (api.Client, bool, error) {
66+
// if we have a stored token, use it. Otherwise run the device flow
4367
if token, err := loadStoredOAuthToken(ctx, p.cfg.endpointURL); err == nil {
44-
return newOAuthAPIClient(p, token), nil
68+
return newOAuthAPIClient(p, token), true, nil
4569
}
4670

71+
client, err := newOAuthLoginClient(ctx, p)
72+
return client, false, err
73+
}
74+
75+
func newOAuthLoginClient(ctx context.Context, p loginParams) (api.Client, error) {
4776
token, err := runOAuthDeviceFlow(ctx, p.cfg.endpointURL, p.out, p.oauthClient)
4877
if err != nil {
4978
return nil, err
5079
}
5180

52-
if err := oauth.StoreToken(ctx, token); err != nil {
81+
if err := storeOAuthToken(ctx, token); err != nil {
5382
fmt.Fprintln(p.out)
5483
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
5584
}

cmd/src/login_test.go

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,80 @@ func TestLogin(t *testing.T) {
137137
t.Errorf("got output %q, want %q", gotOut, wantOut)
138138
}
139139
})
140+
141+
t.Run("invalid stored oauth token restarts device flow", func(t *testing.T) {
142+
var authHeaders []string
143+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
144+
authHeaders = append(authHeaders, r.Header.Get("Authorization"))
145+
if r.Header.Get("Authorization") != "Bearer new-oauth-token" {
146+
http.Error(w, "", http.StatusUnauthorized)
147+
return
148+
}
149+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
150+
}))
151+
defer s.Close()
152+
153+
restoreStoredOAuthLoader(t, func(_ context.Context, _ *url.URL) (*oauth.Token, error) {
154+
return &oauth.Token{
155+
Endpoint: s.URL,
156+
ClientID: oauth.DefaultClientID,
157+
AccessToken: "old-oauth-token",
158+
ExpiresAt: time.Now().Add(time.Hour),
159+
}, nil
160+
})
161+
restoreOAuthTokenStore(t, func(context.Context, *oauth.Token) error { return nil })
162+
163+
u, _ := url.ParseRequestURI(s.URL)
164+
startCalled := false
165+
pollCalled := false
166+
var out bytes.Buffer
167+
err := loginCmd(context.Background(), loginParams{
168+
cfg: &config{endpointURL: u},
169+
client: (&config{endpointURL: u}).apiClient(nil, io.Discard),
170+
out: &out,
171+
oauthClient: fakeOAuthClient{
172+
startCalled: &startCalled,
173+
deviceResp: &oauth.DeviceAuthResponse{
174+
DeviceCode: "device-code",
175+
ExpiresIn: 60,
176+
},
177+
pollCalled: &pollCalled,
178+
pollResp: &oauth.TokenResponse{
179+
AccessToken: "new-oauth-token",
180+
ExpiresIn: 3600,
181+
TokenType: "Bearer",
182+
},
183+
},
184+
})
185+
if err != nil {
186+
t.Fatal(err)
187+
}
188+
if !startCalled || !pollCalled {
189+
t.Fatal("expected invalid stored oauth token to restart device flow")
190+
}
191+
if len(authHeaders) != 2 || authHeaders[0] != "Bearer old-oauth-token" || authHeaders[1] != "Bearer new-oauth-token" {
192+
t.Fatalf("Authorization headers = %q, want old token then new token", authHeaders)
193+
}
194+
gotOut := out.String()
195+
for _, want := range []string{
196+
"⚠️ Warning: Stored OAuth credentials could not be verified. Starting a new OAuth device flow.",
197+
"Waiting for authorization... DONE",
198+
"✔︎ Authenticated as alice on " + s.URL,
199+
"✔︎ Authenticated with OAuth credentials",
200+
} {
201+
if !strings.Contains(gotOut, want) {
202+
t.Errorf("got output %q, want it to contain %q", gotOut, want)
203+
}
204+
}
205+
})
140206
}
141207

142208
type fakeOAuthClient struct {
143209
startErr error
144210
startCalled *bool
211+
deviceResp *oauth.DeviceAuthResponse
212+
pollCalled *bool
213+
pollResp *oauth.TokenResponse
145214
}
146215

147216
func (f fakeOAuthClient) ClientID() string {
@@ -156,10 +225,22 @@ func (f fakeOAuthClient) Start(context.Context, *url.URL, []string) (*oauth.Devi
156225
if f.startCalled != nil {
157226
*f.startCalled = true
158227
}
159-
return nil, f.startErr
228+
if f.startErr != nil {
229+
return nil, f.startErr
230+
}
231+
if f.deviceResp != nil {
232+
return f.deviceResp, nil
233+
}
234+
return nil, fmt.Errorf("unexpected call to Start")
160235
}
161236

162237
func (f fakeOAuthClient) Poll(context.Context, *url.URL, string, time.Duration, int) (*oauth.TokenResponse, error) {
238+
if f.pollCalled != nil {
239+
*f.pollCalled = true
240+
}
241+
if f.pollResp != nil {
242+
return f.pollResp, nil
243+
}
163244
return nil, fmt.Errorf("unexpected call to Poll")
164245
}
165246

@@ -242,3 +323,13 @@ func restoreStoredOAuthLoader(t *testing.T, loader func(context.Context, *url.UR
242323
loadStoredOAuthToken = prev
243324
})
244325
}
326+
327+
func restoreOAuthTokenStore(t *testing.T, store func(context.Context, *oauth.Token) error) {
328+
t.Helper()
329+
330+
prev := storeOAuthToken
331+
storeOAuthToken = store
332+
t.Cleanup(func() {
333+
storeOAuthToken = prev
334+
})
335+
}

cmd/src/login_validate.go

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,8 @@ func runValidatedLogin(ctx context.Context, p loginParams) error {
1616
}
1717

1818
func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer, endpointURL *url.URL) error {
19-
query := `query CurrentUser { currentUser { username } }`
20-
var result struct {
21-
CurrentUser *struct{ Username string }
22-
}
23-
if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil {
19+
username, err := currentUsername(ctx, client)
20+
if err != nil {
2421
if strings.HasPrefix(err.Error(), "error: 401 Unauthorized") || strings.HasPrefix(err.Error(), "error: 403 Forbidden") {
2522
printLoginProblem(out, "Invalid access token.")
2623
} else {
@@ -31,14 +28,32 @@ func validateCurrentUser(ctx context.Context, client api.Client, out io.Writer,
3128
return cmderrors.ExitCode1
3229
}
3330

34-
if result.CurrentUser == nil {
31+
if username == "" {
3532
// This should never happen; we verified there is an access token, so there should always be
3633
// a user.
3734
printLoginProblem(out, fmt.Sprintf("Unable to determine user on %s.", endpointURL))
3835
return cmderrors.ExitCode1
3936
}
37+
printAuthenticatedUser(out, username, endpointURL)
38+
return nil
39+
}
40+
41+
func printAuthenticatedUser(out io.Writer, username string, endpointURL *url.URL) {
4042
fmt.Fprintln(out)
41-
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", result.CurrentUser.Username, endpointURL)
43+
fmt.Fprintf(out, "✔︎ Authenticated as %s on %s\n", username, endpointURL)
4244
fmt.Fprintln(out)
43-
return nil
45+
}
46+
47+
func currentUsername(ctx context.Context, client api.Client) (string, error) {
48+
query := `query CurrentUser { currentUser { username } }`
49+
var result struct {
50+
CurrentUser *struct{ Username string }
51+
}
52+
if _, err := client.NewRequest(query, nil).Do(ctx, &result); err != nil {
53+
return "", err
54+
}
55+
if result.CurrentUser == nil {
56+
return "", nil
57+
}
58+
return result.CurrentUser.Username, nil
4459
}

0 commit comments

Comments
 (0)