Skip to content

Commit 3c2f549

Browse files
Flgadonineinchnick
authored andcommitted
make sure when query is executed with a deadline, underlying trino query is canceled
1 parent 6acb774 commit 3c2f549

File tree

2 files changed

+105
-69
lines changed

2 files changed

+105
-69
lines changed

trino/integration_test.go

Lines changed: 104 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,82 +1008,118 @@ func TestIntegrationUnsupportedHeader(t *testing.T) {
10081008
}
10091009
}
10101010

1011-
func TestIntegrationQueryContextCancellation(t *testing.T) {
1012-
err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}})
1013-
if err != nil {
1011+
func TestIntegrationQueryContext(t *testing.T) {
1012+
tests := []struct {
1013+
name string
1014+
timeout time.Duration
1015+
expectedErrMsg string
1016+
}{
1017+
{
1018+
name: "Context Cancellation",
1019+
timeout: 0,
1020+
expectedErrMsg: "canceled",
1021+
},
1022+
{
1023+
name: "Context Deadline Exceeded",
1024+
timeout: 3 * time.Second,
1025+
expectedErrMsg: "context deadline exceeded",
1026+
},
1027+
}
1028+
1029+
if err := RegisterCustomClient("uncompressed", &http.Client{Transport: &http.Transport{DisableCompression: true}}); err != nil {
10141030
t.Fatal(err)
10151031
}
1016-
dsn := *integrationServerFlag
1017-
dsn += "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
1032+
1033+
dsn := *integrationServerFlag + "?catalog=tpch&schema=sf100&source=cancel-test&custom_client=uncompressed"
10181034
db := integrationOpen(t, dsn)
10191035
defer db.Close()
10201036

1021-
ctx, cancel := context.WithCancel(context.Background())
1022-
errCh := make(chan error, 3)
1023-
done := make(chan struct{})
1024-
longQuery := "SELECT COUNT(*) FROM lineitem"
1025-
go func() {
1026-
// query will complete in ~7s unless cancelled
1027-
rows, err := db.QueryContext(ctx, longQuery)
1028-
if err != nil {
1029-
errCh <- err
1030-
return
1031-
}
1032-
rows.Next()
1033-
if err = rows.Err(); err != nil {
1034-
errCh <- err
1035-
return
1036-
}
1037-
close(done)
1038-
}()
1039-
1040-
// poll system.runtime.queries and wait for query to start working
1041-
var queryID string
1042-
pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second)
1043-
defer pollCancel()
1044-
for {
1045-
row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery)
1046-
err := row.Scan(&queryID)
1047-
if err == nil {
1048-
break
1049-
}
1050-
if err != sql.ErrNoRows {
1051-
t.Fatal("failed to read query id", err)
1052-
}
1053-
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
1054-
t.Fatal("query did not start in 1 second")
1055-
}
1056-
}
1037+
for _, tt := range tests {
1038+
t.Run(tt.name, func(t *testing.T) {
1039+
var ctx context.Context
1040+
var cancel context.CancelFunc
10571041

1058-
cancel()
1042+
if tt.timeout == 0 {
1043+
ctx, cancel = context.WithCancel(context.Background())
1044+
} else {
1045+
ctx, cancel = context.WithTimeout(context.Background(), tt.timeout)
1046+
}
1047+
defer cancel()
10591048

1060-
select {
1061-
case <-done:
1062-
t.Fatal("unexpected query with cancelled context succeeded")
1063-
break
1064-
case err = <-errCh:
1065-
if !strings.Contains(err.Error(), "canceled") {
1066-
t.Fatal("expected err to be canceled but got:", err)
1067-
}
1068-
}
1049+
errCh := make(chan error, 1)
1050+
done := make(chan struct{})
1051+
longQuery := "SELECT COUNT(*) FROM lineitem"
10691052

1070-
// poll system.runtime.queries and wait for query to be cancelled
1071-
pollCtx, pollCancel = context.WithTimeout(context.Background(), 1*time.Second)
1072-
defer pollCancel()
1073-
for {
1074-
row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID)
1075-
var state string
1076-
var code *string
1077-
err := row.Scan(&state, &code)
1078-
if err != nil {
1079-
t.Fatal("failed to read query id", err)
1080-
}
1081-
if state == "FAILED" && code != nil && *code == "USER_CANCELED" {
1082-
break
1083-
}
1084-
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
1085-
t.Fatal("query was not cancelled in 1 second; state, code, err are:", state, code, err)
1086-
}
1053+
go func() {
1054+
// query will complete in ~7s unless cancelled
1055+
rows, err := db.QueryContext(ctx, longQuery)
1056+
if err != nil {
1057+
errCh <- err
1058+
return
1059+
}
1060+
defer rows.Close()
1061+
1062+
rows.Next()
1063+
if err = rows.Err(); err != nil {
1064+
errCh <- err
1065+
return
1066+
}
1067+
close(done)
1068+
}()
1069+
1070+
// Poll system.runtime.queries to get the query ID
1071+
var queryID string
1072+
pollCtx, pollCancel := context.WithTimeout(context.Background(), 1*time.Second)
1073+
defer pollCancel()
1074+
1075+
for {
1076+
row := db.QueryRowContext(pollCtx, "SELECT query_id FROM system.runtime.queries WHERE state = 'RUNNING' AND source = 'cancel-test' AND query = ?", longQuery)
1077+
err := row.Scan(&queryID)
1078+
if err == nil {
1079+
break
1080+
}
1081+
if err != sql.ErrNoRows {
1082+
t.Fatal("failed to read query ID:", err)
1083+
}
1084+
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
1085+
t.Fatal("query did not start in 1 second")
1086+
}
1087+
}
1088+
1089+
if tt.timeout == 0 {
1090+
cancel()
1091+
}
1092+
1093+
// Wait for the query to be canceled or completed
1094+
select {
1095+
case <-done:
1096+
t.Fatal("unexpected query succeeded despite cancellation or deadline")
1097+
case err := <-errCh:
1098+
if !strings.Contains(err.Error(), tt.expectedErrMsg) {
1099+
t.Fatalf("expected error containing %q, but got: %v", tt.expectedErrMsg, err)
1100+
}
1101+
}
1102+
1103+
// Poll system.runtime.queries to verify the query was canceled
1104+
pollCtx, pollCancel = context.WithTimeout(context.Background(), 2*time.Second)
1105+
defer pollCancel()
1106+
1107+
for {
1108+
row := db.QueryRowContext(pollCtx, "SELECT state, error_code FROM system.runtime.queries WHERE query_id = ?", queryID)
1109+
var state string
1110+
var code *string
1111+
err := row.Scan(&state, &code)
1112+
if err != nil {
1113+
t.Fatal("failed to read query state:", err)
1114+
}
1115+
if state == "FAILED" && code != nil && *code == "USER_CANCELED" {
1116+
return
1117+
}
1118+
if err = contextSleep(pollCtx, 100*time.Millisecond); err != nil {
1119+
t.Fatalf("query was not canceled in 2 seconds; state: %s, code: %v, err: %v", state, code, err)
1120+
}
1121+
}
1122+
})
10871123
}
10881124
}
10891125

trino/trino.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ func (qr *driverRows) fetch() error {
13611361
// Channel was closed, which means the statement
13621362
// or rows were closed.
13631363
err = io.EOF
1364-
} else if err == context.Canceled {
1364+
} else if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
13651365
qr.Close()
13661366
}
13671367
qr.err = err

0 commit comments

Comments
 (0)