Skip to content

Commit fe4ab99

Browse files
committed
Add hive catalog, Add proper integration tests for roles
1 parent 5eef69f commit fe4ab99

File tree

5 files changed

+175
-50
lines changed

5 files changed

+175
-50
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,13 @@ dsn, err := c.FormatDSN()
314314

315315
**Query parameter example (overrides DSN roles):**
316316
```go
317-
rows, err := db.Query(query, sql.Named("X-Trino-Role", "catalog1:role1;catalog2:role2;catalog3:ALL"))
317+
rows, err := db.Query(
318+
query,
319+
sql.Named("X-Trino-Role", map[string]string{
320+
"catalog1": "role1",
321+
"catalog2": "role2",
322+
}),
323+
)
318324
```
319325

320326
#### Examples

trino/etc/catalog/hive.properties

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
connector.name=hive
2+
hive.metastore=file
3+
hive.metastore.catalog.dir=/tmp/metastore
4+
hive.security=sql-standard
5+
fs.hadoop.enabled=true

trino/integration_test.go

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ func TestMain(m *testing.M) {
183183

184184
waitForContainerHealth(trinoResource.Container.ID, "trino")
185185

186+
err = grantAdminRoleToTestUser()
187+
if err != nil {
188+
log.Printf("Warning: Failed to grant admin role to test user: %s", err)
189+
}
190+
186191
*integrationServerFlag = "http://test@localhost:" + trinoResource.GetPort("8080/tcp")
187192
tlsServer = "https://admin:admin@localhost:" + trinoResource.GetPort("8443/tcp")
188193

@@ -218,6 +223,35 @@ func TestMain(m *testing.M) {
218223
os.Exit(code)
219224
}
220225

226+
func grantAdminRoleToTestUser() error {
227+
grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
228+
229+
execCmd := []string{
230+
"trino",
231+
"--user", "admin",
232+
"--execute", grantSQL,
233+
}
234+
exec, err := pool.Client.CreateExec(docker.CreateExecOptions{
235+
Container: trinoResource.Container.ID,
236+
Cmd: execCmd,
237+
})
238+
if err != nil {
239+
log.Printf("Warning: Failed to create exec for GRANT: %s", err)
240+
} else {
241+
var stdout, stderr bytes.Buffer
242+
err = pool.Client.StartExec(exec.ID, docker.StartExecOptions{
243+
Detach: false,
244+
OutputStream: &stdout,
245+
ErrorStream: &stderr,
246+
})
247+
if err != nil {
248+
log.Printf("Warning: Failed to execute GRANT: %s", err)
249+
}
250+
}
251+
252+
return err
253+
}
254+
221255
func getOrCreateLocalStack(pool *dt.Pool, networkID string) *dt.Resource {
222256
resource, ok := pool.ContainerByName(DockerLocalStackName)
223257
if ok {
@@ -1033,18 +1067,41 @@ func TestIntegrationNoResults(t *testing.T) {
10331067
}
10341068
func TestRoleHeaderSupport(t *testing.T) {
10351069
tests := []struct {
1036-
name string
1037-
config Config
1038-
rawDSN string
1039-
expectError bool
1040-
errorSubstr string
1070+
name string
1071+
config Config
1072+
rawDSN string
1073+
query string
1074+
expectError bool
1075+
errorSubstr string
1076+
validateRows func(t *testing.T, rows *sql.Rows)
10411077
}{
10421078
{
1043-
name: "Valid roles via Config",
1079+
name: "Valid hive admin role via Config",
1080+
config: Config{
1081+
ServerURI: *integrationServerFlag,
1082+
Roles: map[string]string{"hive": "admin"},
1083+
},
1084+
query: "SHOW ROLES FROM hive",
1085+
expectError: false,
1086+
validateRows: func(t *testing.T, rows *sql.Rows) {
1087+
foundAdmin := false
1088+
for rows.Next() {
1089+
var roleName string
1090+
err := rows.Scan(&roleName)
1091+
require.NoError(t, err)
1092+
if roleName == "admin" {
1093+
foundAdmin = true
1094+
}
1095+
}
1096+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1097+
},
1098+
},
1099+
{
10441100
config: Config{
10451101
ServerURI: *integrationServerFlag,
1046-
Roles: map[string]string{"tpch": "role1", "memory": "role2"},
1102+
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
10471103
},
1104+
query: "SELECT 1",
10481105
expectError: false,
10491106
},
10501107
{
@@ -1053,24 +1110,71 @@ func TestRoleHeaderSupport(t *testing.T) {
10531110
ServerURI: *integrationServerFlag,
10541111
Roles: map[string]string{"tpch": "NONE", "memory": "ALL"},
10551112
},
1113+
query: "SELECT 1",
10561114
expectError: false,
10571115
},
10581116
{
1059-
name: "Valid roles via DSN, not encoded url",
1060-
rawDSN: *integrationServerFlag + "?roles=tpch:role1;memory:role2",
1117+
name: "Valid hive admin role via DSN, not encoded url",
1118+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1119+
query: "SHOW ROLES FROM hive",
10611120
expectError: false,
1121+
validateRows: func(t *testing.T, rows *sql.Rows) {
1122+
foundAdmin := false
1123+
for rows.Next() {
1124+
var roleName string
1125+
err := rows.Scan(&roleName)
1126+
require.NoError(t, err)
1127+
if roleName == "admin" {
1128+
foundAdmin = true
1129+
}
1130+
}
1131+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1132+
},
10621133
},
10631134
{
10641135
name: "Valid roles via DSN, url encoded",
1065-
rawDSN: *integrationServerFlag + "?roles%3Dtpch%3Arole1%3Bmemory%3Arole2",
1136+
rawDSN: *integrationServerFlag + "?roles=hive:admin",
1137+
query: "SHOW ROLES FROM hive",
10661138
expectError: false,
1139+
validateRows: func(t *testing.T, rows *sql.Rows) {
1140+
foundAdmin := false
1141+
for rows.Next() {
1142+
var roleName string
1143+
err := rows.Scan(&roleName)
1144+
require.NoError(t, err)
1145+
if roleName == "admin" {
1146+
foundAdmin = true
1147+
}
1148+
}
1149+
require.True(t, foundAdmin, "Expected to find 'admin' role in SHOW ROLES output")
1150+
},
1151+
},
1152+
{
1153+
name: "No role - should fail to show roles",
1154+
config: Config{
1155+
ServerURI: *integrationServerFlag,
1156+
},
1157+
query: "SHOW ROLES FROM hive",
1158+
expectError: true,
1159+
errorSubstr: "Access Denied",
1160+
},
1161+
{
1162+
name: "Wrong role - should fail to show roles",
1163+
config: Config{
1164+
ServerURI: *integrationServerFlag,
1165+
Roles: map[string]string{"hive": "ALL"},
1166+
},
1167+
query: "SHOW ROLES FROM hive",
1168+
expectError: true,
1169+
errorSubstr: "Access Denied",
10671170
},
10681171
{
10691172
name: "Non-existent catalog role",
10701173
config: Config{
10711174
ServerURI: *integrationServerFlag,
10721175
Roles: map[string]string{"not-exist-catalog": "role1"},
10731176
},
1177+
query: "SELECT 1",
10741178
expectError: true,
10751179
errorSubstr: "USER_ERROR: Catalog",
10761180
},
@@ -1091,7 +1195,9 @@ func TestRoleHeaderSupport(t *testing.T) {
10911195
}
10921196

10931197
db := integrationOpen(t, dns)
1094-
_, err = db.Query("SELECT 1")
1198+
defer db.Close()
1199+
1200+
rows, err := db.Query(tt.query)
10951201

10961202
if tt.expectError {
10971203
require.Error(t, err)
@@ -1100,6 +1206,10 @@ func TestRoleHeaderSupport(t *testing.T) {
11001206
}
11011207
} else {
11021208
require.NoError(t, err)
1209+
if tt.validateRows != nil && rows != nil {
1210+
defer rows.Close()
1211+
tt.validateRows(t, rows)
1212+
}
11031213
}
11041214
})
11051215
}

trino/trino.go

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -487,27 +487,29 @@ func formatRolesFromMap(rolesMap map[string]string) string {
487487
return strings.Join(formattedRoles, commaSeparator)
488488
}
489489

490-
// parseAndFormatRoles parses roles from DSN format (catalog1:roleA;catalog2:roleB)
491-
// and formats them into the Trino header format
492-
func parseAndFormatRoles(rolesString string) (string, error) {
493-
var formattedRoles []string
494-
for _, entry := range strings.Split(rolesString, mapEntrySeparator) {
495-
parts := strings.SplitN(entry, mapKeySeparator, 2)
496-
if len(parts) != 2 {
497-
return "", fmt.Errorf("invalid role entry: %q", entry)
498-
}
499-
formattedRoles = append(formattedRoles, formatRoleEntry(parts[0], parts[1]))
500-
}
501-
sort.Strings(formattedRoles)
502-
return strings.Join(formattedRoles, commaSeparator), nil
503-
}
504-
505490
// formatRoleEntry formats a single catalog role entry into Trino header format
506491
func formatRoleEntry(catalog, role string) string {
507492
if role == "ALL" || role == "NONE" {
508493
return fmt.Sprintf("%s=%s", catalog, role)
509494
}
510-
return fmt.Sprintf("%s=ROLE{%q}", catalog, role)
495+
return fmt.Sprintf("%s=ROLE{%s}", catalog, role)
496+
}
497+
498+
// formatHeaderValue converts a named argument value to a string suitable for HTTP headers.
499+
func formatHeaderValue(headerName string, value interface{}) (string, error) {
500+
if headerName == trinoRoleHeader {
501+
rolesMap, ok := value.(map[string]string)
502+
if !ok {
503+
return "", fmt.Errorf("%s must be a map[string]string, got %T", trinoRoleHeader, value)
504+
}
505+
return formatRolesFromMap(rolesMap), nil
506+
}
507+
508+
headerValue, ok := value.(string)
509+
if !ok {
510+
return "", fmt.Errorf("%s must be a string, got %T", headerName, value)
511+
}
512+
return headerValue, nil
511513
}
512514

513515
func newConn(dsn string) (*Conn, error) {
@@ -1041,6 +1043,10 @@ func (st *driverStmt) CheckNamedValue(arg *driver.NamedValue) error {
10411043
return nil
10421044
}
10431045

1046+
if arg.Name == trinoRoleHeader {
1047+
return nil
1048+
}
1049+
10441050
if arg.Name == trinoProgressCallbackParam {
10451051
return nil
10461052
}
@@ -1235,29 +1241,27 @@ func (st *driverStmt) exec(ctx context.Context, args []driver.NamedValue) (*stmt
12351241
continue
12361242
}
12371243

1238-
s, err := Serial(arg.Value)
1239-
if err != nil {
1240-
return nil, err
1241-
}
1242-
12431244
if strings.HasPrefix(arg.Name, trinoHeaderPrefix) {
1244-
headerValue := arg.Value.(string)
1245+
headerValue, err := formatHeaderValue(arg.Name, arg.Value)
1246+
if err != nil {
1247+
return nil, err
1248+
}
12451249

12461250
if arg.Name == trinoUserHeader {
12471251
st.user = headerValue
12481252
}
12491253

12501254
if arg.Name == trinoRoleHeader {
1251-
formattedRoles, err := parseAndFormatRoles(headerValue)
1252-
if err != nil {
1253-
return nil, err
1254-
}
1255-
st.conn.httpHeaders.Set(trinoRoleHeader, formattedRoles)
1256-
headerValue = formattedRoles
1255+
st.conn.httpHeaders.Set(trinoRoleHeader, headerValue)
12571256
}
12581257

12591258
hs.Add(arg.Name, headerValue)
12601259
} else {
1260+
s, err := Serial(arg.Value)
1261+
if err != nil {
1262+
return nil, err
1263+
}
1264+
12611265
if st.conn.useExplicitPrepare && hs.Get(preparedStatementHeader) == "" {
12621266
for _, v := range st.conn.httpHeaders.Values(preparedStatementHeader) {
12631267
hs.Add(preparedStatementHeader, v)

trino/trino_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,24 +1347,24 @@ func TestQueryCancellation(t *testing.T) {
13471347
assert.EqualError(t, err, ErrQueryCancelled.Error(), "unexpected error")
13481348
}
13491349

1350-
func TestTrinoRoleHeader(t *testing.T) {
1350+
func TestRoleHeader(t *testing.T) {
13511351
tests := []struct {
13521352
name string
13531353
roles map[string]string
1354-
namedArg string
1354+
namedArgRoles map[string]string
13551355
expectedHeader string
13561356
}{
13571357
{
13581358
name: "Roles from config",
13591359
roles: map[string]string{"catalog1": "role1", "catalog2": "role2"},
1360-
namedArg: "",
1361-
expectedHeader: `catalog1=ROLE{"role1"},catalog2=ROLE{"role2"}`,
1360+
namedArgRoles: nil,
1361+
expectedHeader: `catalog1=ROLE{role1},catalog2=ROLE{role2}`,
13621362
},
13631363
{
1364-
name: "Override roles with named argument",
1364+
name: "Override dsn roles with named argument",
13651365
roles: map[string]string{"catalog1": "role1"},
1366-
namedArg: `catalog3:role3;catalog4:role4;catalog5:ALL`,
1367-
expectedHeader: `catalog3=ROLE{"role3"},catalog4=ROLE{"role4"},catalog5=ALL`,
1366+
namedArgRoles: map[string]string{"catalog3": "role3", "catalog4": "role4", "catalog5": "ALL"},
1367+
expectedHeader: `catalog3=ROLE{role3},catalog4=ROLE{role4},catalog5=ALL`,
13681368
},
13691369
}
13701370

@@ -1391,8 +1391,8 @@ func TestTrinoRoleHeader(t *testing.T) {
13911391
db, err := sql.Open("trino", dsn)
13921392
require.NoError(t, err)
13931393

1394-
if tt.namedArg != "" {
1395-
_, _ = db.Query("SELECT 1", sql.Named("X-Trino-Role", tt.namedArg))
1394+
if tt.namedArgRoles != nil {
1395+
_, _ = db.Query("SELECT 1", sql.Named("X-Trino-Role", tt.namedArgRoles))
13961396
} else {
13971397
_, _ = db.Query("SELECT 1")
13981398
}
@@ -2801,7 +2801,7 @@ func TestSetRoleHeader(t *testing.T) {
28012801
require.NoError(t, err)
28022802
require.NoError(t, rows.Close())
28032803

2804-
assert.Equal(t, `catalog=ROLE{"user"}`, firstRoleHeader, "initial role from DSN should be sent in first request")
2804+
assert.Equal(t, `catalog=ROLE{user}`, firstRoleHeader, "initial role from DSN should be sent in first request")
28052805
assert.Equal(t, "ROLE%7Badmin%7D", secondRoleHeader, "server-set role should be sent in subsequent requests")
28062806
assert.NotEqual(t, firstRoleHeader, secondRoleHeader, "role should have changed from DSN value to server-set value")
28072807
}

0 commit comments

Comments
 (0)