diff --git a/README.md b/README.md index ee81056..bb6baee 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,30 @@ Default: nil The `queryTimeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used. + +#### `roles` + +``` +Type: string +Format: roles=catalog1:role1;catalog2=role2 +Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role. +Default: empty +``` +The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session. + +You can assign roles either as a map of catalog-to-role pairs or a string direcly in the dns connection. + +##### Example +``` go +c := &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, +} + +dsn, err := c.FormatDSN() +``` + #### Examples ``` @@ -259,6 +283,11 @@ http://user@localhost:8080?source=hello&catalog=default&schema=foobar https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2 ``` + +``` +http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2 +``` + ## Data types ### Query arguments diff --git a/trino/integration_test.go b/trino/integration_test.go index a871cab..b32ee40 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -49,6 +49,7 @@ import ( "github.com/golang-jwt/jwt/v5" dt "github.com/ory/dockertest/v3" docker "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/require" ) const ( @@ -1024,6 +1025,77 @@ func TestIntegrationNoResults(t *testing.T) { t.Fatal(err) } } +func TestRoleHeaderSupport(t *testing.T) { + tests := []struct { + name string + config Config + rawDSN string + expectError bool + errorSubstr string + }{ + { + name: "Valid roles via Config", + config: Config{ + ServerURI: *integrationServerFlag, + Roles: map[string]string{"tpch": "role1", "memory": "role2"}, + }, + expectError: false, + }, + { + name: "Valid roles via DSN, not encoded url", + rawDSN: *integrationServerFlag + "?roles=tpch:role1;memory:role2", + expectError: false, + }, + { + name: "Valid roles via DSN, url encoded", + rawDSN: *integrationServerFlag + "?roles%3Dtpch%3Arole1%3Bmemory%3Arole2", + expectError: false, + }, + { + name: "Non-existent catalog role", + config: Config{ + ServerURI: *integrationServerFlag, + Roles: map[string]string{"not-exist-catalog": "role1"}, + }, + expectError: true, + errorSubstr: "USER_ERROR: Catalog", + }, + { + name: "Invalid role format missing ROLE{}", + rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1", + expectError: true, + errorSubstr: "Invalid role format: catolog=role1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var dns string + var err error + + if tt.rawDSN != "" { + dns = tt.rawDSN + } else { + dns, err = tt.config.FormatDSN() + if err != nil { + t.Fatal(err) + } + } + + db := integrationOpen(t, dns) + _, err = db.Query("SELECT 1") + + if tt.expectError { + require.Error(t, err) + if tt.errorSubstr != "" { + require.Contains(t, err.Error(), tt.errorSubstr) + } + } else { + require.NoError(t, err) + } + }) + } +} func TestIntegrationQueryParametersSelect(t *testing.T) { scenarios := []struct { diff --git a/trino/trino.go b/trino/trino.go index 3941d31..46f5ca7 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -126,6 +126,7 @@ const ( trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session` trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session` trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role` + trinoRoleHeader = trinoHeaderPrefix + `Role` trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential` trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback` @@ -153,6 +154,7 @@ const ( mapKeySeparator = ":" mapEntrySeparator = ";" + mapCommaSeparator = "," ) var ( @@ -194,6 +196,7 @@ type Config struct { AccessToken string // An access token (JWT) for authentication (optional) ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) QueryTimeout *time.Duration // Configurable timeout for query (optional) + Roles map[string]string // Roles (optional) } // FormatDSN returns a DSN string from the configuration. @@ -214,6 +217,14 @@ func (c *Config) FormatDSN() (string, error) { credkv = append(credkv, k+mapKeySeparator+v) } } + + var roles []string + if c.Roles != nil { + for k, v := range c.Roles { + roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", k, v)) + } + } + source := c.Source if source == "" { source = "trino-go-client" @@ -284,6 +295,7 @@ func (c *Config) FormatDSN() (string, error) { "extra_credentials": strings.Join(credkv, mapEntrySeparator), "custom_client": c.CustomClientName, accessTokenConfig: c.AccessToken, + "roles": strings.Join(roles, mapCommaSeparator), } { if v != "" { query[k] = []string{v} @@ -307,6 +319,7 @@ type Conn struct { useExplicitPrepare bool forwardAuthorizationHeader bool queryTimeout *time.Duration + Roles string } var ( @@ -390,6 +403,26 @@ func newConn(dsn string) (*Conn, error) { queryTimeout = &d } + var formatedRoles string + if rolesStr := query.Get("roles"); rolesStr != "" { + if !strings.Contains(rolesStr, "=ROLE{") { + roles := []string{} + rolesToFormat := strings.Split(rolesStr, ";") + + for _, role := range rolesToFormat { + parts := strings.Split(role, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("Invalid role format: %s", role) + } + roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", parts[0], parts[1])) + } + + formatedRoles = strings.Join(roles, mapCommaSeparator) + } else { + formatedRoles = rolesStr + } + } + c := &Conn{ baseURL: serverURL.Scheme + "://" + serverURL.Host, httpClient: *httpClient, @@ -400,6 +433,7 @@ func newConn(dsn string) (*Conn, error) { useExplicitPrepare: useExplicitPrepare, forwardAuthorizationHeader: forwardAuthorizationHeader, queryTimeout: queryTimeout, + Roles: formatedRoles, } var user string @@ -931,6 +965,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt // Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3). hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME") + if st.conn.Roles != "" { + hs.Add(trinoRoleHeader, st.conn.Roles) + } + if len(args) > 0 { var ss []string for _, arg := range args { diff --git a/trino/trino_test.go b/trino/trino_test.go index 694571e..c2e4e65 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -203,6 +203,46 @@ func TestKerberosConfig(t *testing.T) { assert.Equal(t, want, dsn) } +func TestFormatDSNWithRoles(t *testing.T) { + tests := []struct { + name string + config *Config + wantDSN string + expectError bool + }{ + { + name: "Multiple catalog roles", + config: &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, + }, + wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client", + }, + { + name: "Single catalog role", + config: &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1"}, + }, + wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dsn, err := tt.config.FormatDSN() + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantDSN, dsn) + } + }) + } +} + func TestInvalidKerberosConfig(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8090", @@ -1098,6 +1138,31 @@ func TestQueryCancellation(t *testing.T) { assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error") } +func TestTrinoRoleHeaderSent(t *testing.T) { + var receivedHeader string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeader = r.Header.Get(trinoRoleHeader) + })) + t.Cleanup(ts.Close) + + c := &Config{ + ServerURI: ts.URL, + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + db, err := sql.Open("trino", dsn) + require.NoError(t, err) + + _, _ = db.Query("SHOW TABLES") + require.NoError(t, err) + + assert.Equal(t, `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`, receivedHeader, "expected X-Trino-Role header to be set") +} + func TestQueryFailure(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError)