Skip to content

Commit bb43d20

Browse files
committed
spelling
1 parent 9f9d240 commit bb43d20

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
// Package oauthdevice implements the OAuth 2.0 Device Authorization Grant (RFC 8628)
2+
// for authenticating with Sourcegraph instances.
3+
package oauthdevice
4+
5+
import (
6+
"context"
7+
"encoding/json"
8+
"fmt"
9+
"io"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
"testing"
14+
"time"
15+
16+
"github.com/sourcegraph/sourcegraph/lib/errors"
17+
)
18+
19+
const (
20+
ClientID = "sgo_cid_sourcegraph-cli"
21+
22+
wellKnownPath = "/.well-known/openid-configuration"
23+
24+
GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code"
25+
26+
ScopeOpenID string = "openid"
27+
ScopeProfile string = "profile"
28+
ScopeEmail string = "email"
29+
ScopeOfflineAccess string = "offline_access"
30+
ScopeUserAll string = "user:all"
31+
)
32+
33+
var defaultScopes = []string{ScopeEmail, ScopeOfflineAccess, ScopeOpenID, ScopeProfile, ScopeUserAll}
34+
35+
// OIDCConfiguration represents the relevant fields from the OpenID Connect
36+
// Discovery document at /.well-known/openid-configuration
37+
type OIDCConfiguration struct {
38+
Issuer string `json:"issuer,omitempty"`
39+
TokenEndpoint string `json:"token_endpoint,omitempty"`
40+
DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint,omitempty"`
41+
}
42+
43+
type DeviceAuthResponse struct {
44+
DeviceCode string `json:"device_code"`
45+
UserCode string `json:"user_code"`
46+
VerificationURI string `json:"verification_uri"`
47+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
48+
ExpiresIn int `json:"expires_in"`
49+
Interval int `json:"interval"`
50+
}
51+
52+
type TokenResponse struct {
53+
AccessToken string `json:"access_token"`
54+
TokenType string `json:"token_type"`
55+
ExpiresIn int `json:"expires_in,omitempty"`
56+
Scope string `json:"scope,omitempty"`
57+
}
58+
59+
type ErrorResponse struct {
60+
Error string `json:"error"`
61+
ErrorDescription string `json:"error_description,omitempty"`
62+
}
63+
64+
type Client interface {
65+
Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error)
66+
Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error)
67+
Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error)
68+
}
69+
70+
type httpClient struct {
71+
client *http.Client
72+
// cached OIDC configuration per endpoint
73+
configCache map[string]*OIDCConfiguration
74+
}
75+
76+
func NewClient() Client {
77+
return &httpClient{
78+
client: &http.Client{
79+
Timeout: 30 * time.Second,
80+
},
81+
configCache: make(map[string]*OIDCConfiguration),
82+
}
83+
}
84+
85+
func NewClientWithHTTPClient(c *http.Client) Client {
86+
return &httpClient{
87+
client: c,
88+
configCache: make(map[string]*OIDCConfiguration),
89+
}
90+
}
91+
92+
// Discover fetches the openid-configuration which contains all the routes a client should
93+
// use for authorization, device flows, tokens etc.
94+
//
95+
// Before making any requests, the configCache is checked and if there is a cache hit, the
96+
// cached config is returned.
97+
func (c *httpClient) Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error) {
98+
endpoint = strings.TrimRight(endpoint, "/")
99+
100+
if config, ok := c.configCache[endpoint]; ok {
101+
return config, nil
102+
}
103+
104+
reqURL := endpoint + wellKnownPath
105+
106+
req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
107+
if err != nil {
108+
return nil, errors.Wrap(err, "creating discovery request")
109+
}
110+
req.Header.Set("Accept", "application/json")
111+
112+
resp, err := c.client.Do(req)
113+
if err != nil {
114+
return nil, errors.Wrap(err, "discovery request failed")
115+
}
116+
defer resp.Body.Close()
117+
118+
body, err := io.ReadAll(resp.Body)
119+
if err != nil {
120+
return nil, errors.Wrap(err, "reading discovery response")
121+
}
122+
123+
if resp.StatusCode != http.StatusOK {
124+
return nil, errors.Newf("discovery failed with status %d: %s", resp.StatusCode, string(body))
125+
}
126+
127+
var config OIDCConfiguration
128+
if err := json.Unmarshal(body, &config); err != nil {
129+
return nil, errors.Wrap(err, "parsing discovery response")
130+
}
131+
132+
c.configCache[endpoint] = &config
133+
134+
return &config, nil
135+
}
136+
137+
// Start starts the OAuth device flow with the given endpoint. If no scopes are given the default scopes are used.
138+
//
139+
// Default Scopes: "openid" "profile" "email" "offline_access" "user:all"
140+
func (c *httpClient) Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error) {
141+
endpoint = strings.TrimRight(endpoint, "/")
142+
143+
// Discover OIDC configuration
144+
config, err := c.Discover(ctx, endpoint)
145+
if err != nil {
146+
return nil, errors.Wrap(err, "OIDC discovery failed")
147+
}
148+
149+
if config.DeviceAuthorizationEndpoint == "" {
150+
return nil, errors.New("device authorization endpoint not found in OIDC configuration; the server may not support device flow")
151+
}
152+
153+
data := url.Values{}
154+
data.Set("client_id", ClientID)
155+
if len(scopes) > 0 {
156+
data.Set("scope", strings.Join(scopes, " "))
157+
} else {
158+
data.Set("scope", strings.Join(defaultScopes, " "))
159+
}
160+
161+
req, err := http.NewRequestWithContext(ctx, "POST", config.DeviceAuthorizationEndpoint, strings.NewReader(data.Encode()))
162+
if err != nil {
163+
return nil, errors.Wrap(err, "creating device auth request")
164+
}
165+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
166+
req.Header.Set("Accept", "application/json")
167+
168+
resp, err := c.client.Do(req)
169+
if err != nil {
170+
return nil, errors.Wrap(err, "device auth request failed")
171+
}
172+
defer resp.Body.Close()
173+
174+
body, err := io.ReadAll(resp.Body)
175+
if err != nil {
176+
return nil, errors.Wrap(err, "reading device auth response")
177+
}
178+
179+
if resp.StatusCode != http.StatusOK {
180+
var errResp ErrorResponse
181+
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
182+
return nil, errors.Newf("device auth failed: %s: %s", errResp.Error, errResp.ErrorDescription)
183+
}
184+
return nil, errors.Newf("device auth failed with status %d: %s", resp.StatusCode, string(body))
185+
}
186+
187+
var authResp DeviceAuthResponse
188+
if err := json.Unmarshal(body, &authResp); err != nil {
189+
return nil, errors.Wrap(err, "parsing device auth response")
190+
}
191+
192+
return &authResp, nil
193+
}
194+
195+
// Poll polls the OAuth token endpoint until the device has been authorized or not
196+
//
197+
// We poll as long as the authorization is pending. If the server tells us to slow down, we will wait 5 secs extra.
198+
//
199+
// Polling will stop when:
200+
// - Device is authorized, and a token is returned
201+
// - Device code has expried
202+
// - User denied authorization
203+
func (c *httpClient) Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error) {
204+
endpoint = strings.TrimRight(endpoint, "/")
205+
206+
// Discover OIDC configuration (should be cached from Start)
207+
config, err := c.Discover(ctx, endpoint)
208+
if err != nil {
209+
return nil, errors.Wrap(err, "OIDC discovery failed")
210+
}
211+
212+
if config.TokenEndpoint == "" {
213+
return nil, errors.New("token endpoint not found in OIDC configuration")
214+
}
215+
216+
deadline := time.Now().Add(time.Duration(expiresIn) * time.Second)
217+
218+
for {
219+
if time.Now().After(deadline) {
220+
return nil, errors.New("device code expired")
221+
}
222+
223+
if !testing.Testing() {
224+
select {
225+
case <-ctx.Done():
226+
return nil, ctx.Err()
227+
case <-time.After(interval):
228+
}
229+
}
230+
231+
tokenResp, err := c.pollOnce(ctx, config.TokenEndpoint, deviceCode)
232+
if err != nil {
233+
var pollErr *PollError
234+
if errors.As(err, &pollErr) {
235+
switch pollErr.Code {
236+
case "authorization_pending":
237+
continue
238+
case "slow_down":
239+
interval += 5 * time.Second
240+
continue
241+
case "expired_token":
242+
return nil, errors.New("device code expired")
243+
case "access_denied":
244+
return nil, errors.New("authorization was denied by the user")
245+
}
246+
}
247+
return nil, err
248+
}
249+
250+
return tokenResp, nil
251+
}
252+
}
253+
254+
type PollError struct {
255+
Code string
256+
Description string
257+
}
258+
259+
func (e *PollError) Error() string {
260+
if e.Description != "" {
261+
return fmt.Sprintf("%s: %s", e.Code, e.Description)
262+
}
263+
return e.Code
264+
}
265+
266+
func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode string) (*TokenResponse, error) {
267+
data := url.Values{}
268+
data.Set("client_id", ClientID)
269+
data.Set("device_code", deviceCode)
270+
data.Set("grant_type", GrantTypeDeviceCode)
271+
272+
req, err := http.NewRequestWithContext(ctx, "POST", tokenEndpoint, strings.NewReader(data.Encode()))
273+
if err != nil {
274+
return nil, errors.Wrap(err, "creating token request")
275+
}
276+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
277+
req.Header.Set("Accept", "application/json")
278+
279+
resp, err := c.client.Do(req)
280+
if err != nil {
281+
return nil, errors.Wrap(err, "token request failed")
282+
}
283+
defer resp.Body.Close()
284+
285+
body, err := io.ReadAll(resp.Body)
286+
if err != nil {
287+
return nil, errors.Wrap(err, "reading token response")
288+
}
289+
290+
if resp.StatusCode != http.StatusOK {
291+
var errResp ErrorResponse
292+
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != "" {
293+
return nil, &PollError{Code: errResp.Error, Description: errResp.ErrorDescription}
294+
}
295+
return nil, errors.Newf("token request failed with status %d: %s", resp.StatusCode, string(body))
296+
}
297+
298+
var tokenResp TokenResponse
299+
if err := json.Unmarshal(body, &tokenResp); err != nil {
300+
return nil, errors.Wrap(err, "parsing token response")
301+
}
302+
303+
return &tokenResp, nil
304+
}

0 commit comments

Comments
 (0)