@@ -185,12 +185,10 @@ type Writer struct {
185
185
// If nil, DefaultTransport is used.
186
186
Transport RoundTripper
187
187
188
- // Atomic flag indicating whether the writer has been closed.
189
- closed uint32
190
- group sync.WaitGroup
191
-
192
188
// Manages the current set of partition-topic writers.
189
+ group sync.WaitGroup
193
190
mutex sync.Mutex
191
+ closed bool
194
192
writers map [topicPartition ]* partitionWriter
195
193
196
194
// writer stats are all made of atomic values, no need for synchronization.
@@ -505,13 +503,47 @@ func NewWriter(config WriterConfig) *Writer {
505
503
return w
506
504
}
507
505
506
+ // enter is called by WriteMessages to indicate that a new inflight operation
507
+ // has started, which helps synchronize with Close and ensure that the method
508
+ // does not return until all inflight operations were completed.
509
+ func (w * Writer ) enter () bool {
510
+ w .mutex .Lock ()
511
+ defer w .mutex .Unlock ()
512
+ if w .closed {
513
+ return false
514
+ }
515
+ w .group .Add (1 )
516
+ return true
517
+ }
518
+
519
+ // leave is called by WriteMessages to indicate that the inflight operation has
520
+ // completed.
521
+ func (w * Writer ) leave () { w .group .Done () }
522
+
523
+ // spawn starts an new asynchronous operation on the writer. This method is used
524
+ // instead of starting goroutines inline to help manage the state of the
525
+ // writer's wait group. The wait group is used to block Close calls until all
526
+ // inflight operations have completed, therefore automatically including those
527
+ // started with calls to spawn.
528
+ func (w * Writer ) spawn (f func ()) {
529
+ w .group .Add (1 )
530
+ go func () {
531
+ defer w .group .Done ()
532
+ f ()
533
+ }()
534
+ }
535
+
508
536
// Close flushes pending writes, and waits for all writes to complete before
509
537
// returning. Calling Close also prevents new writes from being submitted to
510
538
// the writer, further calls to WriteMessages and the like will fail with
511
539
// io.ErrClosedPipe.
512
540
func (w * Writer ) Close () error {
513
- w .markClosed ()
514
541
w .mutex .Lock ()
542
+ // Marking the writer as closed here causes future calls to WriteMessages to
543
+ // fail with io.ErrClosedPipe. Mutation of this field is synchronized on the
544
+ // writer's mutex to ensure that no more increments of the wait group are
545
+ // performed afterwards (which could otherwise race with the Wait below).
546
+ w .closed = true
515
547
516
548
// close all writers to trigger any pending batches
517
549
for _ , writer := range w .writers {
@@ -561,12 +593,10 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error {
561
593
return errors .New ("kafka.(*Writer).WriteMessages: cannot create a kafka writer with a nil address" )
562
594
}
563
595
564
- w .group .Add (1 )
565
- defer w .group .Done ()
566
-
567
- if w .isClosed () {
596
+ if ! w .enter () {
568
597
return io .ErrClosedPipe
569
598
}
599
+ defer w .leave ()
570
600
571
601
if len (msgs ) == 0 {
572
602
return nil
@@ -720,14 +750,6 @@ func (w *Writer) partitions(ctx context.Context, topic string) (int, error) {
720
750
return 0 , UnknownTopicOrPartition
721
751
}
722
752
723
- func (w * Writer ) markClosed () {
724
- atomic .StoreUint32 (& w .closed , 1 )
725
- }
726
-
727
- func (w * Writer ) isClosed () bool {
728
- return atomic .LoadUint32 (& w .closed ) != 0
729
- }
730
-
731
753
func (w * Writer ) client (timeout time.Duration ) * Client {
732
754
return & Client {
733
755
Addr : w .Addr ,
@@ -936,8 +958,6 @@ type partitionWriter struct {
936
958
mutex sync.Mutex
937
959
currBatch * writeBatch
938
960
939
- group sync.WaitGroup
940
-
941
961
// reference to the writer that owns this batch. Used for the produce logic
942
962
// as well as stat tracking
943
963
w * Writer
@@ -949,12 +969,7 @@ func newPartitionWriter(w *Writer, key topicPartition) *partitionWriter {
949
969
queue : newBatchQueue (10 ),
950
970
w : w ,
951
971
}
952
- go func () {
953
- writer .group .Add (1 )
954
- defer writer .group .Done ()
955
- writer .writeBatches ()
956
- }()
957
-
972
+ w .spawn (writer .writeBatches )
958
973
return writer
959
974
}
960
975
@@ -970,14 +985,10 @@ func (ptw *partitionWriter) writeBatches() {
970
985
}
971
986
972
987
ptw .writeBatch (batch )
973
-
974
988
}
975
989
}
976
990
977
991
func (ptw * partitionWriter ) writeMessages (msgs []Message , indexes []int32 ) map [* writeBatch ][]int32 {
978
- ptw .group .Add (1 )
979
- defer ptw .group .Done ()
980
-
981
992
ptw .mutex .Lock ()
982
993
defer ptw .mutex .Unlock ()
983
994
@@ -1019,11 +1030,7 @@ func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[*
1019
1030
// ptw.w can be accessed here because this is called with the lock ptw.mutex already held.
1020
1031
func (ptw * partitionWriter ) newWriteBatch () * writeBatch {
1021
1032
batch := newWriteBatch (time .Now (), ptw .w .batchTimeout ())
1022
- ptw .group .Add (1 )
1023
- go func () {
1024
- defer ptw .group .Done ()
1025
- ptw .awaitBatch (batch )
1026
- }()
1033
+ ptw .w .spawn (func () { ptw .awaitBatch (batch ) })
1027
1034
return batch
1028
1035
}
1029
1036
@@ -1150,9 +1157,8 @@ func (ptw *partitionWriter) close() {
1150
1157
ptw .currBatch = nil
1151
1158
batch .trigger ()
1152
1159
}
1153
- ptw .queue .Close ()
1154
1160
1155
- ptw .group . Wait ()
1161
+ ptw .queue . Close ()
1156
1162
}
1157
1163
1158
1164
type writeBatch struct {
0 commit comments