Skip to content

Commit e89b499

Browse files
committed
Add support for catalog roles
1 parent 31e5182 commit e89b499

File tree

6 files changed

+492
-10
lines changed

6 files changed

+492
-10
lines changed

README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,40 @@ dsn, err := config.FormatDSN()
289289
```go
290290
rows, err := db.Query(query, sql.Named("X-Trino-Client-Tags", "tag1,tag2,tag3"))
291291
```
292+
=======
293+
294+
#### `roles`
295+
296+
```
297+
Type: string
298+
Format: roles=catalog1:role1;catalog2=role2
299+
Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role.
300+
Default: empty
301+
```
302+
The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session.
303+
304+
##### Example
305+
``` go
306+
c := &Config{
307+
ServerURI: "https://foobar@localhost:8090",
308+
SessionProperties: map[string]string{"query_priority": "1"},
309+
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
310+
}
311+
312+
dsn, err := c.FormatDSN()
313+
```
314+
315+
**Query parameter example (overrides DSN roles):**
316+
```go
317+
rows, err := db.Query(
318+
query,
319+
sql.Named("X-Trino-Role", map[string]string{
320+
"catalog1": "role1",
321+
"catalog2": "role2",
322+
}),
323+
)
324+
```
325+
292326
#### Examples
293327

294328
```
@@ -299,6 +333,11 @@ http://user@localhost:8080?source=hello&catalog=default&schema=foobar
299333
https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2
300334
```
301335

336+
337+
```
338+
http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2
339+
```
340+
302341
## Data types
303342

304343
### Query arguments
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
connector.name=hive
2+
hive.metastore=file
3+
hive.metastore.catalog.dir=/tmp/metastore
4+
hive.security=sql-standard

trino/etc/catalog/hive.properties

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
connector.name=hive
2+
hive.metastore=file
3+
hive.metastore.catalog.dir=/tmp/metastore
4+
hive.security=sql-standard
5+
fs.hadoop.enabled=true

trino/integration_test.go

Lines changed: 189 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import (
5050
"github.com/golang-jwt/jwt/v5"
5151
dt "github.com/ory/dockertest/v3"
5252
docker "github.com/ory/dockertest/v3/docker"
53+
"github.com/stretchr/testify/require"
5354
)
5455

5556
const (
@@ -143,8 +144,12 @@ func TestMain(m *testing.M) {
143144
wd + "/etc/password-authenticator.properties:/etc/trino/password-authenticator.properties",
144145
}
145146

147+
version, err := strconv.Atoi(*trinoImageTagFlag)
148+
if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 458) {
149+
mounts = append(mounts, wd+"/etc/catalog/hive-pre-458version.properties:/etc/trino/catalog/hive.properties")
150+
}
151+
146152
if spoolingProtocolSupported {
147-
version, err := strconv.Atoi(*trinoImageTagFlag)
148153
if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 477) {
149154
mounts = append(mounts, wd+"/etc/config-pre-477version.properties:/etc/trino/config.properties")
150155
} else {
@@ -182,6 +187,11 @@ func TestMain(m *testing.M) {
182187

183188
waitForContainerHealth(trinoResource.Container.ID, "trino")
184189

190+
err = grantAdminRoleToTestUser()
191+
if err != nil {
192+
log.Fatalf("Warning: Failed to grant admin role to test user: %s", err)
193+
}
194+
185195
*integrationServerFlag = "http://test@localhost:" + trinoResource.GetPort("8080/tcp")
186196
tlsServer = "https://admin:admin@localhost:" + trinoResource.GetPort("8443/tcp")
187197

@@ -217,6 +227,35 @@ func TestMain(m *testing.M) {
217227
os.Exit(code)
218228
}
219229

230+
func grantAdminRoleToTestUser() error {
231+
grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
232+
233+
execCmd := []string{
234+
"trino",
235+
"--user", "admin",
236+
"--execute", grantSQL,
237+
}
238+
exec, err := pool.Client.CreateExec(docker.CreateExecOptions{
239+
Container: trinoResource.Container.ID,
240+
Cmd: execCmd,
241+
})
242+
if err != nil {
243+
log.Printf("Warning: Failed to create exec for GRANT: %s", err)
244+
} else {
245+
var stdout, stderr bytes.Buffer
246+
err = pool.Client.StartExec(exec.ID, docker.StartExecOptions{
247+
Detach: false,
248+
OutputStream: &stdout,
249+
ErrorStream: &stderr,
250+
})
251+
if err != nil {
252+
log.Printf("Warning: Failed to execute GRANT: %s", err)
253+
}
254+
}
255+
256+
return err
257+
}
258+
220259
func getOrCreateLocalStack(pool *dt.Pool, networkID string) *dt.Resource {
221260
resource, ok := pool.ContainerByName(DockerLocalStackName)
222261
if ok {
@@ -1030,6 +1069,155 @@ func TestIntegrationNoResults(t *testing.T) {
10301069
t.Fatal(err)
10311070
}
10321071
}
1072+
func TestRoleHeaderSupport(t *testing.T) {
1073+
tests := []struct {
1074+
name string
1075+
config Config
1076+
rawDSN string
1077+
query string
1078+
expectError bool
1079+
errorSubstr string
1080+
validateRows func(t *testing.T, rows *sql.Rows)
1081+
}{
1082+
{
1083+
name: "Valid hive admin role via Config",
1084+
config: Config{
1085+
ServerURI: *integrationServerFlag,
1086+
Roles: map[string]string{"hive": "admin"},
1087+
},
1088+
query: "SHOW ROLES FROM hive",
1089+
expectError: false,
1090+
validateRows: func(t *testing.T, rows *sql.Rows) {
1091+
foundAdmin := false
1092+
for rows.Next() {
1093+
var roleName string
1094+
err := rows.Scan(&roleName)
1095+
require.NoError(t, err)
1096+
if roleName == "admin" {
1097+
foundAdmin = true
1098+
}
1099+
}
1100+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1101+
},
1102+
},
1103+
{
1104+
config: Config{
1105+
ServerURI: *integrationServerFlag,
1106+
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
1107+
},
1108+
query: "SELECT 1",
1109+
expectError: false,
1110+
},
1111+
{
1112+
name: "Valid special roles via Config",
1113+
config: Config{
1114+
ServerURI: *integrationServerFlag,
1115+
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
1116+
},
1117+
query: "SELECT 1",
1118+
expectError: false,
1119+
},
1120+
{
1121+
name: "Valid hive admin role via DSN, not encoded url",
1122+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1123+
query: "SHOW ROLES FROM hive",
1124+
expectError: false,
1125+
validateRows: func(t *testing.T, rows *sql.Rows) {
1126+
foundAdmin := false
1127+
for rows.Next() {
1128+
var roleName string
1129+
err := rows.Scan(&roleName)
1130+
require.NoError(t, err)
1131+
if roleName == "admin" {
1132+
foundAdmin = true
1133+
}
1134+
}
1135+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1136+
},
1137+
},
1138+
{
1139+
name: "Valid roles via DSN, url encoded",
1140+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1141+
query: "SHOW ROLES FROM hive",
1142+
expectError: false,
1143+
validateRows: func(t *testing.T, rows *sql.Rows) {
1144+
foundAdmin := false
1145+
for rows.Next() {
1146+
var roleName string
1147+
err := rows.Scan(&roleName)
1148+
require.NoError(t, err)
1149+
if roleName == "admin" {
1150+
foundAdmin = true
1151+
}
1152+
}
1153+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1154+
},
1155+
},
1156+
{
1157+
name: "No role - should fail to show roles",
1158+
config: Config{
1159+
ServerURI: *integrationServerFlag,
1160+
},
1161+
query: "SHOW ROLES FROM hive",
1162+
expectError: true,
1163+
errorSubstr: "Access Denied",
1164+
},
1165+
{
1166+
name: "Wrong role - should fail to show roles",
1167+
config: Config{
1168+
ServerURI: *integrationServerFlag,
1169+
Roles: map[string]string{"hive": "ALL"},
1170+
},
1171+
query: "SHOW ROLES FROM hive",
1172+
expectError: true,
1173+
errorSubstr: "Access Denied",
1174+
},
1175+
{
1176+
name: "Non-existent catalog role",
1177+
config: Config{
1178+
ServerURI: *integrationServerFlag,
1179+
Roles: map[string]string{"not-exist-catalog": "role1"},
1180+
},
1181+
query: "SELECT 1",
1182+
expectError: true,
1183+
errorSubstr: "USER_ERROR: Catalog",
1184+
},
1185+
}
1186+
1187+
for _, tt := range tests {
1188+
t.Run(tt.name, func(t *testing.T) {
1189+
var dns string
1190+
var err error
1191+
1192+
if tt.rawDSN != "" {
1193+
dns = tt.rawDSN
1194+
} else {
1195+
dns, err = tt.config.FormatDSN()
1196+
if err != nil {
1197+
t.Fatal(err)
1198+
}
1199+
}
1200+
1201+
db := integrationOpen(t, dns)
1202+
defer db.Close()
1203+
1204+
rows, err := db.Query(tt.query)
1205+
1206+
if tt.expectError {
1207+
require.Error(t, err)
1208+
if tt.errorSubstr != "" {
1209+
require.Contains(t, err.Error(), tt.errorSubstr)
1210+
}
1211+
} else {
1212+
require.NoError(t, err)
1213+
if tt.validateRows != nil && rows != nil {
1214+
defer rows.Close()
1215+
tt.validateRows(t, rows)
1216+
}
1217+
}
1218+
})
1219+
}
1220+
}
10331221

10341222
func TestIntegrationQueryParametersSelect(t *testing.T) {
10351223
scenarios := []struct {

0 commit comments

Comments
 (0)