Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions grpcbp/client_middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ package grpcbp
import (
"context"
"errors"
"fmt"
"os"
"strings"
"time"

"github.com/opentracing/opentracing-go"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/reddit/baseplate.go/ecinterface"
Expand Down Expand Up @@ -192,3 +196,46 @@ func PrometheusStreamClientInterceptor(serverSlug string) grpc.StreamClientInter
return nil, errors.New("PrometheusStreamClientInterceptor: not implemented")
}
}

// WithDefaultBlock returns a DialOption which makes callers of Dial block until the
// underlying connection is up. Without this, Dial returns immediately and
// connecting the server happens in background.
//
// If the REDDIT_RPC_CONNECTION_MODE=non-blocking env var is set the connection will
// not block (keeping the default behavior). In non-blocking mode additional interceptors are
// added to improve connection error messages and to set [grpc.WaitForReady] on all RPC calls.
func WithDefaultBlock() grpc.DialOption {
if strings.EqualFold(os.Getenv("REDDIT_RPC_CONNECTION_MODE"), "non-blocking") {
return grpc.WithChainUnaryInterceptor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about stream connections? would grpc.WithDefaultCallOptions(grpc.WaitForReady(true)) provide what you're looking for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how to handle streaming connections. We'll have to revisit streaming connections later if we find some that are relevant.

WithDefaultCallOptions does seem like a better way to implement this!

connectionErrorInterceptor(),
waitForReadyInterceptor(),
)
}
return grpc.WithBlock()
}

// connectionErrorInterceptor adds additional context to connection errors.
func connectionErrorInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
err := invoker(ctx, method, req, reply, cc, opts...)
if err == nil { // if no error
return nil
}

state := cc.GetState()
target := cc.Target()
switch status.Code(err) {
case codes.DeadlineExceeded, codes.Unavailable:
return fmt.Errorf("%w: Dial(%q) connection_state = %v", err, target, state)
}
return err
}
}

// waitForReadyInterceptor sets grpc.WaitForReady(true) on all calls.
func waitForReadyInterceptor() grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
opts = append(opts, grpc.WaitForReady(true))
return invoker(ctx, method, req, reply, cc, opts...)
}
}
33 changes: 33 additions & 0 deletions grpcbp/client_middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -170,3 +171,35 @@ func drainRecorder(t *testing.T, recorder *mqsend.MockMessageQueue) []byte {
}
return msg
}

func TestDial_WithDefaultBlock_NonBlocking(t *testing.T) {
ctx := t.Context()
timeout := 100 * time.Millisecond
ctx, cancel := context.WithTimeout(ctx, timeout)
t.Cleanup(cancel)

t.Setenv("REDDIT_RPC_CONNECTION_MODE", "non-blocking")

start := time.Now()
target := "target-foo:9091"
conn, err := grpc.DialContext(ctx, target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
WithDefaultBlock(),
)
if err != nil {
t.Fatalf("DialReddit: unexpected error %s", err)
}
defer conn.Close()

client := pb.NewTestServiceClient(conn)
_, err = client.Ping(ctx, &pb.PingRequest{Value: "hello"})
if err == nil {
t.Fatalf("Request should have failed")
}
if !strings.Contains(err.Error(), target) {
t.Errorf("Request error: got=%v, wanted error to contain %v", err, target)
}
if elapsed := time.Since(start); elapsed < timeout {
t.Errorf("Request did not wait for connection to be ready, took %v with timeout %v", elapsed, timeout)
}
}