@@ -50,6 +50,7 @@ import (
5050 "github.com/golang-jwt/jwt/v5"
5151 dt "github.com/ory/dockertest/v3"
5252 docker "github.com/ory/dockertest/v3/docker"
53+ "github.com/stretchr/testify/require"
5354)
5455
5556const (
@@ -143,8 +144,12 @@ func TestMain(m *testing.M) {
143144 wd + "/etc/password-authenticator.properties:/etc/trino/password-authenticator.properties" ,
144145 }
145146
147+ version , err := strconv .Atoi (* trinoImageTagFlag )
148+ if (err != nil && * trinoImageTagFlag != "latest" ) || (err == nil && version < 458 ) {
149+ mounts = append (mounts , wd + "/etc/catalog/hive-pre-458version.properties:/etc/trino/catalog/hive.properties" )
150+ }
151+
146152 if spoolingProtocolSupported {
147- version , err := strconv .Atoi (* trinoImageTagFlag )
148153 if (err != nil && * trinoImageTagFlag != "latest" ) || (err == nil && version < 477 ) {
149154 mounts = append (mounts , wd + "/etc/config-pre-477version.properties:/etc/trino/config.properties" )
150155 } else {
@@ -182,6 +187,11 @@ func TestMain(m *testing.M) {
182187
183188 waitForContainerHealth (trinoResource .Container .ID , "trino" )
184189
190+ err = grantAdminRoleToTestUser ()
191+ if err != nil {
192+ log .Fatalf ("Warning: Failed to grant admin role to test user: %s" , err )
193+ }
194+
185195 * integrationServerFlag = "http://test@localhost:" + trinoResource .GetPort ("8080/tcp" )
186196 tlsServer = "https://admin:admin@localhost:" + trinoResource .GetPort ("8443/tcp" )
187197
@@ -217,6 +227,35 @@ func TestMain(m *testing.M) {
217227 os .Exit (code )
218228}
219229
230+ func grantAdminRoleToTestUser () error {
231+ grantSQL := "SET ROLE admin IN hive; GRANT admin TO USER test IN hive;"
232+
233+ execCmd := []string {
234+ "trino" ,
235+ "--user" , "admin" ,
236+ "--execute" , grantSQL ,
237+ }
238+ exec , err := pool .Client .CreateExec (docker.CreateExecOptions {
239+ Container : trinoResource .Container .ID ,
240+ Cmd : execCmd ,
241+ })
242+ if err != nil {
243+ log .Printf ("Warning: Failed to create exec for GRANT: %s" , err )
244+ } else {
245+ var stdout , stderr bytes.Buffer
246+ err = pool .Client .StartExec (exec .ID , docker.StartExecOptions {
247+ Detach : false ,
248+ OutputStream : & stdout ,
249+ ErrorStream : & stderr ,
250+ })
251+ if err != nil {
252+ log .Printf ("Warning: Failed to execute GRANT: %s" , err )
253+ }
254+ }
255+
256+ return err
257+ }
258+
220259func getOrCreateLocalStack (pool * dt.Pool , networkID string ) * dt.Resource {
221260 resource , ok := pool .ContainerByName (DockerLocalStackName )
222261 if ok {
@@ -1030,6 +1069,155 @@ func TestIntegrationNoResults(t *testing.T) {
10301069 t .Fatal (err )
10311070 }
10321071}
1072+ func TestRoleHeaderSupport (t * testing.T ) {
1073+ tests := []struct {
1074+ name string
1075+ config Config
1076+ rawDSN string
1077+ query string
1078+ expectError bool
1079+ errorSubstr string
1080+ validateRows func (t * testing.T , rows * sql.Rows )
1081+ }{
1082+ {
1083+ name : "Valid hive admin role via Config" ,
1084+ config : Config {
1085+ ServerURI : * integrationServerFlag ,
1086+ Roles : map [string ]string {"hive" : "admin" },
1087+ },
1088+ query : "SHOW ROLES FROM hive" ,
1089+ expectError : false ,
1090+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1091+ foundAdmin := false
1092+ for rows .Next () {
1093+ var roleName string
1094+ err := rows .Scan (& roleName )
1095+ require .NoError (t , err )
1096+ if roleName == "admin" {
1097+ foundAdmin = true
1098+ }
1099+ }
1100+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1101+ },
1102+ },
1103+ {
1104+ config : Config {
1105+ ServerURI : * integrationServerFlag ,
1106+ Roles : map [string ]string {"tpch" : "NONE" , "memory" : "ALL" },
1107+ },
1108+ query : "SELECT 1" ,
1109+ expectError : false ,
1110+ },
1111+ {
1112+ name : "Valid special roles via Config" ,
1113+ config : Config {
1114+ ServerURI : * integrationServerFlag ,
1115+ Roles : map [string ]string {"tpch" : "NONE" , "memory" : "ALL" },
1116+ },
1117+ query : "SELECT 1" ,
1118+ expectError : false ,
1119+ },
1120+ {
1121+ name : "Valid hive admin role via DSN, not encoded url" ,
1122+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1123+ query : "SHOW ROLES FROM hive" ,
1124+ expectError : false ,
1125+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1126+ foundAdmin := false
1127+ for rows .Next () {
1128+ var roleName string
1129+ err := rows .Scan (& roleName )
1130+ require .NoError (t , err )
1131+ if roleName == "admin" {
1132+ foundAdmin = true
1133+ }
1134+ }
1135+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1136+ },
1137+ },
1138+ {
1139+ name : "Valid roles via DSN, url encoded" ,
1140+ rawDSN : * integrationServerFlag + "?roles=hive:admin" ,
1141+ query : "SHOW ROLES FROM hive" ,
1142+ expectError : false ,
1143+ validateRows : func (t * testing.T , rows * sql.Rows ) {
1144+ foundAdmin := false
1145+ for rows .Next () {
1146+ var roleName string
1147+ err := rows .Scan (& roleName )
1148+ require .NoError (t , err )
1149+ if roleName == "admin" {
1150+ foundAdmin = true
1151+ }
1152+ }
1153+ require .True (t , foundAdmin , "Expected to find 'admin' role in SHOW ROLES output" )
1154+ },
1155+ },
1156+ {
1157+ name : "No role - should fail to show roles" ,
1158+ config : Config {
1159+ ServerURI : * integrationServerFlag ,
1160+ },
1161+ query : "SHOW ROLES FROM hive" ,
1162+ expectError : true ,
1163+ errorSubstr : "Access Denied" ,
1164+ },
1165+ {
1166+ name : "Wrong role - should fail to show roles" ,
1167+ config : Config {
1168+ ServerURI : * integrationServerFlag ,
1169+ Roles : map [string ]string {"hive" : "ALL" },
1170+ },
1171+ query : "SHOW ROLES FROM hive" ,
1172+ expectError : true ,
1173+ errorSubstr : "Access Denied" ,
1174+ },
1175+ {
1176+ name : "Non-existent catalog role" ,
1177+ config : Config {
1178+ ServerURI : * integrationServerFlag ,
1179+ Roles : map [string ]string {"not-exist-catalog" : "role1" },
1180+ },
1181+ query : "SELECT 1" ,
1182+ expectError : true ,
1183+ errorSubstr : "USER_ERROR: Catalog" ,
1184+ },
1185+ }
1186+
1187+ for _ , tt := range tests {
1188+ t .Run (tt .name , func (t * testing.T ) {
1189+ var dns string
1190+ var err error
1191+
1192+ if tt .rawDSN != "" {
1193+ dns = tt .rawDSN
1194+ } else {
1195+ dns , err = tt .config .FormatDSN ()
1196+ if err != nil {
1197+ t .Fatal (err )
1198+ }
1199+ }
1200+
1201+ db := integrationOpen (t , dns )
1202+ defer db .Close ()
1203+
1204+ rows , err := db .Query (tt .query )
1205+
1206+ if tt .expectError {
1207+ require .Error (t , err )
1208+ if tt .errorSubstr != "" {
1209+ require .Contains (t , err .Error (), tt .errorSubstr )
1210+ }
1211+ } else {
1212+ require .NoError (t , err )
1213+ if tt .validateRows != nil && rows != nil {
1214+ defer rows .Close ()
1215+ tt .validateRows (t , rows )
1216+ }
1217+ }
1218+ })
1219+ }
1220+ }
10331221
10341222func TestIntegrationQueryParametersSelect (t * testing.T ) {
10351223 scenarios := []struct {
0 commit comments