@@ -50,6 +50,7 @@ import (
5050 "github.com/golang-jwt/jwt/v5"
5151 dt "github.com/ory/dockertest/v3"
5252 docker "github.com/ory/dockertest/v3/docker"
53+ "github.com/stretchr/testify/require"
5354)
5455
5556const (
@@ -136,11 +137,16 @@ func TestMain(m *testing.M) {
136137 }
137138
138139 mounts := []string {
139- wd + "/etc/catalog:/etc/trino/catalog" ,
140140 wd + "/etc/secrets:/etc/trino/secrets" ,
141141 wd + "/etc/jvm.config:/etc/trino/jvm.config" ,
142142 wd + "/etc/node.properties:/etc/trino/node.properties" ,
143143 wd + "/etc/password-authenticator.properties:/etc/trino/password-authenticator.properties" ,
144+ wd + "/etc/catalog/memory.properties:/etc/trino/catalog/memory.properties" ,
145+ wd + "/etc/catalog/tpch.properties:/etc/trino/catalog/tpch.properties" ,
146+ }
147+ version , err := strconv .Atoi (* trinoImageTagFlag )
148+ if (err != nil && * trinoImageTagFlag == "latest" ) || (err == nil && version >= 458 ) {
149+ mounts = append (mounts , wd + "/etc/catalog/hive.properties:/etc/trino/catalog/hive.properties" )
144150 }
145151
146152 if spoolingProtocolSupported {
@@ -182,6 +188,11 @@ func TestMain(m *testing.M) {
182188
183189 waitForContainerHealth (trinoResource .Container .ID , "trino" )
184190
191+ err = grantAdminRoleToTestUser ()
192+ if err != nil {
193+ log .Fatalf ("Warning: Failed to grant admin role to test user: %s" , err )
194+ }
195+
185196 * integrationServerFlag = "http://test@localhost:" + trinoResource .GetPort ("8080/tcp" )
186197 tlsServer = "https://admin:admin@localhost:" + trinoResource .GetPort ("8443/tcp" )
187198
@@ -217,6 +228,35 @@ func TestMain(m *testing.M) {
217228 os .Exit (code )
218229}
219230
231+ func grantAdminRoleToTestUser () error {
232+ grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
233+
234+ execCmd := []string {
235+ "trino" ,
236+ "--user" , "admin" ,
237+ "--execute" , grantSQL ,
238+ }
239+ exec , err := pool .Client .CreateExec (docker.CreateExecOptions {
240+ Container : trinoResource .Container .ID ,
241+ Cmd : execCmd ,
242+ })
243+ if err != nil {
244+ log .Printf ("Warning: Failed to create exec for GRANT: %s" , err )
245+ } else {
246+ var stdout , stderr bytes.Buffer
247+ err = pool .Client .StartExec (exec .ID , docker.StartExecOptions {
248+ Detach : false ,
249+ OutputStream : & stdout ,
250+ ErrorStream : & stderr ,
251+ })
252+ if err != nil {
253+ log .Printf ("Warning: Failed to execute GRANT: %s" , err )
254+ }
255+ }
256+
257+ return err
258+ }
259+
220260func getOrCreateLocalStack (pool * dt.Pool , networkID string ) * dt.Resource {
221261 resource , ok := pool .ContainerByName (DockerLocalStackName )
222262 if ok {
@@ -1030,6 +1070,159 @@ func TestIntegrationNoResults(t *testing.T) {
10301070 t .Fatal (err )
10311071 }
10321072}
1073+ func TestRoleHeaderSupport (t * testing.T ) {
1074+ version , err := strconv .Atoi (* trinoImageTagFlag )
1075+ if (err != nil && * trinoImageTagFlag != "latest" ) || (err == nil && version < 458 ) {
1076+ t .Skip ("Skipping test when not using Trino 458 or later." )
1077+ }
1078+ tests := []struct {
1079+ name string
1080+ config Config
1081+ rawDSN string
1082+ query string
1083+ expectError bool
1084+ errorSubstr string
1085+ validateRows func (t * testing.T , rows * sql.Rows )
1086+ }{
1087+ {
1088+ name : "Valid hive admin role via Config" ,
1089+ config : Config {
1090+ ServerURI : * integrationServerFlag ,
1091+ Roles : map [string ]string {"hive" : "admin" },
1092+ },
1093+ query : "SHOW ROLES FROM hive" ,
1094+ expectError : false ,
1095+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1096+ foundAdmin := false
1097+ for rows .Next () {
1098+ var roleName string
1099+ err := rows .Scan (& roleName )
1100+ require .NoError (t , err )
1101+ if roleName == "admin" {
1102+ foundAdmin = true
1103+ }
1104+ }
1105+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1106+ },
1107+ },
1108+ {
1109+ config : Config {
1110+ ServerURI : * integrationServerFlag ,
1111+ Roles : map [string ]string {"tpch" : "NONE" , "memory" : "ALL" },
1112+ },
1113+ query : "SELECT 1" ,
1114+ expectError : false ,
1115+ },
1116+ {
1117+ name : "Valid special roles via Config" ,
1118+ config : Config {
1119+ ServerURI : * integrationServerFlag ,
1120+ Roles : map [string ]string {"tpch" : "NONE" , "memory" : "ALL" },
1121+ },
1122+ query : "SELECT 1" ,
1123+ expectError : false ,
1124+ },
1125+ {
1126+ name : "Valid hive admin role via DSN, not encoded url" ,
1127+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1128+ query : "SHOW ROLES FROM hive" ,
1129+ expectError : false ,
1130+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1131+ foundAdmin := false
1132+ for rows .Next () {
1133+ var roleName string
1134+ err := rows .Scan (& roleName )
1135+ require .NoError (t , err )
1136+ if roleName == "admin" {
1137+ foundAdmin = true
1138+ }
1139+ }
1140+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1141+ },
1142+ },
1143+ {
1144+ name : "Valid roles via DSN, url encoded" ,
1145+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1146+ query : "SHOW ROLES FROM hive" ,
1147+ expectError : false ,
1148+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1149+ foundAdmin := false
1150+ for rows .Next () {
1151+ var roleName string
1152+ err := rows .Scan (& roleName )
1153+ require .NoError (t , err )
1154+ if roleName == "admin" {
1155+ foundAdmin = true
1156+ }
1157+ }
1158+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1159+ },
1160+ },
1161+ {
1162+ name : "No role - should fail to show roles" ,
1163+ config : Config {
1164+ ServerURI : * integrationServerFlag ,
1165+ },
1166+ query : "SHOW ROLES FROM hive" ,
1167+ expectError : true ,
1168+ errorSubstr : "Access Denied" ,
1169+ },
1170+ {
1171+ name : "Wrong role - should fail to show roles" ,
1172+ config : Config {
1173+ ServerURI : * integrationServerFlag ,
1174+ Roles : map [string ]string {"hive" : "ALL" },
1175+ },
1176+ query : "SHOW ROLES FROM hive" ,
1177+ expectError : true ,
1178+ errorSubstr : "Access Denied" ,
1179+ },
1180+ {
1181+ name : "Non-existent catalog role" ,
1182+ config : Config {
1183+ ServerURI : * integrationServerFlag ,
1184+ Roles : map [string ]string {"not-exist-catalog" : "role1" },
1185+ },
1186+ query : "SELECT 1" ,
1187+ expectError : true ,
1188+ errorSubstr : "USER_ERROR: Catalog" ,
1189+ },
1190+ }
1191+
1192+ for _ , tt := range tests {
1193+ t .Run (tt .name , func (t * testing.T ) {
1194+ var dns string
1195+ var err error
1196+
1197+ if tt .rawDSN != "" {
1198+ dns = tt .rawDSN
1199+ } else {
1200+ dns , err = tt .config .FormatDSN ()
1201+ if err != nil {
1202+ t .Fatal (err )
1203+ }
1204+ }
1205+
1206+ db := integrationOpen (t , dns )
1207+ defer db .Close ()
1208+
1209+ rows , err := db .Query (tt .query )
1210+
1211+ if tt .expectError {
1212+ require .Error (t , err )
1213+ if tt .errorSubstr != "" {
1214+ require .Contains (t , err .Error (), tt .errorSubstr )
1215+ }
1216+ } else {
1217+ require .NoError (t , err )
1218+ if tt .validateRows != nil && rows != nil {
1219+ defer rows .Close ()
1220+ tt .validateRows (t , rows )
1221+ }
1222+ }
1223+ })
1224+ }
1225+ }
10331226
10341227func TestIntegrationQueryParametersSelect (t * testing.T ) {
10351228 scenarios := []struct {
0 commit comments