Skip to content

Commit b7bde9b

Browse files
committed
Add support for catalog roles
1 parent 31e5182 commit b7bde9b

File tree

5 files changed

+493
-10
lines changed

5 files changed

+493
-10
lines changed

README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,40 @@ dsn, err := config.FormatDSN()
289289
```go
290290
rows, err := db.Query(query, sql.Named("X-Trino-Client-Tags", "tag1,tag2,tag3"))
291291
```
292+
=======
293+
294+
#### `roles`
295+
296+
```
297+
Type: string
298+
Format: roles=catalog1:role1;catalog2=role2
299+
Valid values: A semicolon-separated list of catalog-to-role assignments, where each assignment maps a catalog to a role.
300+
Default: empty
301+
```
302+
The roles parameter defines authorization roles to assume for one or more catalogs during the Trino session.
303+
304+
##### Example
305+
``` go
306+
c := &Config{
307+
ServerURI: "https://foobar@localhost:8090",
308+
SessionProperties: map[string]string{"query_priority": "1"},
309+
Roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
310+
}
311+
312+
dsn, err := c.FormatDSN()
313+
```
314+
315+
**Query parameter example (overrides DSN roles):**
316+
```go
317+
rows, err := db.Query(
318+
query,
319+
sql.Named("X-Trino-Role", map[string]string{
320+
"catalog1": "role1",
321+
"catalog2": "role2",
322+
}),
323+
)
324+
```
325+
292326
#### Examples
293327

294328
```
@@ -299,6 +333,11 @@ http://user@localhost:8080?source=hello&catalog=default&schema=foobar
299333
https://user@localhost:8443?session_properties=query_max_run_time=10m,query_priority=2
300334
```
301335

336+
337+
```
338+
http://user@localhost:8080?source=hello&catalog=default&schema=foobar&roles=catalog1:role1;catalog2:role2
339+
```
340+
302341
## Data types
303342

304343
### Query arguments

trino/etc/catalog/hive.properties

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
connector.name=hive
2+
hive.metastore=file
3+
hive.metastore.catalog.dir=/tmp/metastore
4+
hive.security=sql-standard
5+
fs.hadoop.enabled=true

trino/integration_test.go

Lines changed: 194 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ import (
5050
"github.com/golang-jwt/jwt/v5"
5151
dt "github.com/ory/dockertest/v3"
5252
docker "github.com/ory/dockertest/v3/docker"
53+
"github.com/stretchr/testify/require"
5354
)
5455

5556
const (
@@ -136,11 +137,16 @@ func TestMain(m *testing.M) {
136137
}
137138

138139
mounts := []string{
139-
wd + "/etc/catalog:/etc/trino/catalog",
140140
wd + "/etc/secrets:/etc/trino/secrets",
141141
wd + "/etc/jvm.config:/etc/trino/jvm.config",
142142
wd + "/etc/node.properties:/etc/trino/node.properties",
143143
wd + "/etc/password-authenticator.properties:/etc/trino/password-authenticator.properties",
144+
wd + "/etc/catalog/memory.properties:/etc/trino/catalog/memory.properties",
145+
wd + "/etc/catalog/tpch.properties:/etc/trino/catalog/tpch.properties",
146+
}
147+
version, err := strconv.Atoi(*trinoImageTagFlag)
148+
if (err != nil && *trinoImageTagFlag == "latest") || (err == nil && version >= 458) {
149+
mounts = append(mounts, wd+"/etc/catalog/hive.properties:/etc/trino/catalog/hive.properties")
144150
}
145151

146152
if spoolingProtocolSupported {
@@ -182,6 +188,11 @@ func TestMain(m *testing.M) {
182188

183189
waitForContainerHealth(trinoResource.Container.ID, "trino")
184190

191+
err = grantAdminRoleToTestUser()
192+
if err != nil {
193+
log.Fatalf("Warning: Failed to grant admin role to test user: %s", err)
194+
}
195+
185196
*integrationServerFlag = "http://test@localhost:" + trinoResource.GetPort("8080/tcp")
186197
tlsServer = "https://admin:admin@localhost:" + trinoResource.GetPort("8443/tcp")
187198

@@ -217,6 +228,35 @@ func TestMain(m *testing.M) {
217228
os.Exit(code)
218229
}
219230

231+
func grantAdminRoleToTestUser() error {
232+
grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
233+
234+
execCmd := []string{
235+
"trino",
236+
"--user", "admin",
237+
"--execute", grantSQL,
238+
}
239+
exec, err := pool.Client.CreateExec(docker.CreateExecOptions{
240+
Container: trinoResource.Container.ID,
241+
Cmd: execCmd,
242+
})
243+
if err != nil {
244+
log.Printf("Warning: Failed to create exec for GRANT: %s", err)
245+
} else {
246+
var stdout, stderr bytes.Buffer
247+
err = pool.Client.StartExec(exec.ID, docker.StartExecOptions{
248+
Detach: false,
249+
OutputStream: &stdout,
250+
ErrorStream: &stderr,
251+
})
252+
if err != nil {
253+
log.Printf("Warning: Failed to execute GRANT: %s", err)
254+
}
255+
}
256+
257+
return err
258+
}
259+
220260
func getOrCreateLocalStack(pool *dt.Pool, networkID string) *dt.Resource {
221261
resource, ok := pool.ContainerByName(DockerLocalStackName)
222262
if ok {
@@ -1030,6 +1070,159 @@ func TestIntegrationNoResults(t *testing.T) {
10301070
t.Fatal(err)
10311071
}
10321072
}
1073+
func TestRoleHeaderSupport(t *testing.T) {
1074+
version, err := strconv.Atoi(*trinoImageTagFlag)
1075+
if (err != nil && *trinoImageTagFlag != "latest") || (err == nil && version < 458) {
1076+
t.Skip("Skipping test when not using Trino 458 or later.")
1077+
}
1078+
tests := []struct {
1079+
name string
1080+
config Config
1081+
rawDSN string
1082+
query string
1083+
expectError bool
1084+
errorSubstr string
1085+
validateRows func(t *testing.T, rows *sql.Rows)
1086+
}{
1087+
{
1088+
name: "Valid hive admin role via Config",
1089+
config: Config{
1090+
ServerURI: *integrationServerFlag,
1091+
Roles: map[string]string{"hive": "admin"},
1092+
},
1093+
query: "SHOW ROLES FROM hive",
1094+
expectError: false,
1095+
validateRows: func(t *testing.T, rows *sql.Rows) {
1096+
foundAdmin := false
1097+
for rows.Next() {
1098+
var roleName string
1099+
err := rows.Scan(&roleName)
1100+
require.NoError(t, err)
1101+
if roleName == "admin" {
1102+
foundAdmin = true
1103+
}
1104+
}
1105+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1106+
},
1107+
},
1108+
{
1109+
config: Config{
1110+
ServerURI: *integrationServerFlag,
1111+
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
1112+
},
1113+
query: "SELECT 1",
1114+
expectError: false,
1115+
},
1116+
{
1117+
name: "Valid special roles via Config",
1118+
config: Config{
1119+
ServerURI: *integrationServerFlag,
1120+
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
1121+
},
1122+
query: "SELECT 1",
1123+
expectError: false,
1124+
},
1125+
{
1126+
name: "Valid hive admin role via DSN, not encoded url",
1127+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1128+
query: "SHOW ROLES FROM hive",
1129+
expectError: false,
1130+
validateRows: func(t *testing.T, rows *sql.Rows) {
1131+
foundAdmin := false
1132+
for rows.Next() {
1133+
var roleName string
1134+
err := rows.Scan(&roleName)
1135+
require.NoError(t, err)
1136+
if roleName == "admin" {
1137+
foundAdmin = true
1138+
}
1139+
}
1140+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1141+
},
1142+
},
1143+
{
1144+
name: "Valid roles via DSN, url encoded",
1145+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1146+
query: "SHOW ROLES FROM hive",
1147+
expectError: false,
1148+
validateRows: func(t *testing.T, rows *sql.Rows) {
1149+
foundAdmin := false
1150+
for rows.Next() {
1151+
var roleName string
1152+
err := rows.Scan(&roleName)
1153+
require.NoError(t, err)
1154+
if roleName == "admin" {
1155+
foundAdmin = true
1156+
}
1157+
}
1158+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1159+
},
1160+
},
1161+
{
1162+
name: "No role - should fail to show roles",
1163+
config: Config{
1164+
ServerURI: *integrationServerFlag,
1165+
},
1166+
query: "SHOW ROLES FROM hive",
1167+
expectError: true,
1168+
errorSubstr: "Access Denied",
1169+
},
1170+
{
1171+
name: "Wrong role - should fail to show roles",
1172+
config: Config{
1173+
ServerURI: *integrationServerFlag,
1174+
Roles: map[string]string{"hive": "ALL"},
1175+
},
1176+
query: "SHOW ROLES FROM hive",
1177+
expectError: true,
1178+
errorSubstr: "Access Denied",
1179+
},
1180+
{
1181+
name: "Non-existent catalog role",
1182+
config: Config{
1183+
ServerURI: *integrationServerFlag,
1184+
Roles: map[string]string{"not-exist-catalog": "role1"},
1185+
},
1186+
query: "SELECT 1",
1187+
expectError: true,
1188+
errorSubstr: "USER_ERROR: Catalog",
1189+
},
1190+
}
1191+
1192+
for _, tt := range tests {
1193+
t.Run(tt.name, func(t *testing.T) {
1194+
var dns string
1195+
var err error
1196+
1197+
if tt.rawDSN != "" {
1198+
dns = tt.rawDSN
1199+
} else {
1200+
dns, err = tt.config.FormatDSN()
1201+
if err != nil {
1202+
t.Fatal(err)
1203+
}
1204+
}
1205+
1206+
db := integrationOpen(t, dns)
1207+
defer db.Close()
1208+
1209+
rows, err := db.Query(tt.query)
1210+
1211+
if tt.expectError {
1212+
require.Error(t, err)
1213+
if tt.errorSubstr != "" {
1214+
require.Contains(t, err.Error(), tt.errorSubstr)
1215+
}
1216+
} else {
1217+
require.NoError(t, err)
1218+
if tt.validateRows != nil && rows != nil {
1219+
defer rows.Close()
1220+
tt.validateRows(t, rows)
1221+
}
1222+
}
1223+
})
1224+
}
1225+
}
10331226

10341227
func TestIntegrationQueryParametersSelect(t *testing.T) {
10351228
scenarios := []struct {

0 commit comments

Comments
 (0)