66 "fmt"
77 "sort"
88 "sync"
9- "sync/atomic"
109 "time"
1110
1211 "google.golang.org/grpc"
@@ -70,7 +69,7 @@ func newClient(
7069 balancer .OnUpdate (c .updateNodes )
7170 }
7271 if idleThreshold := config .IdleThreshold (); idleThreshold > 0 {
73- c .spawnedGoroutines .Add (1 )
72+ c .wg .Add (1 )
7473 go c .internalPoolGC (ctx , idleThreshold )
7574 }
7675 onDone (c .limit )
@@ -95,8 +94,7 @@ type Client struct {
9594 waitq * list.List // list<*chan *session>
9695 waitChPool sync.Pool
9796 testHookGetWaitCh func () // nil except some tests.
98- spawnedGoroutines sync.WaitGroup
99- closed uint32
97+ wg sync.WaitGroup
10098 done chan struct {}
10199}
102100
@@ -128,6 +126,9 @@ func (c *Client) updateNodes(ctx context.Context, endpoints []endpoint.Info) {
128126 return nodeIDs [i ] < nodeIDs [j ]
129127 })
130128 c .mu .WithLock (func () {
129+ if c .isClosed () {
130+ return
131+ }
131132 for nodeID := range c .nodes {
132133 if sort .Search (len (nodeIDs ), func (i int ) bool {
133134 return nodeIDs [i ] >= nodeID
@@ -145,10 +146,6 @@ func (c *Client) updateNodes(ctx context.Context, endpoints []endpoint.Info) {
145146}
146147
147148func (c * Client ) createSession (ctx context.Context , opts ... createSessionOption ) (s * session , err error ) {
148- if c .isClosed () {
149- return nil , errClosedClient
150- }
151-
152149 options := createSessionOptions {}
153150 for _ , o := range opts {
154151 o (& options )
@@ -171,10 +168,6 @@ func (c *Client) createSession(ctx context.Context, opts ...createSessionOption)
171168
172169 ch := make (chan result )
173170
174- if c .isClosed () {
175- return nil , xerrors .WithStackTrace (errClosedClient )
176- }
177-
178171 select {
179172 case <- c .done :
180173 return nil , xerrors .WithStackTrace (errClosedClient )
@@ -183,54 +176,59 @@ func (c *Client) createSession(ctx context.Context, opts ...createSessionOption)
183176 return nil , xerrors .WithStackTrace (ctx .Err ())
184177
185178 default :
186- c .spawnedGoroutines .Add (1 )
187- go func () {
188- defer c .spawnedGoroutines .Done ()
189-
190- var (
191- s * session
192- err error
193- )
194-
195- createSessionCtx := xcontext .WithoutDeadline (ctx )
196-
197- if timeout := c .config .CreateSessionTimeout (); timeout > 0 {
198- var cancel context.CancelFunc
199- createSessionCtx , cancel = context .WithTimeout (createSessionCtx , timeout )
200- defer cancel ()
179+ c .mu .WithLock (func () {
180+ if c .isClosed () {
181+ return
201182 }
183+ c .wg .Add (1 )
184+ go func () {
185+ defer c .wg .Done ()
202186
203- s , err = c .build (createSessionCtx )
187+ var (
188+ s * session
189+ err error
190+ )
204191
205- closeSession := func (s * session ) {
206- if s == nil {
207- return
208- }
192+ createSessionCtx := xcontext .WithoutDeadline (ctx )
209193
210- closeSessionCtx := xcontext .WithoutDeadline (ctx )
211-
212- if timeout := c .config .DeleteTimeout (); timeout > 0 {
194+ if timeout := c .config .CreateSessionTimeout (); timeout > 0 {
213195 var cancel context.CancelFunc
214- createSessionCtx , cancel = context .WithTimeout (closeSessionCtx , timeout )
196+ createSessionCtx , cancel = context .WithTimeout (createSessionCtx , timeout )
215197 defer cancel ()
216198 }
217199
218- _ = s .Close (ctx )
219- }
200+ s , err = c .build (createSessionCtx )
220201
221- select {
222- case ch <- result {
223- s : s ,
224- err : err ,
225- }: // nop
202+ closeSession := func (s * session ) {
203+ if s == nil {
204+ return
205+ }
226206
227- case <- c .done :
228- closeSession (s )
207+ closeSessionCtx := xcontext .WithoutDeadline (ctx )
229208
230- case <- ctx .Done ():
231- closeSession (s )
232- }
233- }()
209+ if timeout := c .config .DeleteTimeout (); timeout > 0 {
210+ var cancel context.CancelFunc
211+ createSessionCtx , cancel = context .WithTimeout (closeSessionCtx , timeout )
212+ defer cancel ()
213+ }
214+
215+ _ = s .Close (ctx )
216+ }
217+
218+ select {
219+ case ch <- result {
220+ s : s ,
221+ err : err ,
222+ }: // nop
223+
224+ case <- c .done :
225+ closeSession (s )
226+
227+ case <- ctx .Done ():
228+ closeSession (s )
229+ }
230+ }()
231+ })
234232 }
235233
236234 select {
@@ -280,6 +278,9 @@ func (c *Client) CreateSession(ctx context.Context, opts ...table.Option) (_ tab
280278 if c == nil {
281279 return nil , xerrors .WithStackTrace (errNilClient )
282280 }
281+ if c .isClosed () {
282+ return nil , xerrors .WithStackTrace (errClosedClient )
283+ }
283284 var s * session
284285 createSession := func (ctx context.Context ) (* session , error ) {
285286 s , err = c .createSession (ctx ,
@@ -331,7 +332,12 @@ func (c *Client) CreateSession(ctx context.Context, opts ...table.Option) (_ tab
331332}
332333
333334func (c * Client ) isClosed () bool {
334- return atomic .LoadUint32 (& c .closed ) != 0
335+ select {
336+ case <- c .done :
337+ return true
338+ default :
339+ return false
340+ }
335341}
336342
337343// c.mu must NOT be held.
@@ -611,16 +617,19 @@ func (c *Client) Close(ctx context.Context) (err error) {
611617 return xerrors .WithStackTrace (errNilClient )
612618 }
613619
614- onDone := trace .TableOnClose (c .config .Trace (), & ctx )
615- defer func () {
616- onDone (err )
617- }()
620+ c .mu .WithLock (func () {
621+ select {
622+ case <- c .done :
623+ return
624+
625+ default :
626+ close (c .done )
618627
619- if atomic .CompareAndSwapUint32 (& c .closed , 0 , 1 ) {
620- close (c .done )
628+ onDone := trace .TableOnClose (c .config .Trace (), & ctx )
629+ defer func () {
630+ onDone (err )
631+ }()
621632
622- wg := sync.WaitGroup {}
623- c .mu .WithLock (func () {
624633 c .limit = 0
625634
626635 for el := c .waitq .Front (); el != nil ; el = el .Next () {
@@ -629,19 +638,18 @@ func (c *Client) Close(ctx context.Context) (err error) {
629638 }
630639
631640 for e := c .idle .Front (); e != nil ; e = e .Next () {
632- wg .Add (1 )
633641 s := e .Value .(* session )
634642 s .SetStatus (table .SessionClosing )
643+ c .wg .Add (1 )
635644 go func () {
636- defer wg .Done ()
645+ defer c . wg .Done ()
637646 c .internalPoolSyncCloseSession (ctx , s )
638647 }()
639648 }
640- })
641- wg . Wait ( )
649+ }
650+ } )
642651
643- c .spawnedGoroutines .Wait ()
644- }
652+ c .wg .Wait ()
645653
646654 return nil
647655}
@@ -689,7 +697,7 @@ func (c *Client) DoTx(ctx context.Context, op table.TxOperation, opts ...table.O
689697}
690698
691699func (c * Client ) internalPoolGC (ctx context.Context , idleThreshold time.Duration ) {
692- defer c .spawnedGoroutines .Done ()
700+ defer c .wg .Done ()
693701 timer := timeutil .NewTimer (idleThreshold )
694702
695703 for {
@@ -699,6 +707,9 @@ func (c *Client) internalPoolGC(ctx context.Context, idleThreshold time.Duration
699707
700708 case <- timer .C ():
701709 c .mu .WithLock (func () {
710+ if c .isClosed () {
711+ return
712+ }
702713 for e := c .idle .Front (); e != nil ; e = e .Next () {
703714 s := e .Value .(* session )
704715 info , has := c .index [s ]
@@ -712,8 +723,8 @@ func (c *Client) internalPoolGC(ctx context.Context, idleThreshold time.Duration
712723 c .internalPoolAsyncCloseSession (ctx , s )
713724 }
714725 }
726+ timer .Reset (idleThreshold / 2 )
715727 })
716- timer .Reset (idleThreshold / 2 )
717728 }
718729 }
719730}
@@ -806,11 +817,16 @@ func (c *Client) internalPoolNotify(s *session) (notified bool) {
806817
807818func (c * Client ) internalPoolAsyncCloseSession (ctx context.Context , s * session ) {
808819 s .SetStatus (table .SessionClosing )
809- c .spawnedGoroutines .Add (1 )
810- go func () {
811- defer c .spawnedGoroutines .Done ()
812- c .internalPoolSyncCloseSession (ctx , s )
813- }()
820+ c .mu .WithLock (func () {
821+ if c .isClosed () {
822+ return
823+ }
824+ c .wg .Add (1 )
825+ go func () {
826+ defer c .wg .Done ()
827+ c .internalPoolSyncCloseSession (ctx , s )
828+ }()
829+ })
814830}
815831
816832func (c * Client ) internalPoolSyncCloseSession (ctx context.Context , s * session ) {
0 commit comments