-
Notifications
You must be signed in to change notification settings - Fork 73
Add support for catalog roles #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
081cd85
6ce2261
66fb9b9
6bfe79c
2c413ad
79b3a4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,7 @@ const ( | |
trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session` | ||
trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session` | ||
trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role` | ||
trinoRoleHeader = trinoHeaderPrefix + `Role` | ||
trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential` | ||
|
||
trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback` | ||
|
@@ -153,6 +154,7 @@ const ( | |
|
||
mapKeySeparator = ":" | ||
mapEntrySeparator = ";" | ||
mapCommaSeparator = "," | ||
) | ||
|
||
var ( | ||
|
@@ -194,6 +196,7 @@ type Config struct { | |
AccessToken string // An access token (JWT) for authentication (optional) | ||
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) | ||
QueryTimeout *time.Duration // Configurable timeout for query (optional) | ||
Roles map[string]string // Roles (optional) | ||
} | ||
|
||
// FormatDSN returns a DSN string from the configuration. | ||
|
@@ -214,6 +217,14 @@ func (c *Config) FormatDSN() (string, error) { | |
credkv = append(credkv, k+mapKeySeparator+v) | ||
} | ||
} | ||
|
||
var roles []string | ||
if c.Roles != nil { | ||
for k, v := range c.Roles { | ||
roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", k, v)) | ||
} | ||
} | ||
|
||
source := c.Source | ||
if source == "" { | ||
source = "trino-go-client" | ||
|
@@ -284,6 +295,7 @@ func (c *Config) FormatDSN() (string, error) { | |
"extra_credentials": strings.Join(credkv, mapEntrySeparator), | ||
"custom_client": c.CustomClientName, | ||
accessTokenConfig: c.AccessToken, | ||
"roles": strings.Join(roles, mapCommaSeparator), | ||
} { | ||
if v != "" { | ||
query[k] = []string{v} | ||
|
@@ -307,6 +319,7 @@ type Conn struct { | |
useExplicitPrepare bool | ||
forwardAuthorizationHeader bool | ||
queryTimeout *time.Duration | ||
Roles string | ||
} | ||
|
||
var ( | ||
|
@@ -390,6 +403,26 @@ func newConn(dsn string) (*Conn, error) { | |
queryTimeout = &d | ||
} | ||
|
||
var formatedRoles string | ||
if rolesStr := query.Get("roles"); rolesStr != "" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would we allow overriding roles in query parameters, instead of setting them for the whole connection? I think this should be only possible using SQL statements, and we actually should handle the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, if we want to support it via DNS and use the JDBC format, we’ll need to find a way to convert it to the correct format. So, are you suggesting it should only be sent as an SQL statement? I’ll make the changes. 👍🏻 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I got confused with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AAAh I know what I have done wrong ! So here I am validation if is not on the right format already. I will change it, I will convert it always to convert to jdbc format and then when it reach to this point I will convert it to the format that trino is waiting for: |
||
if !strings.Contains(rolesStr, "=ROLE{") { | ||
roles := []string{} | ||
rolesToFormat := strings.Split(rolesStr, ";") | ||
|
||
for _, role := range rolesToFormat { | ||
parts := strings.Split(role, ":") | ||
if len(parts) != 2 { | ||
return nil, fmt.Errorf("Invalid role format: %s", role) | ||
} | ||
roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", parts[0], parts[1])) | ||
} | ||
|
||
formatedRoles = strings.Join(roles, mapCommaSeparator) | ||
} else { | ||
formatedRoles = rolesStr | ||
} | ||
} | ||
|
||
c := &Conn{ | ||
baseURL: serverURL.Scheme + "://" + serverURL.Host, | ||
httpClient: *httpClient, | ||
|
@@ -400,6 +433,7 @@ func newConn(dsn string) (*Conn, error) { | |
useExplicitPrepare: useExplicitPrepare, | ||
forwardAuthorizationHeader: forwardAuthorizationHeader, | ||
queryTimeout: queryTimeout, | ||
Roles: formatedRoles, | ||
} | ||
|
||
var user string | ||
|
@@ -931,6 +965,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt | |
// Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3). | ||
hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME") | ||
|
||
if st.conn.Roles != "" { | ||
hs.Add(trinoRoleHeader, st.conn.Roles) | ||
} | ||
|
||
if len(args) > 0 { | ||
var ss []string | ||
for _, arg := range args { | ||
|
Uh oh!
There was an error while loading. Please reload this page.