Skip to content

Commit 5eef69f

Browse files
committed
Add roles support
1 parent eaa2f66 commit 5eef69f

File tree

3 files changed

+111
-38
lines changed

3 files changed

+111
-38
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,11 @@ c := &Config{
312312
dsn, err := c.FormatDSN()
313313
```
314314

315+
**Query parameter example (overrides DSN roles):**
316+
```go
317+
rows, err := db.Query(query, sql.Named("X-Trino-Role", "catalog1:role1;catalog2:role2;catalog3:ALL"))
318+
```
319+
315320
#### Examples
316321

317322
```

trino/trino.go

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,14 @@ func ParseDSN(dsn string) (*Config, error) {
261261
}
262262
}
263263

264+
if roles := query.Get("roles"); roles != "" {
265+
var err error
266+
config.Roles, err = parseMapParameter(roles, "role", mapEntrySeparator, mapKeySeparator)
267+
if err != nil {
268+
return nil, err
269+
}
270+
}
271+
264272
if clientTags := query.Get("clientTags"); clientTags != "" {
265273
config.ClientTags = strings.Split(clientTags, commaSeparator)
266274
}
@@ -469,6 +477,39 @@ var (
469477
_ driver.ConnPrepareContext = &Conn{}
470478
)
471479

480+
// formatRolesFromMap formats roles from a map into the Trino header format
481+
func formatRolesFromMap(rolesMap map[string]string) string {
482+
var formattedRoles []string
483+
for catalog, role := range rolesMap {
484+
formattedRoles = append(formattedRoles, formatRoleEntry(catalog, role))
485+
}
486+
sort.Strings(formattedRoles)
487+
return strings.Join(formattedRoles, commaSeparator)
488+
}
489+
490+
// parseAndFormatRoles parses roles from DSN format (catalog1:roleA;catalog2:roleB)
491+
// and formats them into the Trino header format
492+
func parseAndFormatRoles(rolesString string) (string, error) {
493+
var formattedRoles []string
494+
for _, entry := range strings.Split(rolesString, mapEntrySeparator) {
495+
parts := strings.SplitN(entry, mapKeySeparator, 2)
496+
if len(parts) != 2 {
497+
return "", fmt.Errorf("invalid role entry: %q", entry)
498+
}
499+
formattedRoles = append(formattedRoles, formatRoleEntry(parts[0], parts[1]))
500+
}
501+
sort.Strings(formattedRoles)
502+
return strings.Join(formattedRoles, commaSeparator), nil
503+
}
504+
505+
// formatRoleEntry formats a single catalog role entry into Trino header format
506+
func formatRoleEntry(catalog, role string) string {
507+
if role == "ALL" || role == "NONE" {
508+
return fmt.Sprintf("%s=%s", catalog, role)
509+
}
510+
return fmt.Sprintf("%s=ROLE{%q}", catalog, role)
511+
}
512+
472513
func newConn(dsn string) (*Conn, error) {
473514
conf, err := ParseDSN(dsn)
474515
if err != nil {
@@ -555,26 +596,11 @@ func newConn(dsn string) (*Conn, error) {
555596
c.httpHeaders.Add(trinoTagsHeader, strings.Join(tags, commaSeparator))
556597
}
557598

558-
var formatedRoles []string
559-
if rolesStr := query.Get("roles"); rolesStr != "" {
560-
splitRoles := strings.Split(rolesStr, mapEntrySeparator)
561-
if len(splitRoles) == 0 {
562-
return nil, fmt.Errorf("trino: roles cannot be empty")
599+
if conf.Roles != nil {
600+
rolesHeader := formatRolesFromMap(conf.Roles)
601+
if rolesHeader != "" {
602+
c.httpHeaders.Add(trinoRoleHeader, rolesHeader)
563603
}
564-
for _, role := range splitRoles {
565-
splitRole := strings.Split(role, ":")
566-
if len(splitRole) != 2 {
567-
return nil, fmt.Errorf("trino: invalid role format: %q", role)
568-
}
569-
if splitRole[1] == "ALL" || splitRole[1] == "NONE" {
570-
formatedRoles = append(formatedRoles, fmt.Sprintf("%s=%s", splitRole[0], splitRole[1]))
571-
continue
572-
}
573-
574-
formatedRoles = append(formatedRoles, fmt.Sprintf("%s=ROLE{%q}", splitRole[0], splitRole[1]))
575-
}
576-
577-
c.httpHeaders.Add(trinoRoleHeader, strings.Join(formatedRoles, commaSeparator))
578604
}
579605

580606
for k, v := range map[string]string{
@@ -1221,6 +1247,15 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
12211247
st.user = headerValue
12221248
}
12231249

1250+
if arg.Name == trinoRoleHeader {
1251+
formattedRoles, err := parseAndFormatRoles(headerValue)
1252+
if err != nil {
1253+
return nil, err
1254+
}
1255+
st.conn.httpHeaders.Set(trinoRoleHeader, formattedRoles)
1256+
headerValue = formattedRoles
1257+
}
1258+
12241259
hs.Add(arg.Name, headerValue)
12251260
} else {
12261261
if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" {

trino/trino_test.go

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ func TestParseDSNToConfig(t *testing.T) {
8383
DisableExplicitPrepare: true,
8484
ForwardAuthorizationHeader: true,
8585
QueryTimeout: &[]time.Duration{5 * time.Minute}[0],
86+
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
8687
},
8788
},
8889
{
@@ -158,7 +159,8 @@ func TestParseDSNToConfigAllFieldsHandled(t *testing.T) {
158159
"accessToken=jwt-token-here&" +
159160
"explicitPrepare=false&" +
160161
"forwardAuthorizationHeader=true&" +
161-
"query_timeout=5m30s"
162+
"query_timeout=5m30s&" +
163+
"roles=catalog1%3Arole1%3Bcatalog2%3Arole2"
162164

163165
config, err := ParseDSN(complexDSN)
164166
require.NoError(t, err)
@@ -207,6 +209,7 @@ func TestParseDSNToConfigAllFieldsHandled(t *testing.T) {
207209
assert.Equal(t, true, config.ForwardAuthorizationHeader)
208210
assert.NotNil(t, config.QueryTimeout)
209211
assert.Equal(t, 5*time.Minute+30*time.Second, *config.QueryTimeout)
212+
assert.Equal(t, map[string]string{"catalog1": "role1", "catalog2": "role2"}, config.Roles)
210213
}
211214

212215
func TestConfigFormatDSNTags(t *testing.T) {
@@ -1344,29 +1347,59 @@ func TestQueryCancellation(t *testing.T) {
13441347
assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error")
13451348
}
13461349

1347-
func TestTrinoRoleHeaderSent(t *testing.T) {
1348-
var receivedHeader string
1350+
func TestTrinoRoleHeader(t *testing.T) {
1351+
tests := []struct {
1352+
name string
1353+
roles map[string]string
1354+
namedArg string
1355+
expectedHeader string
1356+
}{
1357+
{
1358+
name: "Roles from config",
1359+
roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
1360+
namedArg: "",
1361+
expectedHeader: `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`,
1362+
},
1363+
{
1364+
name: "Override roles with named argument",
1365+
roles: map[string]string{"catalog1": "role1"},
1366+
namedArg: `catalog3:role3;catalog4:role4;catalog5:ALL`,
1367+
expectedHeader: `catalog3=ROLE{"role3"},catalog4=ROLE{"role4"},catalog5=ALL`,
1368+
},
1369+
}
13491370

1350-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1351-
receivedHeader = r.Header.Get(trinoRoleHeader)
1352-
}))
1353-
t.Cleanup(ts.Close)
1371+
for _, tt := range tests {
1372+
t.Run(tt.name, func(t *testing.T) {
1373+
var receivedHeader string
1374+
var serverURL string
1375+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1376+
receivedHeader = r.Header.Get(trinoRoleHeader)
1377+
w.Header().Set("Content-Type", "application/json")
1378+
w.WriteHeader(http.StatusOK)
1379+
_, _ = w.Write([]byte(`{"id":"1","nextUri":"` + serverURL + `/1"}`))
1380+
}))
1381+
serverURL = ts.URL
1382+
t.Cleanup(ts.Close)
13541383

1355-
c := &Config{
1356-
ServerURI: ts.URL,
1357-
SessionProperties: map[string]string{"query_priority": "1"},
1358-
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
1359-
}
1384+
c := &Config{
1385+
ServerURI: ts.URL,
1386+
Roles: tt.roles,
1387+
}
13601388

1361-
dsn, err := c.FormatDSN()
1362-
require.NoError(t, err)
1363-
db, err := sql.Open("trino", dsn)
1364-
require.NoError(t, err)
1389+
dsn, err := c.FormatDSN()
1390+
require.NoError(t, err)
1391+
db, err := sql.Open("trino", dsn)
1392+
require.NoError(t, err)
13651393

1366-
_, _ = db.Query("SHOW TABLES")
1367-
require.NoError(t, err)
1394+
if tt.namedArg != "" {
1395+
_, _ = db.Query("SELECT 1", sql.Named("X-Trino-Role", tt.namedArg))
1396+
} else {
1397+
_, _ = db.Query("SELECT 1")
1398+
}
13681399

1369-
assert.Equal(t, `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`, receivedHeader, "expected X-Trino-Role header to be set")
1400+
assert.Equal(t, tt.expectedHeader, receivedHeader, "expected X-Trino-Role header to match")
1401+
})
1402+
}
13701403
}
13711404

13721405
func TestQueryFailure(t *testing.T) {

0 commit comments

Comments
 (0)