Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 12091f6

Browse files
authored
Merge pull request #565 from erizocosmico/fix/cancel-queries
*: fix query cancellation
2 parents cabd9bc + 07f1a72 commit 12091f6

File tree

9 files changed

+120
-19
lines changed

9 files changed

+120
-19
lines changed

server/handler.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) {
177177

178178
if s[1] == "query" {
179179
logrus.Infof("kill query: id %v", id)
180-
h.e.Catalog.Kill(id)
180+
h.e.Catalog.KillConnection(uint32(id))
181181
} else {
182182
logrus.Infof("kill connection: id %v, pid: %v", conn.ConnectionID, id)
183183
h.mu.Lock()
@@ -189,7 +189,7 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) {
189189
return false, errConnectionNotFound.New(conn.ConnectionID)
190190
}
191191

192-
h.e.Catalog.KillConnection(id)
192+
h.e.Catalog.KillConnection(uint32(id))
193193
h.sm.CloseConn(c)
194194
c.Close()
195195
}

server/handler_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ func TestHandlerKill(t *testing.T) {
195195
require.Equal(conn1, handler.c[1])
196196
require.Equal(conn2, handler.c[2])
197197

198+
assertNoConnProcesses(t, e, conn2.ConnectionID)
199+
198200
err = handler.ComQuery(conn2, "KILL 1", func(res *sqltypes.Result) error {
199201
return nil
200202
})
@@ -203,4 +205,15 @@ func TestHandlerKill(t *testing.T) {
203205
require.Len(handler.sm.sessions, 0)
204206
require.Len(handler.c, 1)
205207
require.Equal(conn1, handler.c[1])
208+
assertNoConnProcesses(t, e, conn2.ConnectionID)
209+
}
210+
211+
func assertNoConnProcesses(t *testing.T, e *sqle.Engine, conn uint32) {
212+
t.Helper()
213+
214+
for _, p := range e.Catalog.Processes() {
215+
if p.Connection == conn {
216+
t.Errorf("expecting no processes with connection id %d", conn)
217+
}
218+
}
206219
}

sql/parse/lock.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package parse
22

33
import (
44
"bufio"
5-
"fmt"
65
"io"
76
"strings"
87

@@ -11,7 +10,6 @@ import (
1110
)
1211

1312
func parseLockTables(ctx *sql.Context, query string) (sql.Node, error) {
14-
fmt.Println(query)
1513
var r = bufio.NewReader(strings.NewReader(query))
1614
var tables []*plan.TableLock
1715
err := parseFuncs{

sql/plan/exchange.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"sync"
@@ -103,7 +104,7 @@ func newExchangeRowIter(
103104
ctx: ctx,
104105
parallelism: parallelism,
105106
rows: make(chan sql.Row, parallelism),
106-
err: make(chan error),
107+
err: make(chan error, 1),
107108
started: false,
108109
tree: tree,
109110
partitions: iter,
@@ -149,6 +150,7 @@ func (it *exchangeRowIter) start() {
149150
for {
150151
select {
151152
case <-it.ctx.Done():
153+
it.err <- context.Canceled
152154
it.closeTokens()
153155
return
154156
case <-it.quit:
@@ -186,6 +188,7 @@ func (it *exchangeRowIter) iterPartitions(ch chan<- sql.Partition) {
186188
for {
187189
select {
188190
case <-it.ctx.Done():
191+
it.err <- context.Canceled
189192
return
190193
case <-it.quit:
191194
return
@@ -232,6 +235,7 @@ func (it *exchangeRowIter) iterPartition(p sql.Partition) {
232235
for {
233236
select {
234237
case <-it.ctx.Done():
238+
it.err <- context.Canceled
235239
return
236240
case <-it.quit:
237241
return
@@ -259,14 +263,14 @@ func (it *exchangeRowIter) Next() (sql.Row, error) {
259263
}
260264

261265
select {
266+
case err := <-it.err:
267+
_ = it.Close()
268+
return nil, err
262269
case row, ok := <-it.rows:
263270
if !ok {
264271
return nil, io.EOF
265272
}
266273
return row, nil
267-
case err := <-it.err:
268-
_ = it.Close()
269-
return nil, err
270274
}
271275
}
272276

sql/plan/exchange_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"testing"
@@ -56,6 +57,39 @@ func TestExchange(t *testing.T) {
5657
}
5758
}
5859

60+
func TestExchangeCancelled(t *testing.T) {
61+
children := NewProject(
62+
[]sql.Expression{
63+
expression.NewGetField(0, sql.Text, "partition", false),
64+
expression.NewArithmetic(
65+
expression.NewGetField(1, sql.Int64, "val", false),
66+
expression.NewLiteral(int64(1), sql.Int64),
67+
"+",
68+
),
69+
},
70+
NewFilter(
71+
expression.NewLessThan(
72+
expression.NewGetField(1, sql.Int64, "val", false),
73+
expression.NewLiteral(int64(4), sql.Int64),
74+
),
75+
&partitionable{nil, 3, 6},
76+
),
77+
)
78+
79+
exchange := NewExchange(3, children)
80+
require := require.New(t)
81+
82+
c, cancel := context.WithCancel(context.Background())
83+
ctx := sql.NewContext(c)
84+
cancel()
85+
86+
iter, err := exchange.RowIter(ctx)
87+
require.NoError(err)
88+
89+
_, err = iter.Next()
90+
require.Equal(context.Canceled, err)
91+
}
92+
5993
type partitionable struct {
6094
sql.Node
6195
partitions int

sql/plan/resolved_table.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"context"
45
"io"
56

67
"gopkg.in/src-d/go-mysql-server.v0/sql"
@@ -62,6 +63,12 @@ type tableIter struct {
6263
}
6364

6465
func (i *tableIter) Next() (sql.Row, error) {
66+
select {
67+
case <-i.ctx.Done():
68+
return nil, context.Canceled
69+
default:
70+
}
71+
6572
if i.partition == nil {
6673
partition, err := i.partitions.Next()
6774
if err != nil {

sql/plan/resolved_table_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package plan
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"testing"
@@ -31,6 +32,22 @@ func TestResolvedTable(t *testing.T) {
3132
}
3233
}
3334

35+
func TestResolvedTableCancelled(t *testing.T) {
36+
var require = require.New(t)
37+
38+
table := NewResolvedTable(newTableTest("test"))
39+
require.NotNil(table)
40+
41+
ctx, cancel := context.WithCancel(context.Background())
42+
cancel()
43+
44+
iter, err := table.RowIter(sql.NewContext(ctx))
45+
require.NoError(err)
46+
47+
_, err = iter.Next()
48+
require.Equal(context.Canceled, err)
49+
}
50+
3451
func newTableTest(source string) sql.Table {
3552
schema := []*sql.Column{
3653
{Name: "col1", Type: sql.Int32, Source: source, Default: int32(0), Nullable: false},

sql/processlist.go

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,14 @@ func (pl *ProcessList) Kill(pid uint64) {
160160
pl.Done(pid)
161161
}
162162

163-
// KillConnection kills all processes that have the same connection as the one
164-
// of the process with the given process id. If the process does not exist, it
165-
// will do nothing.
166-
func (pl *ProcessList) KillConnection(pid uint64) {
163+
// KillConnection kills all processes from the given connection.
164+
func (pl *ProcessList) KillConnection(conn uint32) {
167165
pl.mu.Lock()
168166
defer pl.mu.Unlock()
169167

170-
proc, ok := pl.procs[pid]
171-
if !ok {
172-
return
173-
}
174-
175-
conn := proc.Connection
176168
for pid, proc := range pl.procs {
177169
if proc.Connection == conn {
178-
proc.Kill()
170+
proc.Done()
179171
delete(pl.procs, pid)
180172
}
181173
}

sql/processlist_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,39 @@ func sortByPid(slice []Process) {
8585
return slice[i].Pid < slice[j].Pid
8686
})
8787
}
88+
89+
func TestKillConnection(t *testing.T) {
90+
pl := NewProcessList()
91+
92+
s1 := NewSession("", "", "", 1)
93+
s2 := NewSession("", "", "", 2)
94+
95+
var killed = make(map[uint64]bool)
96+
for i := uint64(1); i <= 3; i++ {
97+
// Odds get s1, evens get s2
98+
s := s1
99+
if i%2 == 0 {
100+
s = s2
101+
}
102+
103+
_, err := pl.AddProcess(
104+
NewContext(context.Background(), WithPid(i), WithSession(s)),
105+
QueryProcess,
106+
"foo",
107+
)
108+
require.NoError(t, err)
109+
110+
i := i
111+
pl.procs[i].Kill = func() {
112+
killed[i] = true
113+
}
114+
}
115+
116+
pl.KillConnection(1)
117+
require.Len(t, pl.procs, 1)
118+
119+
// Odds should have been killed
120+
require.True(t, killed[1])
121+
require.False(t, killed[2])
122+
require.True(t, killed[3])
123+
}

0 commit comments

Comments
 (0)