Skip to content

Commit 063303c

Browse files
author
Achille
authored
fix writer async close (#805)
1 parent 44a678e commit 063303c

File tree

1 file changed

+42
-36
lines changed

1 file changed

+42
-36
lines changed

writer.go

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,10 @@ type Writer struct {
185185
// If nil, DefaultTransport is used.
186186
Transport RoundTripper
187187

188-
// Atomic flag indicating whether the writer has been closed.
189-
closed uint32
190-
group sync.WaitGroup
191-
192188
// Manages the current set of partition-topic writers.
189+
group sync.WaitGroup
193190
mutex sync.Mutex
191+
closed bool
194192
writers map[topicPartition]*partitionWriter
195193

196194
// writer stats are all made of atomic values, no need for synchronization.
@@ -505,13 +503,47 @@ func NewWriter(config WriterConfig) *Writer {
505503
return w
506504
}
507505

506+
// enter is called by WriteMessages to indicate that a new inflight operation
507+
// has started, which helps synchronize with Close and ensure that the method
508+
// does not return until all inflight operations were completed.
509+
func (w *Writer) enter() bool {
510+
w.mutex.Lock()
511+
defer w.mutex.Unlock()
512+
if w.closed {
513+
return false
514+
}
515+
w.group.Add(1)
516+
return true
517+
}
518+
519+
// leave is called by WriteMessages to indicate that the inflight operation has
520+
// completed.
521+
func (w *Writer) leave() { w.group.Done() }
522+
523+
// spawn starts an new asynchronous operation on the writer. This method is used
524+
// instead of starting goroutines inline to help manage the state of the
525+
// writer's wait group. The wait group is used to block Close calls until all
526+
// inflight operations have completed, therefore automatically including those
527+
// started with calls to spawn.
528+
func (w *Writer) spawn(f func()) {
529+
w.group.Add(1)
530+
go func() {
531+
defer w.group.Done()
532+
f()
533+
}()
534+
}
535+
508536
// Close flushes pending writes, and waits for all writes to complete before
509537
// returning. Calling Close also prevents new writes from being submitted to
510538
// the writer, further calls to WriteMessages and the like will fail with
511539
// io.ErrClosedPipe.
512540
func (w *Writer) Close() error {
513-
w.markClosed()
514541
w.mutex.Lock()
542+
// Marking the writer as closed here causes future calls to WriteMessages to
543+
// fail with io.ErrClosedPipe. Mutation of this field is synchronized on the
544+
// writer's mutex to ensure that no more increments of the wait group are
545+
// performed afterwards (which could otherwise race with the Wait below).
546+
w.closed = true
515547

516548
// close all writers to trigger any pending batches
517549
for _, writer := range w.writers {
@@ -561,12 +593,10 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error {
561593
return errors.New("kafka.(*Writer).WriteMessages: cannot create a kafka writer with a nil address")
562594
}
563595

564-
w.group.Add(1)
565-
defer w.group.Done()
566-
567-
if w.isClosed() {
596+
if !w.enter() {
568597
return io.ErrClosedPipe
569598
}
599+
defer w.leave()
570600

571601
if len(msgs) == 0 {
572602
return nil
@@ -720,14 +750,6 @@ func (w *Writer) partitions(ctx context.Context, topic string) (int, error) {
720750
return 0, UnknownTopicOrPartition
721751
}
722752

723-
func (w *Writer) markClosed() {
724-
atomic.StoreUint32(&w.closed, 1)
725-
}
726-
727-
func (w *Writer) isClosed() bool {
728-
return atomic.LoadUint32(&w.closed) != 0
729-
}
730-
731753
func (w *Writer) client(timeout time.Duration) *Client {
732754
return &Client{
733755
Addr: w.Addr,
@@ -936,8 +958,6 @@ type partitionWriter struct {
936958
mutex sync.Mutex
937959
currBatch *writeBatch
938960

939-
group sync.WaitGroup
940-
941961
// reference to the writer that owns this batch. Used for the produce logic
942962
// as well as stat tracking
943963
w *Writer
@@ -949,12 +969,7 @@ func newPartitionWriter(w *Writer, key topicPartition) *partitionWriter {
949969
queue: newBatchQueue(10),
950970
w: w,
951971
}
952-
go func() {
953-
writer.group.Add(1)
954-
defer writer.group.Done()
955-
writer.writeBatches()
956-
}()
957-
972+
w.spawn(writer.writeBatches)
958973
return writer
959974
}
960975

@@ -970,14 +985,10 @@ func (ptw *partitionWriter) writeBatches() {
970985
}
971986

972987
ptw.writeBatch(batch)
973-
974988
}
975989
}
976990

977991
func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[*writeBatch][]int32 {
978-
ptw.group.Add(1)
979-
defer ptw.group.Done()
980-
981992
ptw.mutex.Lock()
982993
defer ptw.mutex.Unlock()
983994

@@ -1019,11 +1030,7 @@ func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[*
10191030
// ptw.w can be accessed here because this is called with the lock ptw.mutex already held.
10201031
func (ptw *partitionWriter) newWriteBatch() *writeBatch {
10211032
batch := newWriteBatch(time.Now(), ptw.w.batchTimeout())
1022-
ptw.group.Add(1)
1023-
go func() {
1024-
defer ptw.group.Done()
1025-
ptw.awaitBatch(batch)
1026-
}()
1033+
ptw.w.spawn(func() { ptw.awaitBatch(batch) })
10271034
return batch
10281035
}
10291036

@@ -1150,9 +1157,8 @@ func (ptw *partitionWriter) close() {
11501157
ptw.currBatch = nil
11511158
batch.trigger()
11521159
}
1153-
ptw.queue.Close()
11541160

1155-
ptw.group.Wait()
1161+
ptw.queue.Close()
11561162
}
11571163

11581164
type writeBatch struct {

0 commit comments

Comments
 (0)