Skip to content

Commit acdaf47

Browse files
committed
feat: execute subscription writes on main goroutine in synchronous resolve subscriptions
1 parent f7a31e8 commit acdaf47

File tree

2 files changed

+49
-17
lines changed

2 files changed

+49
-17
lines changed

v2/pkg/engine/resolve/resolve.go

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ type sub struct {
280280
id SubscriptionIdentifier
281281
completed chan struct{}
282282
lastWrite time.Time
283+
// executor is an optional argument that allows us to "schedule" the execution of an update on another thread
284+
// e.g. if we're using SSE/Multipart Fetch, we can run the execution on the goroutine of the http request
285+
// this ensures that ctx cancellation works properly when a client disconnects
286+
executor chan func()
283287
}
284288

285289
func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput []byte) {
@@ -495,6 +499,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
495499
id: add.id,
496500
completed: add.completed,
497501
lastWrite: time.Now(),
502+
executor: add.executor,
498503
}
499504
if add.ctx.ExecutionOptions.SendHeartbeat {
500505
r.heartbeatSubscriptions[add.ctx] = s
@@ -687,6 +692,9 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
687692
trig.inFlight = wg
688693
for c, s := range trig.subscriptions {
689694
c, s := c, s
695+
if err := c.ctx.Err(); err != nil {
696+
continue // no need to schedule an event update when the client already disconnected
697+
}
690698
skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf)
691699
if err != nil {
692700
r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer)
@@ -695,12 +703,22 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
695703
if skip {
696704
continue
697705
}
698-
699706
wg.Add(1)
700-
go func() {
701-
defer wg.Done()
707+
fn := func() {
702708
r.executeSubscriptionUpdate(c, s, data)
703-
}()
709+
}
710+
go func(fn func()) {
711+
defer wg.Done()
712+
if s.executor != nil {
713+
select {
714+
case <-r.ctx.Done():
715+
case <-c.ctx.Done():
716+
case s.executor <- fn:
717+
}
718+
} else {
719+
fn()
720+
}
721+
}(fn)
704722
}
705723
}
706724

@@ -825,6 +843,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
825843
fmt.Printf("resolver:trigger:subscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID)
826844
}
827845
completed := make(chan struct{})
846+
executor := make(chan func())
828847
select {
829848
case <-r.ctx.Done():
830849
return r.ctx.Err()
@@ -838,25 +857,32 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
838857
writer: writer,
839858
id: id,
840859
completed: completed,
860+
executor: executor,
841861
},
842862
}:
843863
}
844-
select {
845-
case <-r.ctx.Done():
846-
// the resolver ctx was canceled
847-
// this will trigger the shutdown of the trigger (on another goroutine)
848-
// as such, we need to wait for the trigger to be shutdown
849-
// otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
850-
// as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
864+
Loop: // execute fn on the main thread of the incoming request until ctx is done
865+
for {
851866
select {
852-
case <-completed:
853-
return r.ctx.Err()
854-
case <-time.After(time.Second * 5):
855-
return r.ctx.Err()
867+
case <-r.ctx.Done():
868+
// the resolver ctx was canceled
869+
// this will trigger the shutdown of the trigger (on another goroutine)
870+
// as such, we need to wait for the trigger to be shutdown
871+
// otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
872+
// as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
873+
select {
874+
case <-completed:
875+
return r.ctx.Err()
876+
case <-time.After(time.Second * 5):
877+
return r.ctx.Err()
878+
case <-ctx.Context().Done():
879+
return ctx.Context().Err()
880+
}
856881
case <-ctx.Context().Done():
857-
return ctx.Context().Err()
882+
break Loop
883+
case fn := <-executor:
884+
fn()
858885
}
859-
case <-ctx.Context().Done():
860886
}
861887
if r.options.Debug {
862888
fmt.Printf("resolver:trigger:unsubscribe:sync:%d:%d\n", uniqueID, id.SubscriptionID)
@@ -1008,6 +1034,7 @@ type addSubscription struct {
10081034
writer SubscriptionResponseWriter
10091035
id SubscriptionIdentifier
10101036
completed chan struct{}
1037+
executor chan func()
10111038
}
10121039

10131040
type subscriptionEventKind int

v2/pkg/engine/resolve/resolve_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5201,6 +5201,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
52015201
resolver := newResolver(c)
52025202

52035203
ctx := &Context{
5204+
ctx: context.Background(),
52045205
Variables: astjson.MustParseBytes([]byte(`{"id":1}`)),
52055206
}
52065207

@@ -5296,6 +5297,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
52965297
resolver := newResolver(c)
52975298

52985299
ctx := &Context{
5300+
ctx: context.Background(),
52995301
Variables: astjson.MustParseBytes([]byte(`{"id":2}`)),
53005302
}
53015303

@@ -5389,6 +5391,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
53895391
resolver := newResolver(c)
53905392

53915393
ctx := &Context{
5394+
ctx: context.Background(),
53925395
Variables: astjson.MustParseBytes([]byte(`{"ids":[1,2]}`)),
53935396
}
53945397

@@ -5487,6 +5490,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
54875490
resolver := newResolver(c)
54885491

54895492
ctx := &Context{
5493+
ctx: context.Background(),
54905494
Variables: astjson.MustParseBytes([]byte(`{"ids":["2","3"]}`)),
54915495
}
54925496

@@ -5595,6 +5599,7 @@ func Test_ResolveGraphQLSubscriptionWithFilter(t *testing.T) {
55955599
resolver := newResolver(c)
55965600

55975601
ctx := &Context{
5602+
ctx: context.Background(),
55985603
Variables: astjson.MustParseBytes([]byte(`{"a":[1,2],"b":[3,4]}`)),
55995604
}
56005605

0 commit comments

Comments
 (0)