@@ -183,6 +183,11 @@ func TestMain(m *testing.M) {
183183
184184 waitForContainerHealth (trinoResource .Container .ID , "trino" )
185185
186+ err = grantAdminRoleToTestUser ()
187+ if err != nil {
188+ log .Printf ("Warning: Failed to grant admin role to test user: %s" , err )
189+ }
190+
186191 * integrationServerFlag = "http://test@localhost:" + trinoResource .GetPort ("8080/tcp" )
187192 tlsServer = "https://admin:admin@localhost:" + trinoResource .GetPort ("8443/tcp" )
188193
@@ -218,6 +223,35 @@ func TestMain(m *testing.M) {
218223 os .Exit (code )
219224}
220225
226+ func grantAdminRoleToTestUser () error {
227+ grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
228+
229+ execCmd := []string {
230+ "trino" ,
231+ "--user" , "admin" ,
232+ "--execute" , grantSQL ,
233+ }
234+ exec , err := pool .Client .CreateExec (docker.CreateExecOptions {
235+ Container : trinoResource .Container .ID ,
236+ Cmd : execCmd ,
237+ })
238+ if err != nil {
239+ log .Printf ("Warning: Failed to create exec for GRANT: %s" , err )
240+ } else {
241+ var stdout , stderr bytes.Buffer
242+ err = pool .Client .StartExec (exec .ID , docker.StartExecOptions {
243+ Detach : false ,
244+ OutputStream : & stdout ,
245+ ErrorStream : & stderr ,
246+ })
247+ if err != nil {
248+ log .Printf ("Warning: Failed to execute GRANT: %s" , err )
249+ }
250+ }
251+
252+ return err
253+ }
254+
221255func getOrCreateLocalStack (pool * dt.Pool , networkID string ) * dt.Resource {
222256 resource , ok := pool .ContainerByName (DockerLocalStackName )
223257 if ok {
@@ -1033,18 +1067,41 @@ func TestIntegrationNoResults(t *testing.T) {
10331067}
10341068func TestRoleHeaderSupport (t * testing.T ) {
10351069 tests := []struct {
1036- name string
1037- config Config
1038- rawDSN string
1039- expectError bool
1040- errorSubstr string
1070+ name string
1071+ config Config
1072+ rawDSN string
1073+ query string
1074+ expectError bool
1075+ errorSubstr string
1076+ validateRows func (t * testing.T , rows * sql.Rows )
10411077 }{
10421078 {
1043- name : "Valid roles via Config" ,
1079+ name : "Valid hive admin role via Config" ,
1080+ config : Config {
1081+ ServerURI : * integrationServerFlag ,
1082+ Roles : map [string ]string {"hive" : "admin" },
1083+ },
1084+ query : "SHOW ROLES FROM hive" ,
1085+ expectError : false ,
1086+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1087+ foundAdmin := false
1088+ for rows .Next () {
1089+ var roleName string
1090+ err := rows .Scan (& roleName )
1091+ require .NoError (t , err )
1092+ if roleName == "admin" {
1093+ foundAdmin = true
1094+ }
1095+ }
1096+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1097+ },
1098+ },
1099+ {
10441100 config : Config {
10451101 ServerURI : * integrationServerFlag ,
1046- Roles : map [string ]string {"tpch" : "role1 " , "memory" : "role2 " },
1102+ Roles : map [string ]string {"tpch" : "NONE " , "memory" : "ALL " },
10471103 },
1104+ query : "SELECT 1" ,
10481105 expectError : false ,
10491106 },
10501107 {
@@ -1053,24 +1110,71 @@ func TestRoleHeaderSupport(t *testing.T) {
10531110 ServerURI : * integrationServerFlag ,
10541111 Roles : map [string ]string {"tpch" : "NONE" , "memory" : "ALL" },
10551112 },
1113+ query : "SELECT 1" ,
10561114 expectError : false ,
10571115 },
10581116 {
1059- name : "Valid roles via DSN, not encoded url" ,
1060- rawDSN : * integrationServerFlag + "?roles=tpch:role1;memory:role2" ,
1117+ name : "Valid hive admin role via DSN, not encoded url" ,
1118+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1119+ query : "SHOW ROLES FROM hive" ,
10611120 expectError : false ,
1121+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1122+ foundAdmin := false
1123+ for rows .Next () {
1124+ var roleName string
1125+ err := rows .Scan (& roleName )
1126+ require .NoError (t , err )
1127+ if roleName == "admin" {
1128+ foundAdmin = true
1129+ }
1130+ }
1131+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1132+ },
10621133 },
10631134 {
10641135 name : "Valid roles via DSN, url encoded" ,
1065- rawDSN : * integrationServerFlag + "?roles%3Dtpch%3Arole1%3Bmemory%3Arole2" ,
1136+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1137+ query : "SHOW ROLES FROM hive" ,
10661138 expectError : false ,
1139+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1140+ foundAdmin := false
1141+ for rows .Next () {
1142+ var roleName string
1143+ err := rows .Scan (& roleName )
1144+ require .NoError (t , err )
1145+ if roleName == "admin" {
1146+ foundAdmin = true
1147+ }
1148+ }
1149+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1150+ },
1151+ },
1152+ {
1153+ name : "No role - should fail to show roles" ,
1154+ config : Config {
1155+ ServerURI : * integrationServerFlag ,
1156+ },
1157+ query : "SHOW ROLES FROM hive" ,
1158+ expectError : true ,
1159+ errorSubstr : "Access Denied" ,
1160+ },
1161+ {
1162+ name : "Wrong role - should fail to show roles" ,
1163+ config : Config {
1164+ ServerURI : * integrationServerFlag ,
1165+ Roles : map [string ]string {"hive" : "ALL" },
1166+ },
1167+ query : "SHOW ROLES FROM hive" ,
1168+ expectError : true ,
1169+ errorSubstr : "Access Denied" ,
10671170 },
10681171 {
10691172 name : "Non-existent catalog role" ,
10701173 config : Config {
10711174 ServerURI : * integrationServerFlag ,
10721175 Roles : map [string ]string {"not-exist-catalog" : "role1" },
10731176 },
1177+ query : "SELECT 1" ,
10741178 expectError : true ,
10751179 errorSubstr : "USER_ERROR: Catalog" ,
10761180 },
@@ -1091,7 +1195,9 @@ func TestRoleHeaderSupport(t *testing.T) {
10911195 }
10921196
10931197 db := integrationOpen (t , dns )
1094- _ , err = db .Query ("SELECT 1" )
1198+ defer db .Close ()
1199+
1200+ rows , err := db .Query (tt .query )
10951201
10961202 if tt .expectError {
10971203 require .Error (t , err )
@@ -1100,6 +1206,10 @@ func TestRoleHeaderSupport(t *testing.T) {
11001206 }
11011207 } else {
11021208 require .NoError (t , err )
1209+ if tt .validateRows != nil && rows != nil {
1210+ defer rows .Close ()
1211+ tt .validateRows (t , rows )
1212+ }
11031213 }
11041214 })
11051215 }
0 commit comments