@@ -1008,82 +1008,118 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
1008
1008
}
1009
1009
}
1010
1010
1011
- func TestIntegrationQueryContextCancellation (t * testing.T ) {
1012
- err := RegisterCustomClient ("uncompressed" , & http.Client {Transport : & http.Transport {DisableCompression : true }})
1013
- if err != nil {
1011
+ func TestIntegrationQueryContext (t * testing.T ) {
1012
+ tests := []struct {
1013
+ name string
1014
+ timeout time.Duration
1015
+ expectedErrMsg string
1016
+ }{
1017
+ {
1018
+ name : "Context Cancellation" ,
1019
+ timeout : 0 ,
1020
+ expectedErrMsg : "canceled" ,
1021
+ },
1022
+ {
1023
+ name : "Context Deadline Exceeded" ,
1024
+ timeout : 3 * time .Second ,
1025
+ expectedErrMsg : "context deadline exceeded" ,
1026
+ },
1027
+ }
1028
+
1029
+ if err := RegisterCustomClient ("uncompressed" , & http.Client {Transport : & http.Transport {DisableCompression : true }}); err != nil {
1014
1030
t .Fatal (err )
1015
1031
}
1016
- dsn := * integrationServerFlag
1017
- dsn += "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
1032
+
1033
+ dsn := * integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
1018
1034
db := integrationOpen (t , dsn )
1019
1035
defer db .Close ()
1020
1036
1021
- ctx , cancel := context .WithCancel (context .Background ())
1022
- errCh := make (chan error , 3 )
1023
- done := make (chan struct {})
1024
- longQuery := "SELECT COUNT(*) FROM lineitem"
1025
- go func () {
1026
- // query will complete in ~7s unless cancelled
1027
- rows , err := db .QueryContext (ctx , longQuery )
1028
- if err != nil {
1029
- errCh <- err
1030
- return
1031
- }
1032
- rows .Next ()
1033
- if err = rows .Err (); err != nil {
1034
- errCh <- err
1035
- return
1036
- }
1037
- close (done )
1038
- }()
1039
-
1040
- // poll system.runtime.queries and wait for query to start working
1041
- var queryID string
1042
- pollCtx , pollCancel := context .WithTimeout (context .Background (), 1 * time .Second )
1043
- defer pollCancel ()
1044
- for {
1045
- row := db .QueryRowContext (pollCtx , "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?" , longQuery )
1046
- err := row .Scan (& queryID )
1047
- if err == nil {
1048
- break
1049
- }
1050
- if err != sql .ErrNoRows {
1051
- t .Fatal ("failed to read query id" , err )
1052
- }
1053
- if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1054
- t .Fatal ("query did not start in 1 second" )
1055
- }
1056
- }
1037
+ for _ , tt := range tests {
1038
+ t .Run (tt .name , func (t * testing.T ) {
1039
+ var ctx context.Context
1040
+ var cancel context.CancelFunc
1057
1041
1058
- cancel ()
1042
+ if tt .timeout == 0 {
1043
+ ctx , cancel = context .WithCancel (context .Background ())
1044
+ } else {
1045
+ ctx , cancel = context .WithTimeout (context .Background (), tt .timeout )
1046
+ }
1047
+ defer cancel ()
1059
1048
1060
- select {
1061
- case <- done :
1062
- t .Fatal ("unexpected query with cancelled context succeeded" )
1063
- break
1064
- case err = <- errCh :
1065
- if ! strings .Contains (err .Error (), "canceled" ) {
1066
- t .Fatal ("expected err to be canceled but got:" , err )
1067
- }
1068
- }
1049
+ errCh := make (chan error , 1 )
1050
+ done := make (chan struct {})
1051
+ longQuery := "SELECT COUNT(*) FROM lineitem"
1069
1052
1070
- // poll system.runtime.queries and wait for query to be cancelled
1071
- pollCtx , pollCancel = context .WithTimeout (context .Background (), 1 * time .Second )
1072
- defer pollCancel ()
1073
- for {
1074
- row := db .QueryRowContext (pollCtx , "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?" , queryID )
1075
- var state string
1076
- var code * string
1077
- err := row .Scan (& state , & code )
1078
- if err != nil {
1079
- t .Fatal ("failed to read query id" , err )
1080
- }
1081
- if state == "FAILED" && code != nil && * code == "USER_CANCELED" {
1082
- break
1083
- }
1084
- if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1085
- t .Fatal ("query was not cancelled in 1 second; state, code, err are:" , state , code , err )
1086
- }
1053
+ go func () {
1054
+ // query will complete in ~7s unless cancelled
1055
+ rows , err := db .QueryContext (ctx , longQuery )
1056
+ if err != nil {
1057
+ errCh <- err
1058
+ return
1059
+ }
1060
+ defer rows .Close ()
1061
+
1062
+ rows .Next ()
1063
+ if err = rows .Err (); err != nil {
1064
+ errCh <- err
1065
+ return
1066
+ }
1067
+ close (done )
1068
+ }()
1069
+
1070
+ // Poll system.runtime.queries to get the query ID
1071
+ var queryID string
1072
+ pollCtx , pollCancel := context .WithTimeout (context .Background (), 1 * time .Second )
1073
+ defer pollCancel ()
1074
+
1075
+ for {
1076
+ row := db .QueryRowContext (pollCtx , "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?" , longQuery )
1077
+ err := row .Scan (& queryID )
1078
+ if err == nil {
1079
+ break
1080
+ }
1081
+ if err != sql .ErrNoRows {
1082
+ t .Fatal ("failed to read query ID:" , err )
1083
+ }
1084
+ if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1085
+ t .Fatal ("query did not start in 1 second" )
1086
+ }
1087
+ }
1088
+
1089
+ if tt .timeout == 0 {
1090
+ cancel ()
1091
+ }
1092
+
1093
+ // Wait for the query to be canceled or completed
1094
+ select {
1095
+ case <- done :
1096
+ t .Fatal ("unexpected query succeeded despite cancellation or deadline" )
1097
+ case err := <- errCh :
1098
+ if ! strings .Contains (err .Error (), tt .expectedErrMsg ) {
1099
+ t .Fatalf ("expected error containing %q, but got: %v" , tt .expectedErrMsg , err )
1100
+ }
1101
+ }
1102
+
1103
+ // Poll system.runtime.queries to verify the query was canceled
1104
+ pollCtx , pollCancel = context .WithTimeout (context .Background (), 2 * time .Second )
1105
+ defer pollCancel ()
1106
+
1107
+ for {
1108
+ row := db .QueryRowContext (pollCtx , "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?" , queryID )
1109
+ var state string
1110
+ var code * string
1111
+ err := row .Scan (& state , & code )
1112
+ if err != nil {
1113
+ t .Fatal ("failed to read query state:" , err )
1114
+ }
1115
+ if state == "FAILED" && code != nil && * code == "USER_CANCELED" {
1116
+ return
1117
+ }
1118
+ if err = contextSleep (pollCtx , 100 * time .Millisecond ); err != nil {
1119
+ t .Fatalf ("query was not canceled in 2 seconds; state: %s, code: %v, err: %v" , state , code , err )
1120
+ }
1121
+ }
1122
+ })
1087
1123
}
1088
1124
}
1089
1125
0 commit comments