Skip to content

Commit 2cc2c38

Browse files
feat: add oauth2 and ecommerce test apis (#17)
OAuth2 -------- This change adds a new OAuth2 middleware and token endpoint that can be used to enforce OAuth2 on any test endpoints. Test code can opt into OAuth2 enforcement by setting the `x-require-oauth2: 1` header on any SDK methods. This is usually done with a custom http client, per-method custom headers or a SDK-wide hook depending on the language and test. The following OAuth2 flows are supported: - Client Credentials - Authorization Code - Resource Owner Password Credentials - Refresh Token When requesting an access token against `POST /oauth2/token`, the following credentials are expected based on the flow used: - Client ID: `beezy` - Client Secret: `super-secret` - Username: `testuser` - Password: `testpassword` - Code: `secret-auth-code` - Refresh Token: `secret-refresh-token` Test code can look for the `x-oauth2: pass` header to determine if the test service detected and validated OAuth2 access tokens in the response of any non-auth endpoints E-commerce API ----------------- This is a new set of test APIs that are inspired by real world e-commerce APIs. These APIs can enforce OAuth2 scopes if incoming requests have passed the OAuth2 middleware.
1 parent ea86c3a commit 2cc2c38

File tree

7 files changed

+712
-3
lines changed

7 files changed

+712
-3
lines changed

cmd/server/main.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package main
22

33
import (
4+
"context"
45
"flag"
56
"log"
67
"net/http"
78

89
"github.com/speakeasy-api/speakeasy-api-test-service/internal/acceptHeaders"
910
"github.com/speakeasy-api/speakeasy-api-test-service/internal/clientcredentials"
11+
"github.com/speakeasy-api/speakeasy-api-test-service/internal/ecommerce"
1012
"github.com/speakeasy-api/speakeasy-api-test-service/internal/errors"
1113
"github.com/speakeasy-api/speakeasy-api-test-service/internal/eventstreams"
1214
"github.com/speakeasy-api/speakeasy-api-test-service/internal/method"
@@ -28,11 +30,12 @@ func main() {
2830
flag.Parse()
2931

3032
r := mux.NewRouter()
33+
r.HandleFunc("/oauth2/token", auth.HandleOAuth2).Methods(http.MethodPost)
34+
r.HandleFunc("/auth", auth.HandleAuth).Methods(http.MethodPost)
35+
r.HandleFunc("/auth/customsecurity/{customSchemeType}", auth.HandleCustomAuth).Methods(http.MethodGet)
3136
r.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
3237
_, _ = w.Write([]byte("pong"))
3338
}).Methods(http.MethodGet)
34-
r.HandleFunc("/auth", auth.HandleAuth).Methods(http.MethodPost)
35-
r.HandleFunc("/auth/customsecurity/{customSchemeType}", auth.HandleCustomAuth).Methods(http.MethodGet)
3639
r.HandleFunc("/requestbody", requestbody.HandleRequestBody).Methods(http.MethodPost)
3740
r.HandleFunc("/vendorjson", responseHeaders.HandleVendorJsonResponseHeaders).Methods(http.MethodGet)
3841
r.HandleFunc("/pagination/limitoffset/page", pagination.HandleLimitOffsetPage).Methods(http.MethodGet, http.MethodPut)
@@ -69,6 +72,14 @@ func main() {
6972
r.HandleFunc("/method/put", method.HandlePut).Methods(http.MethodPut)
7073
r.HandleFunc("/method/trace", method.HandleTrace).Methods(http.MethodTrace)
7174

75+
oauth2router := r.NewRoute().Subrouter()
76+
oauth2router.Use(middleware.OAuth2)
77+
oauth2router.HandleFunc("/ecommerce/products", ecommerce.HandleListProducts).Methods(http.MethodGet)
78+
oauth2router.HandleFunc("/ecommerce/products", ecommerce.HandleCreateProduct).Methods(http.MethodPost)
79+
oauth2router.HandleFunc("/ecommerce/products/{id}", ecommerce.HandleFetchProduct).Methods(http.MethodGet)
80+
oauth2router.HandleFunc("/ecommerce/products/{id}", ecommerce.HandleDeleteProduct).Methods(http.MethodDelete)
81+
oauth2router.HandleFunc("/ecommerce/products/{id}/inventory", ecommerce.HandleUpdateProductStock).Methods(http.MethodPut)
82+
7283
handler := middleware.Fault(r)
7384
handler = middleware.Teapot(handler)
7485

@@ -77,6 +88,10 @@ func main() {
7788
bind = *bindArg
7889
}
7990

91+
ctx, cancel := context.WithCancel(context.Background())
92+
defer cancel()
93+
go auth.StartTokenDBCompaction(ctx)
94+
8095
log.Printf("Listening on %s\n", bind)
8196
if err := http.ListenAndServe(bind, handler); err != nil {
8297
log.Fatal(err)

go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
module github.com/speakeasy-api/speakeasy-api-test-service
22

3-
go 1.22
3+
go 1.23
44

55
require (
6+
github.com/brianvoe/gofakeit/v7 v7.0.4
7+
github.com/golang-jwt/jwt/v5 v5.2.1
68
github.com/gorilla/mux v1.8.0
79
github.com/lingrino/go-fault v1.0.2
810
)

go.sum

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
github.com/brianvoe/gofakeit/v7 v7.0.4 h1:Mkxwz9jYg8Ad8NvT9HA27pCMZGFQo08MK6jD0QTKEww=
2+
github.com/brianvoe/gofakeit/v7 v7.0.4/go.mod h1:QXuPeBw164PJCzCUZVmgpgHJ3Llj49jSLVkKPMtxtxA=
13
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
24
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
5+
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
6+
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
37
github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI=
48
github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So=
59
github.com/lingrino/go-fault v1.0.2 h1:I7gj2vsxw0wdOwQIX7AZ7kdZRXPX2AgVvRyFXCqSvLA=

internal/auth/oauth2.go

Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
package auth
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"log"
8+
"net/http"
9+
"sync"
10+
"sync/atomic"
11+
"time"
12+
13+
"github.com/brianvoe/gofakeit/v7"
14+
"github.com/golang-jwt/jwt/v5"
15+
)
16+
17+
const (
18+
oauth2JWTSigningSecret = "fancy-jwt-signing-secret"
19+
)
20+
21+
type OAuth2ErrorCode string
22+
23+
const (
24+
ErrCodeInvalidRequest OAuth2ErrorCode = "invalid_request"
25+
ErrCodeInvalidClient OAuth2ErrorCode = "invalid_client"
26+
ErrCodeInvalidGrant OAuth2ErrorCode = "invalid_grant"
27+
ErrCodeUnauthorizedClient OAuth2ErrorCode = "unauthorized_client"
28+
ErrCodeUnsupportedGrantType OAuth2ErrorCode = "unsupported_grant_type"
29+
ErrCodeInvalidScope OAuth2ErrorCode = "invalid_scope"
30+
)
31+
32+
func (e OAuth2ErrorCode) Error() string {
33+
return string(e)
34+
}
35+
36+
type oauth2Error struct {
37+
Code OAuth2ErrorCode `json:"error"`
38+
Message string `json:"error_description,omitempty"`
39+
}
40+
41+
func SendOAuth2Error(w http.ResponseWriter, code OAuth2ErrorCode, description string) {
42+
payload := oauth2Error{
43+
Code: code,
44+
Message: description,
45+
}
46+
47+
enc := json.NewEncoder(w)
48+
enc.SetIndent("", " ")
49+
w.Header().Set("Content-Type", "application/json")
50+
w.WriteHeader(http.StatusBadRequest)
51+
if err := enc.Encode(payload); err != nil {
52+
log.Println(err)
53+
54+
http.Error(w, `{"error": "failed to encode response"}`, http.StatusInternalServerError)
55+
return
56+
}
57+
}
58+
59+
type TokenForm struct {
60+
GrantType string `json:"grant_type"`
61+
ClientID string `json:"client_id"`
62+
ClientSecret string `json:"client_secret"`
63+
Username string `json:"username"`
64+
Password string `json:"password"`
65+
Code string `json:"code"`
66+
RedirectURI string `json:"redirect_uri"`
67+
RefreshToken string `json:"refresh_token"`
68+
Scope string `json:"scope"`
69+
}
70+
71+
type OAuth2TokenResponse struct {
72+
AccessToken string `json:"access_token"`
73+
RefreshToken string `json:"refresh_token,omitempty"`
74+
TokenType string `json:"token_type"`
75+
ExpiresIn int `json:"expires_in"`
76+
}
77+
78+
func HandleOAuth2(w http.ResponseWriter, r *http.Request) {
79+
enc := json.NewEncoder(w)
80+
enc.SetIndent("", " ")
81+
82+
w.Header().Set("Content-Type", "application/json")
83+
84+
defer r.Body.Close()
85+
86+
if err := r.ParseForm(); err != nil {
87+
log.Println(err)
88+
SendOAuth2Error(w, ErrCodeInvalidRequest, "cannot parse url-encoded request body")
89+
return
90+
}
91+
92+
form := TokenForm{
93+
GrantType: r.PostForm.Get("grant_type"),
94+
ClientID: r.PostForm.Get("client_id"),
95+
ClientSecret: r.PostForm.Get("client_secret"),
96+
Username: r.PostForm.Get("username"),
97+
Password: r.PostForm.Get("password"),
98+
Code: r.PostForm.Get("code"),
99+
RedirectURI: r.PostForm.Get("redirect_uri"),
100+
RefreshToken: r.PostForm.Get("refresh_token"),
101+
Scope: r.PostForm.Get("scope"),
102+
}
103+
104+
switch form.GrantType {
105+
case "client_credentials":
106+
if form.ClientID == "" || form.ClientSecret == "" {
107+
SendOAuth2Error(w, ErrCodeInvalidRequest, "missing client credentials")
108+
return
109+
}
110+
111+
if !validateClientCredentials(r, form) {
112+
SendOAuth2Error(w, ErrCodeInvalidClient, "invalid client id or secret")
113+
return
114+
}
115+
case "password":
116+
if form.ClientID == "" || form.ClientSecret == "" || form.Username == "" || form.Password == "" {
117+
SendOAuth2Error(w, ErrCodeInvalidRequest, "missing resource owner password credentials")
118+
return
119+
}
120+
if !validateClientCredentials(r, form) {
121+
SendOAuth2Error(w, ErrCodeInvalidClient, "invalid client id or secret")
122+
return
123+
}
124+
if form.Username != "testuser" || form.Password != "testpassword" {
125+
SendOAuth2Error(w, ErrCodeInvalidGrant, "invalid username or password")
126+
return
127+
}
128+
case "authorization_code":
129+
if form.ClientID == "" || form.Code == "" {
130+
SendOAuth2Error(w, ErrCodeInvalidRequest, "missing authorization code credentials")
131+
return
132+
}
133+
if form.ClientID != "beezy" {
134+
SendOAuth2Error(w, ErrCodeInvalidClient, "invalid client id")
135+
return
136+
}
137+
if form.Code != "secret-auth-code" {
138+
SendOAuth2Error(w, ErrCodeInvalidGrant, "")
139+
return
140+
}
141+
case "refresh_token":
142+
if form.RefreshToken == "" {
143+
SendOAuth2Error(w, ErrCodeInvalidRequest, "missing refresh token")
144+
return
145+
}
146+
147+
if !validateClientCredentials(r, form) {
148+
SendOAuth2Error(w, ErrCodeInvalidGrant, "invalid client id or secret")
149+
return
150+
}
151+
152+
rt, err := ParseToken(form.RefreshToken)
153+
if err != nil {
154+
SendOAuth2Error(w, ErrCodeInvalidRequest, "invalid refresh token")
155+
return
156+
}
157+
if IsTokenExpired(rt) {
158+
SendOAuth2Error(w, ErrCodeInvalidGrant, "refresh token has expired")
159+
return
160+
}
161+
default:
162+
SendOAuth2Error(w, ErrCodeUnsupportedGrantType, "unsupported grant type")
163+
return
164+
}
165+
166+
now := time.Now()
167+
expires := now.Add(time.Hour)
168+
169+
accessTokenID := gofakeit.UUID()
170+
accessTokenClaims := jwt.MapClaims{
171+
"exp": float64(expires.Unix()),
172+
"id": accessTokenID,
173+
"grantType": form.GrantType,
174+
"clientID": form.ClientID,
175+
"username": form.Username,
176+
"scope": form.Scope,
177+
}
178+
accessToken := jwt.NewWithClaims(jwt.SigningMethodHS256, accessTokenClaims)
179+
signedAccessToken, err := accessToken.SignedString([]byte(oauth2JWTSigningSecret))
180+
if err != nil {
181+
log.Println(err)
182+
SendOAuth2Error(w, ErrCodeInvalidRequest, err.Error())
183+
return
184+
}
185+
186+
refreshTokenClaims := jwt.MapClaims{
187+
"exp": float64(now.Add(60 * 24 * time.Hour).Unix()),
188+
"id": gofakeit.UUID(),
189+
"sub": accessTokenID,
190+
"grantType": "refresh_token",
191+
"clientID": form.ClientID,
192+
"username": form.Username,
193+
"scope": form.Scope,
194+
}
195+
refreshToken := jwt.NewWithClaims(jwt.SigningMethodHS256, refreshTokenClaims)
196+
signedRefreshToken, err := refreshToken.SignedString([]byte(oauth2JWTSigningSecret))
197+
if err != nil {
198+
log.Println(err)
199+
http.Error(w, `{"error": "failed to sign refresh token"}`, http.StatusInternalServerError)
200+
return
201+
}
202+
203+
res := OAuth2TokenResponse{
204+
AccessToken: signedAccessToken,
205+
RefreshToken: signedRefreshToken,
206+
TokenType: "Bearer",
207+
ExpiresIn: max(int(expires.Sub(now).Seconds()), 0),
208+
}
209+
210+
RegisterToken(accessTokenClaims)
211+
RegisterToken(refreshTokenClaims)
212+
213+
if err := enc.Encode(res); err != nil {
214+
http.Error(w, `{"error": "failed to encode response"}`, http.StatusInternalServerError)
215+
return
216+
}
217+
}
218+
219+
func validateClientCredentials(r *http.Request, form TokenForm) bool {
220+
clientID := form.ClientID
221+
clientSecret := form.ClientSecret
222+
if clientID == "" && clientSecret == "" {
223+
clientID, clientSecret, _ = r.BasicAuth()
224+
}
225+
226+
return clientID == "beezy" && clientSecret == "super-secret"
227+
}
228+
229+
var tokenDB sync.Map
230+
var tokenDBLastAccess atomic.Value
231+
232+
func RegisterToken(tokenClaims jwt.MapClaims) {
233+
tokenDBLastAccess.Store(time.Now())
234+
235+
tokenID := tokenClaims["id"].(string)
236+
expiry, err := tokenClaims.GetExpirationTime()
237+
if err != nil {
238+
panic(err)
239+
}
240+
tokenDB.Store(tokenID, expiry.Time)
241+
}
242+
243+
func RefreshToken(refreshClaims jwt.MapClaims) {
244+
tokenDBLastAccess.Store(time.Now())
245+
246+
tokenID := refreshClaims["sub"].(string)
247+
expiry := time.Now().Add(time.Hour)
248+
tokenDB.Store(tokenID, expiry)
249+
}
250+
251+
func IsTokenExpired(tokenClaims jwt.MapClaims) bool {
252+
tokenDBLastAccess.Store(time.Now())
253+
254+
tokenID := tokenClaims["id"].(string)
255+
expiryClaim, err := tokenClaims.GetExpirationTime()
256+
if err != nil {
257+
panic(err)
258+
}
259+
260+
exp, found := tokenDB.Load(tokenID)
261+
if !found {
262+
RegisterToken(tokenClaims)
263+
exp = expiryClaim.Time
264+
}
265+
266+
expiry := exp.(time.Time)
267+
268+
return expiry.Before(time.Now())
269+
}
270+
271+
func StartTokenDBCompaction(ctx context.Context) {
272+
ticker := time.NewTicker(time.Minute)
273+
274+
for {
275+
select {
276+
case <-ctx.Done():
277+
return
278+
case <-ticker.C:
279+
lastAccess, ok := tokenDBLastAccess.Load().(time.Time)
280+
now := time.Now()
281+
if !ok {
282+
lastAccess = now
283+
}
284+
285+
delta := now.Sub(lastAccess)
286+
if delta > 5*time.Minute {
287+
tokenDB.Clear()
288+
}
289+
}
290+
}
291+
}
292+
293+
func ParseToken(encodedToken string) (jwt.MapClaims, error) {
294+
parser := jwt.NewParser(jwt.WithoutClaimsValidation(), jwt.WithValidMethods([]string{"HS256"}))
295+
token, err := parser.Parse(encodedToken, func(token *jwt.Token) (any, error) {
296+
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
297+
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
298+
}
299+
300+
return []byte(oauth2JWTSigningSecret), nil
301+
})
302+
if err != nil {
303+
return nil, err
304+
}
305+
306+
claims, ok := token.Claims.(jwt.MapClaims)
307+
if !ok {
308+
return nil, fmt.Errorf("invalid access token string")
309+
}
310+
311+
return claims, nil
312+
}

0 commit comments

Comments
 (0)