From 081cd851210c2b23115735cf50626ef49d74d48d Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Thu, 15 May 2025 19:42:53 +0100 Subject: [PATCH 1/6] add support for catalog roles --- README.md | 40 +++++++++++++++++++++++++++++++++ trino/trino.go | 25 +++++++++++++++++++++ trino/trino_test.go | 55 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+) diff --git a/README.md b/README.md index ee81056..f6c6edd 100644 --- a/README.md +++ b/README.md @@ -249,6 +249,46 @@ Default: nil The `queryTimeout` parameter sets a timeout for the query. If the query takes longer than the timeout, it will be cancelled. If it is not set the default context timeout will be used. + +#### `roles` + +``` +Type: string +Format: roles=catalog1=role1;catalog2=role2 +Valid values: semicolon-separated list of catalog-to-role assignments +Default: empty +``` +The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session. + +You can assign roles either as a map of catalog-to-role pairs or a string. When a string is used, it applies the role to the `system` catalog by default. + +##### Example + +``` go +c := &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, +} + +dsn, err := c.FormatDSN() +// Result: https://foobar@localhost:8090?roles=catalog1%3Drole1%3Bcatalog2%3Drole2&session_properties=query_priority%3A1 +``` + +**Example using a string (applies to system catalog)** + +``` go +c := &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: "admin", // equivalent to map[string]string{"system": "admin"} +} + +dsn, err := c.FormatDSN() +// Result: https://foobar@localhost:8090?roles=system%3Dadmin&session_properties=query_priority%3A1 +``` + + #### Examples ``` diff --git a/trino/trino.go b/trino/trino.go index 3941d31..b61b829 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -126,6 +126,7 @@ const ( trinoSetSessionHeader = trinoHeaderPrefix + `Set-Session` trinoClearSessionHeader = trinoHeaderPrefix + `Clear-Session` trinoSetRoleHeader = trinoHeaderPrefix + `Set-Role` + trinoRoleHeader = trinoHeaderPrefix + `Role` trinoExtraCredentialHeader = trinoHeaderPrefix + `Extra-Credential` trinoProgressCallbackParam = trinoHeaderPrefix + `Progress-Callback` @@ -153,6 +154,8 @@ const ( mapKeySeparator = ":" mapEntrySeparator = ";" + mapRolesSeparator = "=" + sistemRole = "system" ) var ( @@ -194,6 +197,7 @@ type Config struct { AccessToken string // An access token (JWT) for authentication (optional) ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) QueryTimeout *time.Duration // Configurable timeout for query (optional) + Roles interface{} // Roles (optional) } // FormatDSN returns a DSN string from the configuration. @@ -214,6 +218,20 @@ func (c *Config) FormatDSN() (string, error) { credkv = append(credkv, k+mapKeySeparator+v) } } + + var roles []string + if c.Roles != nil { + if v, ok := c.Roles.(string); ok { + roles = append(roles, sistemRole+mapRolesSeparator+v) + } else if v, ok := c.Roles.(map[string]string); ok { + for k, v := range v { + roles = append(roles, k+mapRolesSeparator+v) + } + } else { + return "", fmt.Errorf("Invalid roles type %T", c.Roles) + } + } + source := c.Source if source == "" { source = "trino-go-client" @@ -284,6 +302,7 @@ func (c *Config) FormatDSN() (string, error) { "extra_credentials": strings.Join(credkv, mapEntrySeparator), "custom_client": c.CustomClientName, accessTokenConfig: c.AccessToken, + "roles": strings.Join(roles, mapEntrySeparator), } { if v != "" { query[k] = []string{v} @@ -307,6 +326,7 @@ type Conn struct { useExplicitPrepare bool forwardAuthorizationHeader bool queryTimeout *time.Duration + Roles string } var ( @@ -400,6 +420,7 @@ func newConn(dsn string) (*Conn, error) { useExplicitPrepare: useExplicitPrepare, forwardAuthorizationHeader: forwardAuthorizationHeader, queryTimeout: queryTimeout, + Roles: query.Get("roles"), } var user string @@ -931,6 +952,10 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt // Ensure the server returns timestamps preserving their precision, without truncating them to timestamp(3). hs.Add("X-Trino-Client-Capabilities", "PARAMETRIC_DATETIME") + if st.conn.Roles != "" { + hs.Add(trinoRoleHeader, st.conn.Roles) + } + if len(args) > 0 { var ss []string for _, arg := range args { diff --git a/trino/trino_test.go b/trino/trino_test.go index 694571e..b38b37d 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -203,6 +203,36 @@ func TestKerberosConfig(t *testing.T) { assert.Equal(t, want, dsn) } +func TestRolesConfig(t *testing.T) { + c := &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + + want := "https://foobar@localhost:8090?roles=catalog1%3Drole1%3Bcatalog2%3Drole2&session_properties=query_priority%3A1&source=trino-go-client" + + assert.Equal(t, want, dsn) +} + +func TestDefaultRoleConfig(t *testing.T) { + c := &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: "role1", + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + + want := "https://foobar@localhost:8090?roles=system%3Drole1&session_properties=query_priority%3A1&source=trino-go-client" + + assert.Equal(t, want, dsn) +} + func TestInvalidKerberosConfig(t *testing.T) { c := &Config{ ServerURI: "http://foobar@localhost:8090", @@ -1098,6 +1128,31 @@ func TestQueryCancellation(t *testing.T) { assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error") } +func TestTrinoRoleHeaderSent(t *testing.T) { + var receivedHeader string + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeader = r.Header.Get(trinoRoleHeader) + })) + t.Cleanup(ts.Close) + + c := &Config{ + ServerURI: ts.URL, + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, + } + + dsn, err := c.FormatDSN() + require.NoError(t, err) + db, err := sql.Open("trino", dsn) + require.NoError(t, err) + + _, _ = db.Query("SHOW TABLES") + require.NoError(t, err) + + assert.Equal(t, "catalog1=role1;catalog2=role2", receivedHeader, "expected X-Trino-Role header to be set") +} + func TestQueryFailure(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) From 6ce2261380aa88c7d2befa6b937d08397dcd22b4 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Sat, 17 May 2025 16:42:39 +0100 Subject: [PATCH 2/6] wip --- trino/integration_test.go | 61 ++++++++++++++++++++++++++------------- trino/trino.go | 4 +-- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/trino/integration_test.go b/trino/integration_test.go index a871cab..95adec9 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -187,26 +187,26 @@ func TestMain(m *testing.M) { code := m.Run() - if !*noCleanup && pool != nil { - if trinoResource != nil { - if err := pool.Purge(trinoResource); err != nil { - log.Fatalf("Could not purge resource: %s", err) - } - } - - if localStackResource != nil { - if err := pool.Purge(localStackResource); err != nil { - log.Fatalf("Could not purge LocalStack resource: %s", err) - } - } - - networkExists, networkID, err := networkExists(pool, TrinoNetwork) - if err == nil && networkExists { - if err := pool.Client.RemoveNetwork(networkID); err != nil { - log.Fatalf("Could not remove Docker network: %s", err) - } - } - } + // if !*noCleanup && pool != nil { + // if trinoResource != nil { + // if err := pool.Purge(trinoResource); err != nil { + // log.Fatalf("Could not purge resource: %s", err) + // } + // } + + // if localStackResource != nil { + // if err := pool.Purge(localStackResource); err != nil { + // log.Fatalf("Could not purge LocalStack resource: %s", err) + // } + // } + + // networkExists, networkID, err := networkExists(pool, TrinoNetwork) + // if err == nil && networkExists { + // if err := pool.Client.RemoveNetwork(networkID); err != nil { + // log.Fatalf("Could not remove Docker network: %s", err) + // } + // } + // } os.Exit(code) } @@ -1025,6 +1025,27 @@ func TestIntegrationNoResults(t *testing.T) { } } +func TestRoleSupport(t *testing.T) { + config := Config{ + ServerURI: *integrationServerFlag, + Roles: map[string]string{"tpch": "role1"}, + } + + dns, err := config.FormatDSN() + if err != nil { + t.Fatal(err) + } + + db := integrationOpen(t, dns) + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + for rows.Next() { + t.Fatal(errors.New("Rows returned")) + } +} + func TestIntegrationQueryParametersSelect(t *testing.T) { scenarios := []struct { name string diff --git a/trino/trino.go b/trino/trino.go index b61b829..542add5 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -222,10 +222,10 @@ func (c *Config) FormatDSN() (string, error) { var roles []string if c.Roles != nil { if v, ok := c.Roles.(string); ok { - roles = append(roles, sistemRole+mapRolesSeparator+v) + roles = append(roles, sistemRole+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v)) } else if v, ok := c.Roles.(map[string]string); ok { for k, v := range v { - roles = append(roles, k+mapRolesSeparator+v) + roles = append(roles, k+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v)) } } else { return "", fmt.Errorf("Invalid roles type %T", c.Roles) From 66fb9b95e5a9f95848f524a48199620bbf7ebd59 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Sat, 17 May 2025 17:49:58 +0100 Subject: [PATCH 3/6] add integration tests, change x-role header format --- trino/integration_test.go | 90 ++++++++++++++++++++++++++++++++------- trino/trino.go | 3 +- trino/trino_test.go | 60 +++++++++++++++----------- 3 files changed, 111 insertions(+), 42 deletions(-) diff --git a/trino/integration_test.go b/trino/integration_test.go index 95adec9..8aa5f19 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -49,6 +49,7 @@ import ( "github.com/golang-jwt/jwt/v5" dt "github.com/ory/dockertest/v3" docker "github.com/ory/dockertest/v3/docker" + "github.com/stretchr/testify/require" ) const ( @@ -1024,25 +1025,82 @@ func TestIntegrationNoResults(t *testing.T) { t.Fatal(err) } } - -func TestRoleSupport(t *testing.T) { - config := Config{ - ServerURI: *integrationServerFlag, - Roles: map[string]string{"tpch": "role1"}, +func TestRoleHeaderSupport(t *testing.T) { + tests := []struct { + name string + config Config + rawDSN string + expectError bool + errorSubstr string + }{ + { + name: "Valid roles via Config", + config: Config{ + ServerURI: *integrationServerFlag, + Roles: map[string]string{"tpch": "role1", "memory": "role2"}, + }, + expectError: false, + }, + { + name: "Valid single role via DSN", + rawDSN: *integrationServerFlag + "?roles=tpch%3DROLE%7Brole1%7D", + expectError: false, + }, + { + name: "Non-existent catalog role", + config: Config{ + ServerURI: *integrationServerFlag, + Roles: map[string]string{"not-exist-catalog": "role1"}, + }, + expectError: true, + errorSubstr: "USER_ERROR: Catalog 'not-exist-catalog' not found", + }, + { + name: "Invalid role format with colon", + rawDSN: *integrationServerFlag + "?roles=not-exist-catalog%3Arole1", + expectError: true, + errorSubstr: "Invalid X-Trino-Role header", + }, + { + name: "Invalid role format missing ROLE{}", + rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1", + expectError: true, + errorSubstr: "Invalid X-Trino-Role header", + }, + { + name: "Invalid role format missing ROLE{}", + rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1", + expectError: true, + errorSubstr: "Invalid X-Trino-Role header", + }, } - dns, err := config.FormatDSN() - if err != nil { - t.Fatal(err) - } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var dns string + var err error - db := integrationOpen(t, dns) - rows, err := db.Query("SELECT 1") - if err != nil { - t.Fatal(err) - } - for rows.Next() { - t.Fatal(errors.New("Rows returned")) + if tt.rawDSN != "" { + dns = tt.rawDSN + } else { + dns, err = tt.config.FormatDSN() + if err != nil { + t.Fatal(err) + } + } + + db := integrationOpen(t, dns) + _, err = db.Query("SELECT 1") + + if tt.expectError { + require.Error(t, err) + if tt.errorSubstr != "" { + require.Contains(t, err.Error(), tt.errorSubstr) + } + } else { + require.NoError(t, err) + } + }) } } diff --git a/trino/trino.go b/trino/trino.go index 542add5..25e2cd0 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -154,6 +154,7 @@ const ( mapKeySeparator = ":" mapEntrySeparator = ";" + mapCommaSeparator = "," mapRolesSeparator = "=" sistemRole = "system" ) @@ -302,7 +303,7 @@ func (c *Config) FormatDSN() (string, error) { "extra_credentials": strings.Join(credkv, mapEntrySeparator), "custom_client": c.CustomClientName, accessTokenConfig: c.AccessToken, - "roles": strings.Join(roles, mapEntrySeparator), + "roles": strings.Join(roles, mapCommaSeparator), } { if v != "" { query[k] = []string{v} diff --git a/trino/trino_test.go b/trino/trino_test.go index b38b37d..7d6dd4c 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -203,34 +203,44 @@ func TestKerberosConfig(t *testing.T) { assert.Equal(t, want, dsn) } -func TestRolesConfig(t *testing.T) { - c := &Config{ - ServerURI: "https://foobar@localhost:8090", - SessionProperties: map[string]string{"query_priority": "1"}, - Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, +func TestFormatDSNWithRoles(t *testing.T) { + tests := []struct { + name string + config *Config + wantDSN string + expectError bool + }{ + { + name: "Multiple catalog roles", + config: &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"}, + }, + wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client", + }, + { + name: "Default system role as string", + config: &Config{ + ServerURI: "https://foobar@localhost:8090", + SessionProperties: map[string]string{"query_priority": "1"}, + Roles: "role1", + }, + wantDSN: "https://foobar@localhost:8090?roles=system%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client", + }, } - dsn, err := c.FormatDSN() - require.NoError(t, err) - - want := "https://foobar@localhost:8090?roles=catalog1%3Drole1%3Bcatalog2%3Drole2&session_properties=query_priority%3A1&source=trino-go-client" - - assert.Equal(t, want, dsn) -} - -func TestDefaultRoleConfig(t *testing.T) { - c := &Config{ - ServerURI: "https://foobar@localhost:8090", - SessionProperties: map[string]string{"query_priority": "1"}, - Roles: "role1", + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dsn, err := tt.config.FormatDSN() + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantDSN, dsn) + } + }) } - - dsn, err := c.FormatDSN() - require.NoError(t, err) - - want := "https://foobar@localhost:8090?roles=system%3Drole1&session_properties=query_priority%3A1&source=trino-go-client" - - assert.Equal(t, want, dsn) } func TestInvalidKerberosConfig(t *testing.T) { From 6bfe79cff669418b308fb5125f9d73beebd16760 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Sat, 17 May 2025 18:08:26 +0100 Subject: [PATCH 4/6] change readme; fix unit test --- README.md | 8 ++++---- trino/trino_test.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f6c6edd..c6b7b49 100644 --- a/README.md +++ b/README.md @@ -254,8 +254,8 @@ The `queryTimeout` parameter sets a timeout for the query. If the query takes lo ``` Type: string -Format: roles=catalog1=role1;catalog2=role2 -Valid values: semicolon-separated list of catalog-to-role assignments +Format: roles=catalog1=ROLE{role1},catalog2=ROLE{role2} +Valid values: A comma-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role. Default: empty ``` The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session. @@ -272,7 +272,7 @@ c := &Config{ } dsn, err := c.FormatDSN() -// Result: https://foobar@localhost:8090?roles=catalog1%3Drole1%3Bcatalog2%3Drole2&session_properties=query_priority%3A1 +// Result: https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client ``` **Example using a string (applies to system catalog)** @@ -285,7 +285,7 @@ c := &Config{ } dsn, err := c.FormatDSN() -// Result: https://foobar@localhost:8090?roles=system%3Dadmin&session_properties=query_priority%3A1 +// Result: https://foobar@localhost:8090?roles=system%3DROLE%7B%22admin%22%7D&session_properties=query_priority%3A1&source=trino-go-client ``` diff --git a/trino/trino_test.go b/trino/trino_test.go index 7d6dd4c..4fbbf3b 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -1160,7 +1160,7 @@ func TestTrinoRoleHeaderSent(t *testing.T) { _, _ = db.Query("SHOW TABLES") require.NoError(t, err) - assert.Equal(t, "catalog1=role1;catalog2=role2", receivedHeader, "expected X-Trino-Role header to be set") + assert.Equal(t, `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`, receivedHeader, "expected X-Trino-Role header to be set") } func TestQueryFailure(t *testing.T) { From 2c413ad903404290bca1f254ee48293530ac5e85 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Sat, 17 May 2025 18:09:47 +0100 Subject: [PATCH 5/6] uncomment purge dockers --- trino/integration_test.go | 40 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/trino/integration_test.go b/trino/integration_test.go index 8aa5f19..3dde954 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -188,26 +188,26 @@ func TestMain(m *testing.M) { code := m.Run() - // if !*noCleanup && pool != nil { - // if trinoResource != nil { - // if err := pool.Purge(trinoResource); err != nil { - // log.Fatalf("Could not purge resource: %s", err) - // } - // } - - // if localStackResource != nil { - // if err := pool.Purge(localStackResource); err != nil { - // log.Fatalf("Could not purge LocalStack resource: %s", err) - // } - // } - - // networkExists, networkID, err := networkExists(pool, TrinoNetwork) - // if err == nil && networkExists { - // if err := pool.Client.RemoveNetwork(networkID); err != nil { - // log.Fatalf("Could not remove Docker network: %s", err) - // } - // } - // } + if !*noCleanup && pool != nil { + if trinoResource != nil { + if err := pool.Purge(trinoResource); err != nil { + log.Fatalf("Could not purge resource: %s", err) + } + } + + if localStackResource != nil { + if err := pool.Purge(localStackResource); err != nil { + log.Fatalf("Could not purge LocalStack resource: %s", err) + } + } + + networkExists, networkID, err := networkExists(pool, TrinoNetwork) + if err == nil && networkExists { + if err := pool.Client.RemoveNetwork(networkID); err != nil { + log.Fatalf("Could not remove Docker network: %s", err) + } + } + } os.Exit(code) } From 79b3a4c8e3043e99fd821411cf5037ffd1fa3212 Mon Sep 17 00:00:00 2001 From: "joao.folgado" Date: Sun, 18 May 2025 15:01:24 +0100 Subject: [PATCH 6/6] use jdbc sintax on roles query parameter --- README.md | 29 +++++++++-------------------- trino/integration_test.go | 25 +++++++++---------------- trino/trino.go | 36 ++++++++++++++++++++++++------------ trino/trino_test.go | 6 +++--- 4 files changed, 45 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index c6b7b49..bb6baee 100644 --- a/README.md +++ b/README.md @@ -253,17 +253,16 @@ The `queryTimeout` parameter sets a timeout for the query. If the query takes lo #### `roles` ``` -Type: string -Format: roles=catalog1=ROLE{role1},catalog2=ROLE{role2} -Valid values: A comma-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role. +Type: string +Format: roles=catalog1:role1;catalog2=role2 +Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role. Default: empty ``` The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session. -You can assign roles either as a map of catalog-to-role pairs or a string. When a string is used, it applies the role to the `system` catalog by default. +You can assign roles either as a map of catalog-to-role pairs or a string direcly in the dns connection. ##### Example - ``` go c := &Config{ ServerURI: "https://foobar@localhost:8090", @@ -272,23 +271,8 @@ c := &Config{ } dsn, err := c.FormatDSN() -// Result: https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client ``` -**Example using a string (applies to system catalog)** - -``` go -c := &Config{ - ServerURI: "https://foobar@localhost:8090", - SessionProperties: map[string]string{"query_priority": "1"}, - Roles: "admin", // equivalent to map[string]string{"system": "admin"} -} - -dsn, err := c.FormatDSN() -// Result: https://foobar@localhost:8090?roles=system%3DROLE%7B%22admin%22%7D&session_properties=query_priority%3A1&source=trino-go-client -``` - - #### Examples ``` @@ -299,6 +283,11 @@ http://user@localhost:8080?source=hello&catalog=default&schema=foobar https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2 ``` + +``` +http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2 +``` + ## Data types ### Query arguments diff --git a/trino/integration_test.go b/trino/integration_test.go index 3dde954..b32ee40 100644 --- a/trino/integration_test.go +++ b/trino/integration_test.go @@ -1042,8 +1042,13 @@ func TestRoleHeaderSupport(t *testing.T) { expectError: false, }, { - name: "Valid single role via DSN", - rawDSN: *integrationServerFlag + "?roles=tpch%3DROLE%7Brole1%7D", + name: "Valid roles via DSN, not encoded url", + rawDSN: *integrationServerFlag + "?roles=tpch:role1;memory:role2", + expectError: false, + }, + { + name: "Valid roles via DSN, url encoded", + rawDSN: *integrationServerFlag + "?roles%3Dtpch%3Arole1%3Bmemory%3Arole2", expectError: false, }, { @@ -1053,25 +1058,13 @@ func TestRoleHeaderSupport(t *testing.T) { Roles: map[string]string{"not-exist-catalog": "role1"}, }, expectError: true, - errorSubstr: "USER_ERROR: Catalog 'not-exist-catalog' not found", - }, - { - name: "Invalid role format with colon", - rawDSN: *integrationServerFlag + "?roles=not-exist-catalog%3Arole1", - expectError: true, - errorSubstr: "Invalid X-Trino-Role header", - }, - { - name: "Invalid role format missing ROLE{}", - rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1", - expectError: true, - errorSubstr: "Invalid X-Trino-Role header", + errorSubstr: "USER_ERROR: Catalog", }, { name: "Invalid role format missing ROLE{}", rawDSN: *integrationServerFlag + "?roles=catolog%3Drole1", expectError: true, - errorSubstr: "Invalid X-Trino-Role header", + errorSubstr: "Invalid role format: catolog=role1", }, } diff --git a/trino/trino.go b/trino/trino.go index 25e2cd0..46f5ca7 100644 --- a/trino/trino.go +++ b/trino/trino.go @@ -155,8 +155,6 @@ const ( mapKeySeparator = ":" mapEntrySeparator = ";" mapCommaSeparator = "," - mapRolesSeparator = "=" - sistemRole = "system" ) var ( @@ -198,7 +196,7 @@ type Config struct { AccessToken string // An access token (JWT) for authentication (optional) ForwardAuthorizationHeader bool // Allow forwarding the `accessToken` named query parameter in the authorization header, overwriting the `AccessToken` option, if set (optional) QueryTimeout *time.Duration // Configurable timeout for query (optional) - Roles interface{} // Roles (optional) + Roles map[string]string // Roles (optional) } // FormatDSN returns a DSN string from the configuration. @@ -222,14 +220,8 @@ func (c *Config) FormatDSN() (string, error) { var roles []string if c.Roles != nil { - if v, ok := c.Roles.(string); ok { - roles = append(roles, sistemRole+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v)) - } else if v, ok := c.Roles.(map[string]string); ok { - for k, v := range v { - roles = append(roles, k+mapRolesSeparator+fmt.Sprintf("ROLE{%q}", v)) - } - } else { - return "", fmt.Errorf("Invalid roles type %T", c.Roles) + for k, v := range c.Roles { + roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", k, v)) } } @@ -411,6 +403,26 @@ func newConn(dsn string) (*Conn, error) { queryTimeout = &d } + var formatedRoles string + if rolesStr := query.Get("roles"); rolesStr != "" { + if !strings.Contains(rolesStr, "=ROLE{") { + roles := []string{} + rolesToFormat := strings.Split(rolesStr, ";") + + for _, role := range rolesToFormat { + parts := strings.Split(role, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("Invalid role format: %s", role) + } + roles = append(roles, fmt.Sprintf("%s=ROLE{%q}", parts[0], parts[1])) + } + + formatedRoles = strings.Join(roles, mapCommaSeparator) + } else { + formatedRoles = rolesStr + } + } + c := &Conn{ baseURL: serverURL.Scheme + "://" + serverURL.Host, httpClient: *httpClient, @@ -421,7 +433,7 @@ func newConn(dsn string) (*Conn, error) { useExplicitPrepare: useExplicitPrepare, forwardAuthorizationHeader: forwardAuthorizationHeader, queryTimeout: queryTimeout, - Roles: query.Get("roles"), + Roles: formatedRoles, } var user string diff --git a/trino/trino_test.go b/trino/trino_test.go index 4fbbf3b..c2e4e65 100644 --- a/trino/trino_test.go +++ b/trino/trino_test.go @@ -220,13 +220,13 @@ func TestFormatDSNWithRoles(t *testing.T) { wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D%2Ccatalog2%3DROLE%7B%22role2%22%7D&session_properties=query_priority%3A1&source=trino-go-client", }, { - name: "Default system role as string", + name: "Single catalog role", config: &Config{ ServerURI: "https://foobar@localhost:8090", SessionProperties: map[string]string{"query_priority": "1"}, - Roles: "role1", + Roles: map[string]string{"catalog1": "role1"}, }, - wantDSN: "https://foobar@localhost:8090?roles=system%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client", + wantDSN: "https://foobar@localhost:8090?roles=catalog1%3DROLE%7B%22role1%22%7D&session_properties=query_priority%3A1&source=trino-go-client", }, }