@@ -280,6 +280,10 @@ type sub struct {
280280 id SubscriptionIdentifier
281281 completed chan struct {}
282282 lastWrite time.Time
283+ // executor is an optional argument that allows us to "schedule" the execution of an update on another thread
284+ // e.g. if we're using SSE/Multipart Fetch, we can run the execution on the goroutine of the http request
285+ // this ensures that ctx cancellation works properly when a client disconnects
286+ executor chan func ()
283287}
284288
285289func (r * Resolver ) executeSubscriptionUpdate (ctx * Context , sub * sub , sharedInput []byte ) {
@@ -495,6 +499,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription)
495499 id : add .id ,
496500 completed : add .completed ,
497501 lastWrite : time .Now (),
502+ executor : add .executor ,
498503 }
499504 if add .ctx .ExecutionOptions .SendHeartbeat {
500505 r .heartbeatSubscriptions [add .ctx ] = s
@@ -687,6 +692,9 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
687692 trig .inFlight = wg
688693 for c , s := range trig .subscriptions {
689694 c , s := c , s
695+ if err := c .ctx .Err (); err != nil {
696+ continue // no need to schedule an event update when the client already disconnected
697+ }
690698 skip , err := s .resolve .Filter .SkipEvent (c , data , r .triggerUpdateBuf )
691699 if err != nil {
692700 r .asyncErrorWriter .WriteError (c , err , s .resolve .Response , s .writer )
@@ -695,12 +703,22 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) {
695703 if skip {
696704 continue
697705 }
698-
699706 wg .Add (1 )
700- go func () {
701- defer wg .Done ()
707+ fn := func () {
702708 r .executeSubscriptionUpdate (c , s , data )
703- }()
709+ }
710+ go func (fn func ()) {
711+ defer wg .Done ()
712+ if s .executor != nil {
713+ select {
714+ case <- r .ctx .Done ():
715+ case <- c .ctx .Done ():
716+ case s .executor <- fn :
717+ }
718+ } else {
719+ fn ()
720+ }
721+ }(fn )
704722 }
705723}
706724
@@ -825,6 +843,7 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
825843 fmt .Printf ("resolver:trigger:subscribe:sync:%d:%d\n " , uniqueID , id .SubscriptionID )
826844 }
827845 completed := make (chan struct {})
846+ executor := make (chan func ())
828847 select {
829848 case <- r .ctx .Done ():
830849 return r .ctx .Err ()
@@ -838,25 +857,32 @@ func (r *Resolver) ResolveGraphQLSubscription(ctx *Context, subscription *GraphQ
838857 writer : writer ,
839858 id : id ,
840859 completed : completed ,
860+ executor : executor ,
841861 },
842862 }:
843863 }
844- select {
845- case <- r .ctx .Done ():
846- // the resolver ctx was canceled
847- // this will trigger the shutdown of the trigger (on another goroutine)
848- // as such, we need to wait for the trigger to be shutdown
849- // otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
850- // as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
864+ Loop: // execute fn on the main thread of the incoming request until ctx is done
865+ for {
851866 select {
852- case <- completed :
853- return r .ctx .Err ()
854- case <- time .After (time .Second * 5 ):
855- return r .ctx .Err ()
867+ case <- r .ctx .Done ():
868+ // the resolver ctx was canceled
869+ // this will trigger the shutdown of the trigger (on another goroutine)
870+ // as such, we need to wait for the trigger to be shutdown
871+ // otherwise we might experience a data race between trigger shutdown write (Complete) and reading bytes written to the writer
872+ // as the shutdown happens asynchronously, we want to wait here for at most 5 seconds or until the client ctx is done
873+ select {
874+ case <- completed :
875+ return r .ctx .Err ()
876+ case <- time .After (time .Second * 5 ):
877+ return r .ctx .Err ()
878+ case <- ctx .Context ().Done ():
879+ return ctx .Context ().Err ()
880+ }
856881 case <- ctx .Context ().Done ():
857- return ctx .Context ().Err ()
882+ break Loop
883+ case fn := <- executor :
884+ fn ()
858885 }
859- case <- ctx .Context ().Done ():
860886 }
861887 if r .options .Debug {
862888 fmt .Printf ("resolver:trigger:unsubscribe:sync:%d:%d\n " , uniqueID , id .SubscriptionID )
@@ -1008,6 +1034,7 @@ type addSubscription struct {
10081034 writer SubscriptionResponseWriter
10091035 id SubscriptionIdentifier
10101036 completed chan struct {}
1037+ executor chan func ()
10111038}
10121039
10131040type subscriptionEventKind int
0 commit comments