Skip to content

Commit 7adc7f5

Browse files
client-credentials-generic
Summary: - Support for configurable `oauth2` client credentials pattern.
1 parent 1f02226 commit 7adc7f5

File tree

7 files changed

+140
-16
lines changed

7 files changed

+140
-16
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ require (
2121
github.com/spf13/cobra v1.4.0
2222
github.com/spf13/pflag v1.0.5
2323
github.com/spf13/viper v1.10.1
24-
github.com/stackql/any-sdk v0.0.3-beta17
24+
github.com/stackql/any-sdk v0.0.3-beta20
2525
github.com/stackql/go-suffix-map v0.0.1-alpha01
2626
github.com/stackql/psql-wire v0.1.1-alpha07
2727
github.com/stackql/stackql-parser v0.0.14-alpha04

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
471471
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
472472
github.com/spf13/viper v1.10.1 h1:nuJZuYpG7gTj/XqiUwg8bA0cp1+M2mC3J4g5luUYBKk=
473473
github.com/spf13/viper v1.10.1/go.mod h1:IGlFPqhNAPKRxohIzWpI5QEy4kuI7tcl5WvR+8qy1rU=
474-
github.com/stackql/any-sdk v0.0.3-beta17 h1:eajJfNOZBvWwaUD9WdsS81PAKHFxyHndmpdOo3P867A=
475-
github.com/stackql/any-sdk v0.0.3-beta17/go.mod h1:CIMFo3fC2ScpqzkzeCkzUQQuzYA1VuqpG0p1EZXN+wY=
474+
github.com/stackql/any-sdk v0.0.3-beta20 h1:7zHdJp0gM9G8vr5IDN/e8H74gqWI2MkWXGrfkQp6irA=
475+
github.com/stackql/any-sdk v0.0.3-beta20/go.mod h1:CIMFo3fC2ScpqzkzeCkzUQQuzYA1VuqpG0p1EZXN+wY=
476476
github.com/stackql/go-suffix-map v0.0.1-alpha01 h1:TDUDS8bySu41Oo9p0eniUeCm43mnRM6zFEd6j6VUaz8=
477477
github.com/stackql/go-suffix-map v0.0.1-alpha01/go.mod h1:QAi+SKukOyf4dBtWy8UMy+hsXXV+yyEE4vmBkji2V7g=
478478
github.com/stackql/psql-wire v0.1.1-alpha07 h1:LQWVUlx4Bougk6dztDNG5tmXxpIVeeTSsInTj801xCs=

internal/stackql/dto/auth_ctx.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dto
33
import (
44
"encoding/base64"
55
"fmt"
6+
"net/url"
67
"os"
78
"strings"
89
)
@@ -42,6 +43,14 @@ type AuthCtx struct {
4243
Active bool `json:"-" yaml:"-"`
4344
Location string `json:"location" yaml:"location"`
4445
Name string `json:"name" yaml:"name"`
46+
TokenURL string `json:"token_url" yaml:"token_url"`
47+
GrantType string `json:"grant_type" yaml:"grant_type"`
48+
ClientID string `json:"client_id" yaml:"client_id"`
49+
ClientSecret string `json:"client_secret" yaml:"client_secret"`
50+
ClientIDEnvVar string `json:"client_id_env_var" yaml:"client_id_env_var"`
51+
ClientSecretEnvVar string `json:"client_secret_env_var" yaml:"client_secret_env_var"`
52+
Values url.Values `json:"values" yaml:"values"`
53+
AuthStyle int `json:"auth_style" yaml:"auth_style"`
4554
}
4655

4756
func (ac *AuthCtx) GetSQLCfg() (SQLBackendCfg, bool) {
@@ -78,10 +87,26 @@ func (ac *AuthCtx) Clone() *AuthCtx {
7887
EncodedBasicCredentials: ac.EncodedBasicCredentials,
7988
Location: ac.Location,
8089
Name: ac.Name,
90+
Subject: ac.Subject,
91+
TokenURL: ac.TokenURL,
92+
GrantType: ac.GrantType,
93+
ClientID: ac.ClientID,
94+
ClientSecret: ac.ClientSecret,
95+
ClientIDEnvVar: ac.ClientIDEnvVar,
96+
ClientSecretEnvVar: ac.ClientSecretEnvVar,
97+
Values: ac.Values,
98+
AuthStyle: ac.AuthStyle,
8199
}
82100
return rv
83101
}
84102

103+
func (ac *AuthCtx) GetValues() url.Values {
104+
if ac.Values == nil {
105+
return url.Values{}
106+
}
107+
return ac.Values
108+
}
109+
85110
func (ac *AuthCtx) GetSuccessor() (*AuthCtx, bool) {
86111
if ac.Successor != nil {
87112
return ac.Successor, true
@@ -188,6 +213,46 @@ func (ac *AuthCtx) GetCredentialsBytes() ([]byte, error) {
188213
return nil, fmt.Errorf("no credentials found")
189214
}
190215

216+
func (ac *AuthCtx) GetClientID() (string, error) {
217+
if ac.ClientIDEnvVar != "" {
218+
rv := os.Getenv(ac.ClientIDEnvVar)
219+
if rv == "" {
220+
return "", fmt.Errorf("client_id_env_var references empty string")
221+
}
222+
return rv, nil
223+
}
224+
if ac.ClientID == "" {
225+
return "", fmt.Errorf("client_id is empty")
226+
}
227+
return ac.ClientID, nil
228+
}
229+
230+
func (ac *AuthCtx) GetClientSecret() (string, error) {
231+
if ac.ClientSecretEnvVar != "" {
232+
rv := os.Getenv(ac.ClientSecretEnvVar)
233+
if rv == "" {
234+
return "", fmt.Errorf("client_secret_env_var references empty string")
235+
}
236+
return rv, nil
237+
}
238+
if ac.ClientSecret == "" {
239+
return "", fmt.Errorf("client_secret is empty")
240+
}
241+
return ac.ClientSecret, nil
242+
}
243+
244+
func (ac *AuthCtx) GetGrantType() string {
245+
return ac.GrantType
246+
}
247+
248+
func (ac *AuthCtx) GetTokenURL() string {
249+
return ac.TokenURL
250+
}
251+
252+
func (ac *AuthCtx) GetAuthStyle() int {
253+
return ac.AuthStyle
254+
}
255+
191256
func (ac *AuthCtx) GetCredentialsSourceDescriptorString() string {
192257
if ac.KeyEnvVar != "" {
193258
return fmt.Sprintf("credentialsenvvar:%s", ac.KeyEnvVar)

internal/stackql/dto/dto.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ const (
1616
APIRequestTimeoutKey string = "apirequesttimeout"
1717
CacheKeyCountKey string = "cachekeycount"
1818
CacheTTLKey string = "metadatattl"
19+
ClientCredentialsStr string = "client_credentials"
1920
ColorSchemeKey string = "colorscheme" // deprecated
2021
ConfigFilePathKey string = "configfile"
2122
CPUProfileKey string = "cpuprofile"
@@ -37,6 +38,7 @@ const (
3738
AllowInsecureKey string = "tls.allowInsecure"
3839
InfilePathKey string = "infile"
3940
LogLevelStrKey string = "loglevel"
41+
OAuth2Str string = "oauth2"
4042
OutfilePathKey string = "outfile"
4143
OutputFormatKey string = "output"
4244
ApplicationFilesRootPathKey string = "approot"

internal/stackql/handler/handler.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,14 @@ func transformOpenapiStackqlAuthToLocal(authDTO anysdk.AuthDTO) *dto.AuthCtx {
575575
EncodedBasicCredentials: authDTO.GetInlineBasicCredentials(),
576576
Location: authDTO.GetLocation(),
577577
Name: authDTO.GetName(),
578+
TokenURL: authDTO.GetTokenURL(),
579+
GrantType: authDTO.GetGrantType(),
580+
ClientID: authDTO.GetClientID(),
581+
ClientSecret: authDTO.GetClientSecret(),
582+
ClientIDEnvVar: authDTO.GetClientIDEnvVar(),
583+
ClientSecretEnvVar: authDTO.GetClientSecretEnvVar(),
584+
Values: authDTO.GetValues(),
585+
AuthStyle: authDTO.GetAuthStyle(),
578586
}
579587
successor, successorExists := authDTO.GetSuccessor()
580588
currentParent := rv

internal/stackql/provider/auth_util.go

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"regexp"
1616

1717
"golang.org/x/oauth2"
18+
"golang.org/x/oauth2/clientcredentials"
1819
"golang.org/x/oauth2/google"
1920
"golang.org/x/oauth2/jwt"
2021
)
@@ -161,11 +162,16 @@ func parseServiceAccountFile(ac *dto.AuthCtx) (serviceAccount, error) {
161162
return c, json.Unmarshal(b, &c)
162163
}
163164

164-
func getJWTConfig(provider string, credentialsBytes []byte, scopes []string, subject string) (*jwt.Config, error) {
165+
func getGoogleJWTConfig(provider string, credentialsBytes []byte, scopes []string, subject string) (*jwt.Config, error) {
165166
switch provider {
166167
case "google", "googleads", "googleanalytics",
167168
"googledevelopers", "googlemybusiness", "googleworkspace",
168169
"youtube", "googleadmin":
170+
if scopes == nil {
171+
scopes = []string{
172+
"https://www.googleapis.com/auth/cloud-platform",
173+
}
174+
}
169175
rv, err := google.JWTConfigFromJSON(credentialsBytes, scopes...)
170176
if err != nil {
171177
return nil, err
@@ -179,7 +185,31 @@ func getJWTConfig(provider string, credentialsBytes []byte, scopes []string, sub
179185
}
180186
}
181187

182-
func oauthServiceAccount(
188+
func getGenericClientCredentialsConfig(authCtx *dto.AuthCtx, scopes []string) (*clientcredentials.Config, error) {
189+
clientID, clientIDErr := authCtx.GetClientID()
190+
if clientIDErr != nil {
191+
return nil, clientIDErr
192+
}
193+
clientSecret, secretErr := authCtx.GetClientSecret()
194+
if secretErr != nil {
195+
return nil, secretErr
196+
}
197+
rv := &clientcredentials.Config{
198+
ClientID: clientID,
199+
ClientSecret: clientSecret,
200+
Scopes: scopes,
201+
TokenURL: authCtx.GetTokenURL(),
202+
}
203+
if len(authCtx.GetValues()) > 0 {
204+
rv.EndpointParams = authCtx.GetValues()
205+
}
206+
if authCtx.GetAuthStyle() > 0 {
207+
rv.AuthStyle = oauth2.AuthStyle(authCtx.GetAuthStyle())
208+
}
209+
return rv, nil
210+
}
211+
212+
func googleOauthServiceAccount(
183213
provider string,
184214
authCtx *dto.AuthCtx,
185215
scopes []string,
@@ -189,14 +219,27 @@ func oauthServiceAccount(
189219
if err != nil {
190220
return nil, fmt.Errorf("service account credentials error: %w", err)
191221
}
192-
config, errToken := getJWTConfig(provider, b, scopes, authCtx.Subject)
222+
config, errToken := getGoogleJWTConfig(provider, b, scopes, authCtx.Subject)
193223
if errToken != nil {
194224
return nil, errToken
195225
}
196226
activateAuth(authCtx, "", dto.AuthServiceAccountStr)
197227
httpClient := netutils.GetHTTPClient(runtimeCtx, http.DefaultClient)
198-
//nolint:staticcheck // TODO: fix this
199-
return config.Client(context.WithValue(oauth2.NoContext, oauth2.HTTPClient, httpClient)), nil
228+
return config.Client(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)), nil
229+
}
230+
231+
func genericOauthClientCredentials(
232+
authCtx *dto.AuthCtx,
233+
scopes []string,
234+
runtimeCtx dto.RuntimeCtx,
235+
) (*http.Client, error) {
236+
config, errToken := getGenericClientCredentialsConfig(authCtx, scopes)
237+
if errToken != nil {
238+
return nil, errToken
239+
}
240+
activateAuth(authCtx, "", dto.ClientCredentialsStr)
241+
httpClient := netutils.GetHTTPClient(runtimeCtx, http.DefaultClient)
242+
return config.Client(context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)), nil
200243
}
201244

202245
func apiTokenAuth(authCtx *dto.AuthCtx, runtimeCtx dto.RuntimeCtx, enforceBearer bool) (*http.Client, error) {

internal/stackql/provider/generic.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ func (gp *GenericProvider) inferAuthType(authCtx dto.AuthCtx, authTypeRequested
8989
return dto.AuthAWSSigningv4Str
9090
case dto.AuthCustomStr:
9191
return dto.AuthCustomStr
92+
case dto.OAuth2Str:
93+
return dto.OAuth2Str
9294
}
9395
if authCtx.KeyFilePath != "" || authCtx.KeyEnvVar != "" {
9496
return dto.AuthServiceAccountStr
@@ -109,7 +111,11 @@ func (gp *GenericProvider) Auth(
109111
case dto.AuthBearerStr:
110112
return gp.apiTokenFileAuth(authCtx, true)
111113
case dto.AuthServiceAccountStr:
112-
return gp.keyFileAuth(authCtx)
114+
return gp.googleKeyFileAuth(authCtx)
115+
case dto.OAuth2Str:
116+
if authCtx.GrantType == dto.ClientCredentialsStr {
117+
return gp.clientCredentialsAuth(authCtx)
118+
}
113119
case dto.AuthBasicStr:
114120
return gp.basicAuth(authCtx)
115121
case dto.AuthCustomStr:
@@ -269,14 +275,14 @@ func (gp *GenericProvider) oAuth(authCtx *dto.AuthCtx, enforceRevokeFirst bool)
269275
return client, nil
270276
}
271277

272-
func (gp *GenericProvider) keyFileAuth(authCtx *dto.AuthCtx) (*http.Client, error) {
278+
func (gp *GenericProvider) googleKeyFileAuth(authCtx *dto.AuthCtx) (*http.Client, error) {
273279
scopes := authCtx.Scopes
274-
if scopes == nil {
275-
scopes = []string{
276-
"https://www.googleapis.com/auth/cloud-platform",
277-
}
278-
}
279-
return oauthServiceAccount(gp.GetProviderString(), authCtx, scopes, gp.runtimeCtx)
280+
return googleOauthServiceAccount(gp.GetProviderString(), authCtx, scopes, gp.runtimeCtx)
281+
}
282+
283+
func (gp *GenericProvider) clientCredentialsAuth(authCtx *dto.AuthCtx) (*http.Client, error) {
284+
scopes := authCtx.Scopes
285+
return genericOauthClientCredentials(authCtx, scopes, gp.runtimeCtx)
280286
}
281287

282288
func (gp *GenericProvider) apiTokenFileAuth(authCtx *dto.AuthCtx, enforceBearer bool) (*http.Client, error) {

0 commit comments

Comments
 (0)