Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,30 @@ Default: nil

The `queryTimeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used.


#### `roles`

```
Type: string
Format: roles=catalog1:role1;catalog2=role2
Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role.
Default: empty
```
The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session.

You can assign roles either as a map of catalog-to-role pairs or a string direcly in the dns connection.

##### Example
``` go
c := &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
}

dsn, err := c.FormatDSN()
```

#### Examples

```
Expand All @@ -259,6 +283,11 @@ http://user@localhost:8080?source=hello&catalog=default&schema=foobar
https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2
```


```
http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2
```

## Data types

### Query arguments
Expand Down
72 changes: 72 additions & 0 deletions trino/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/golang-jwt/jwt/v5"
dt "github.com/ory/dockertest/v3"
docker "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/require"
)

const (
Expand Down Expand Up @@ -1024,6 +1025,77 @@ func TestIntegrationNoResults(t *testing.T) {
t.Fatal(err)
}
}
func TestRoleHeaderSupport(t *testing.T) {
tests := []struct {
name string
config Config
rawDSN string
expectError bool
errorSubstr string
}{
{
name: "Valid roles via Config",
config: Config{
ServerURI: *integrationServerFlag,
Roles: map[string]string{"tpch": "role1", "memory": "role2"},
},
expectError: false,
},
{
name: "Valid roles via DSN, not encoded url",
rawDSN: *integrationServerFlag + "?roles=tpch:role1;memory:role2",
expectError: false,
},
{
name: "Valid roles via DSN, url encoded",
rawDSN: *integrationServerFlag + "?roles%3Dtpch%3Arole1%3Bmemory%3Arole2",
expectError: false,
},
{
name: "Non-existent catalog role",
config: Config{
ServerURI: *integrationServerFlag,
Roles: map[string]string{"not-exist-catalog": "role1"},
},
expectError: true,
errorSubstr: "USER_ERROR: Catalog",
},
{
name: "Invalid role format missing ROLE{}",
rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1",
expectError: true,
errorSubstr: "Invalid role format: catolog=role1",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var dns string
var err error

if tt.rawDSN != "" {
dns = tt.rawDSN
} else {
dns, err = tt.config.FormatDSN()
if err != nil {
t.Fatal(err)
}
}

db := integrationOpen(t, dns)
_, err = db.Query("SELECT 1")

if tt.expectError {
require.Error(t, err)
if tt.errorSubstr != "" {
require.Contains(t, err.Error(), tt.errorSubstr)
}
} else {
require.NoError(t, err)
}
})
}
}

func TestIntegrationQueryParametersSelect(t *testing.T) {
scenarios := []struct {
Expand Down
38 changes: 38 additions & 0 deletions trino/trino.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ const (
trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session`
trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session`
trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role`
trinoRoleHeader = trinoHeaderPrefix + `Role`
trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential`

trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback`
Expand Down Expand Up @@ -153,6 +154,7 @@ const (

mapKeySeparator = ":"
mapEntrySeparator = ";"
mapCommaSeparator = ","
)

var (
Expand Down Expand Up @@ -194,6 +196,7 @@ type Config struct {
AccessToken string // An access token (JWT) for authentication (optional)
ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional)
QueryTimeout *time.Duration // Configurable timeout for query (optional)
Roles map[string]string // Roles (optional)
}

// FormatDSN returns a DSN string from the configuration.
Expand All @@ -214,6 +217,14 @@ func (c *Config) FormatDSN() (string, error) {
credkv = append(credkv, k+mapKeySeparator+v)
}
}

var roles []string
if c.Roles != nil {
for k, v := range c.Roles {
roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", k, v))
}
}

source := c.Source
if source == "" {
source = "trino-go-client"
Expand Down Expand Up @@ -284,6 +295,7 @@ func (c *Config) FormatDSN() (string, error) {
"extra_credentials": strings.Join(credkv, mapEntrySeparator),
"custom_client": c.CustomClientName,
accessTokenConfig: c.AccessToken,
"roles": strings.Join(roles, mapCommaSeparator),
} {
if v != "" {
query[k] = []string{v}
Expand All @@ -307,6 +319,7 @@ type Conn struct {
useExplicitPrepare bool
forwardAuthorizationHeader bool
queryTimeout *time.Duration
Roles string
}

var (
Expand Down Expand Up @@ -390,6 +403,26 @@ func newConn(dsn string) (*Conn, error) {
queryTimeout = &d
}

var formatedRoles string
if rolesStr := query.Get("roles"); rolesStr != "" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we allow overriding roles in query parameters, instead of setting them for the whole connection? I think this should be only possible using SQL statements, and we actually should handle the X-Set-Role header coming back from the server (and remove it from unsupportedResponseHeaders).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if we want to support it via DNS and use the JDBC format, we’ll need to find a way to convert it to the correct format. So, are you suggesting it should only be sent as an SQL statement? I’ll make the changes. 👍🏻
I also going to take a look to X-Set-Role thank you !

Copy link
Member

@nineinchnick nineinchnick May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I got confused with the query variable - I thought it's for the user's SQL query, not the query part of the DSN URL. The check for =ROLE{ is also weird, this would mean we're adding support for some undocumented legacy format.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AAAh I know what I have done wrong !
When I use the config and call the formatDNS function to setup the dns I am putting the right format already! And is not suppose

So here I am validation if is not on the right format already.

I will change it, I will convert it always to convert to jdbc format and then when it reach to this point I will convert it to the format that trino is waiting for: catalog=ROLE{role1},catalog=ROLE{rol2},

if !strings.Contains(rolesStr, "=ROLE{") {
roles := []string{}
rolesToFormat := strings.Split(rolesStr, ";")

for _, role := range rolesToFormat {
parts := strings.Split(role, ":")
if len(parts) != 2 {
return nil, fmt.Errorf("Invalid role format: %s", role)
}
roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", parts[0], parts[1]))
}

formatedRoles = strings.Join(roles, mapCommaSeparator)
} else {
formatedRoles = rolesStr
}
}

c := &Conn{
baseURL: serverURL.Scheme + "://" + serverURL.Host,
httpClient: *httpClient,
Expand All @@ -400,6 +433,7 @@ func newConn(dsn string) (*Conn, error) {
useExplicitPrepare: useExplicitPrepare,
forwardAuthorizationHeader: forwardAuthorizationHeader,
queryTimeout: queryTimeout,
Roles: formatedRoles,
}

var user string
Expand Down Expand Up @@ -931,6 +965,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
// Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3).
hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME")

if st.conn.Roles != "" {
hs.Add(trinoRoleHeader, st.conn.Roles)
}

if len(args) > 0 {
var ss []string
for _, arg := range args {
Expand Down
65 changes: 65 additions & 0 deletions trino/trino_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,46 @@ func TestKerberosConfig(t *testing.T) {
assert.Equal(t, want, dsn)
}

func TestFormatDSNWithRoles(t *testing.T) {
tests := []struct {
name string
config *Config
wantDSN string
expectError bool
}{
{
name: "Multiple catalog roles",
config: &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
},
wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client",
},
{
name: "Single catalog role",
config: &Config{
ServerURI: "https://foobar@localhost:8090",
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1"},
},
wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dsn, err := tt.config.FormatDSN()
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.wantDSN, dsn)
}
})
}
}

func TestInvalidKerberosConfig(t *testing.T) {
c := &Config{
ServerURI: "http://foobar@localhost:8090",
Expand Down Expand Up @@ -1098,6 +1138,31 @@ func TestQueryCancellation(t *testing.T) {
assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error")
}

func TestTrinoRoleHeaderSent(t *testing.T) {
var receivedHeader string

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeader = r.Header.Get(trinoRoleHeader)
}))
t.Cleanup(ts.Close)

c := &Config{
ServerURI: ts.URL,
SessionProperties: map[string]string{"query_priority": "1"},
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
}

dsn, err := c.FormatDSN()
require.NoError(t, err)
db, err := sql.Open("trino", dsn)
require.NoError(t, err)

_, _ = db.Query("SHOW TABLES")
require.NoError(t, err)

assert.Equal(t, `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`, receivedHeader, "expected X-Trino-Role header to be set")
}

func TestQueryFailure(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
Expand Down
Loading