Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/receive/capnp_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type CapNProtoServer struct {
logger log.Logger
}

func NewCapNProtoServer(listener net.Listener, handler *CapNProtoHandler, logger log.Logger) *CapNProtoServer {
func NewCapNProtoServer(listener net.Listener, handler writecapnp.Writer_Server, logger log.Logger) *CapNProtoServer {
return &CapNProtoServer{
listener: listener,
server: writecapnp.Writer_ServerToClient(handler),
Expand Down
76 changes: 76 additions & 0 deletions pkg/receive/capnproto_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ package receive

import (
"context"
"errors"
"os"
"sync"
"testing"

"github.com/go-kit/log"
"github.com/stretchr/testify/require"
"github.com/thanos-io/thanos/pkg/testutil/custom"
"google.golang.org/grpc/test/bufconn"

"github.com/thanos-io/thanos/pkg/receive/writecapnp"
Expand Down Expand Up @@ -78,3 +81,76 @@ func TestCapNProtoServer_MultipleConcurrentClients(t *testing.T) {

require.NoError(t, listener.Close())
}

func TestCapNProtoServer_MultipleSerialClientsWithReconnect(t *testing.T) {
custom.TolerantVerifyLeak(t)
var (
logger = log.NewNopLogger()
listener = bufconn.Listen(1024)
handler = newFaultyHandler(2)
)

srv := NewCapNProtoServer(listener, handler, logger)
go func() { _ = srv.ListenAndServe() }()
t.Cleanup(srv.Shutdown)

client := writecapnp.NewRemoteWriteClient(listener, logger)
const numRuns = 100
const numRequests = 10
for range numRuns {
handler.numFailures = 2
var wg sync.WaitGroup
for range numRequests {
wg.Go(func() {
_, err := client.RemoteWrite(context.Background(), &storepb.WriteRequest{
Tenant: "default",
})
require.NoError(t, err)
})
}
wg.Wait()
}
require.NoError(t, client.Close())
require.NoError(t, listener.Close())
}

type faultyHandler struct {
mu sync.Mutex
numFailures int
}

func newFaultyHandler(failEach int) *faultyHandler {
return &faultyHandler{numFailures: failEach}
}

func (f *faultyHandler) Write(ctx context.Context, call writecapnp.Writer_write) error {
call.Go()
if err := f.checkFailures(); err != nil {
return err
}

arg, err := call.Args().Wr()
if err != nil {
return err
}
if _, err := arg.Tenant(); err != nil {
return err
}
results, err := call.AllocResults()
if err != nil {
return err
}
results.SetError(writecapnp.WriteError_none)
return nil
}

func (f *faultyHandler) checkFailures() error {
f.mu.Lock()
defer f.mu.Unlock()

f.numFailures--
if f.numFailures > 0 {
return errors.New("handler failure")
}
return nil
}
183 changes: 123 additions & 60 deletions pkg/receive/writecapnp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package writecapnp
import (
"context"
"fmt"
"io"
"net"
"sync"

Expand All @@ -14,13 +15,31 @@ import (
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/pkg/errors"
"github.com/thanos-io/thanos/pkg/runutil"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/thanos-io/thanos/pkg/store/storepb"
)

type connState int

const (
connStateDisconnected = iota
connStateError
connStateConnected
)

type conn struct {
mu sync.RWMutex

state connState
generation uint64 // Incremented on each reconnect to track connection instances
closer io.Closer
writer Writer
}

type Dialer interface {
Dial() (net.Conn, error)
}
Expand All @@ -46,108 +65,152 @@ func (t TCPDialer) Dial() (net.Conn, error) {
}

type RemoteWriteClient struct {
mu sync.Mutex

dialer Dialer
conn *rpc.Conn
conn *conn

writer Writer
logger log.Logger
}

func NewRemoteWriteClient(dialer Dialer, logger log.Logger) *RemoteWriteClient {
return &RemoteWriteClient{
dialer: dialer,
logger: logger,
conn: &conn{},
}
}

func (r *RemoteWriteClient) RemoteWrite(ctx context.Context, in *storepb.WriteRequest, _ ...grpc.CallOption) (*storepb.WriteResponse, error) {
return r.writeWithReconnect(ctx, 2, in)
}

func (r *RemoteWriteClient) writeWithReconnect(ctx context.Context, numReconnects int, in *storepb.WriteRequest) (*storepb.WriteResponse, error) {
if err := r.connect(ctx); err != nil {
return nil, err
}

result, release := r.writer.Write(ctx, func(params Writer_write_Params) error {
wr, err := params.NewWr()
const numAttempts = 3
var (
resp Writer_write_Results
release func()

err error
generation uint64
)
for range numAttempts {
generation, err = r.conn.connect(ctx, r.logger, r.dialer)
if err != nil {
return err
return nil, err
}
return BuildInto(wr, in.Tenant, in.Timeseries)
})
defer release()

s, err := result.Struct()
if err != nil {
if numReconnects > 0 && capnp.IsDisconnected(err) {
level.Warn(r.logger).Log("msg", "rpc failed, reconnecting")
if err := r.Close(); err != nil {
return nil, err
}
numReconnects--
return r.writeWithReconnect(ctx, numReconnects, in)
if resp, release, err = r.write(ctx, in); err == nil {
break
}
return nil, errors.Wrap(err, "failed writing to peer")
r.conn.setStateErrorIfGeneration(generation)
level.Warn(r.logger).Log("msg", "rpc failed, reconnecting", "err", err.Error())
}
if err != nil {
return nil, err
}
switch s.Error() {
defer release()

switch resp.Error() {
case WriteError_unavailable:
return nil, status.Error(codes.Unavailable, "rpc failed")
case WriteError_alreadyExists:
return nil, status.Error(codes.AlreadyExists, "rpc failed")
case WriteError_invalidArgument:
return nil, status.Error(codes.InvalidArgument, "rpc failed")
case WriteError_internal:
extraContext, err := s.ExtraErrorContext()
if err != nil {
if numReconnects > 0 && capnp.IsDisconnected(err) {
level.Warn(r.logger).Log("msg", "rpc failed, reconnecting")
if err := r.Close(); err != nil {
return nil, err
}
numReconnects--
return r.writeWithReconnect(ctx, numReconnects, in)
}
return nil, errors.Wrap(err, "failed writing to peer")
}

if extraContext == "" {
extraContext, err := resp.ExtraErrorContext()
if err != nil || extraContext == "" {
extraContext = " (no additional context provided)"
} else {
extraContext = ": " + extraContext
}

return nil, status.Error(codes.Internal, fmt.Sprintf("rpc failed%s", extraContext))
default:
return &storepb.WriteResponse{}, nil
}
}

func (r *RemoteWriteClient) connect(ctx context.Context) error {
func (r *RemoteWriteClient) write(ctx context.Context, in *storepb.WriteRequest) (Writer_write_Results, func(), error) {
r.conn.mu.RLock()
defer r.conn.mu.RUnlock()

arena := capnp.SingleSegment(nil)
defer arena.Release()
Comment on lines +131 to +132
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arena is released via defer on line 121, but the function returns resp which was created from result.Struct() on line 142. If resp references memory from the arena's segment, this could lead to use-after-free since the arena is released before the caller can use the returned response.

Cap'n Proto messages typically need their backing segments to remain alive for as long as the structs are in use. Consider whether the arena should be released at all here, or if its lifecycle should be tied to the response's lifecycle (e.g., via the release function).

Copilot uses AI. Check for mistakes.

result, release := r.conn.writer.Write(ctx, func(params Writer_write_Params) error {
_, seg, err := capnp.NewMessage(arena)
if err != nil {
return err
}
wr, err := NewRootWriteRequest(seg)
if err != nil {
return err
}
if err := params.SetWr(wr); err != nil {
return err
}
wr, err = params.Wr()
if err != nil {
return err
}
return BuildInto(wr, in.Tenant, in.Timeseries)
})

resp, err := result.Struct()
return resp, release, err
}

func (r *conn) setStateErrorIfGeneration(generation uint64) {
r.mu.Lock()
defer r.mu.Unlock()
if r.conn != nil {
return nil

// Only set error state if we're still on the same connection generation. This
// prevents marking a newly established connection as errored due to a failure
// from a previous connection.
if r.generation == generation {
r.state = connStateError
}
}

conn, err := r.dialer.Dial()
if err != nil {
return errors.Wrap(err, "failed to dial peer")
func (r *conn) connect(ctx context.Context, logger log.Logger, dialer Dialer) (uint64, error) {
r.mu.Lock()
defer r.mu.Unlock()

switch r.state {
case connStateConnected:
return r.generation, nil
case connStateError:
r.close(logger)
fallthrough
case connStateDisconnected:
cc, err := dialer.Dial()
if err != nil {
return 0, errors.Wrap(err, "failed to dial peer")
}
codec := rpc.NewPackedStreamTransport(cc)
r.closer = codec

rpcConn := rpc.NewConn(codec, nil)
r.writer = Writer(rpcConn.Bootstrap(ctx))
Comment on lines +187 to +188
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RPC connection (rpcConn) is created but never explicitly closed or stored for cleanup. When the connection state changes to connStateError or when reconnecting, the previous rpcConn instance is lost without being closed, potentially leading to resource leaks.

Consider storing the rpcConn in the conn struct (similar to how closer is stored) and ensuring it's properly closed in the close() method. Alternatively, ensure that closing the codec also properly closes the RPC connection.

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to call Resolve() here like in https://github.com/thanos-io/thanos/pull/8491/files?

r.state = connStateConnected
r.generation++
}
r.conn = rpc.NewConn(rpc.NewPackedStreamTransport(conn), nil)
r.writer = Writer(r.conn.Bootstrap(ctx))
return nil
return r.generation, nil
}

func (r *RemoteWriteClient) Close() error {
r.conn.closeLocked(r.logger)
return nil
}

func (r *conn) closeLocked(logger log.Logger) {
r.mu.Lock()
if r.conn != nil {
conn := r.conn
r.conn = nil
go conn.Close()
defer r.mu.Unlock()

r.close(logger)
}

func (r *conn) close(logger log.Logger) {
if r.closer != nil {
codec := r.closer
r.closer = nil
go func() {
runutil.CloseWithLogOnErr(logger, codec, "capnp closer")
}()
}
r.mu.Unlock()
return nil
r.state = connStateDisconnected
Copy link

Copilot AI Dec 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The writer field is not reset when closing the connection. After close() is called, r.writer still holds a reference to the old RPC client object, but r.state is set to connStateDisconnected. This could lead to issues if there's any code path that tries to use the writer based on the state check.

Consider setting r.writer to a zero value after closing to ensure no stale references remain:

r.writer = Writer{}
Suggested change
r.state = connStateDisconnected
r.state = connStateDisconnected
r.writer = Writer{}

Copilot uses AI. Check for mistakes.
}
Loading