Skip to content

Commit b42935d

Browse files
authored
chore(spanner): integrate location aware routing with RPCs (googleapis#13877)
- Add `locationAwareSpannerClient` wrapper that intercepts RPCs and routes them to the server endpoint resolved by the existing `channelFinder`, falling back to the default gRPC channel when no endpoint is available - Add `endpointClientCache` that creates and caches per-address gRPC connections - Add transaction affinity tracking so Commit/Rollback route to the same server that handled the transaction, with read-only transactions routed independently per-request based on key ranges - Move request preparation (`prepareReadRequest`, etc.) and response observation (`observePartialResultSet`, etc.) from transaction/batch code into the client wrapper, keeping routing concerns in one place
1 parent 1d71802 commit b42935d

21 files changed

+1855
-129
lines changed

spanner/batch.go

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,6 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
311311
DataBoostEnabled: p.rreq.DataBoostEnabled,
312312
DirectedReadOptions: p.rreq.DirectedReadOptions,
313313
}
314-
if t.locationRouter != nil {
315-
t.locationRouter.prepareReadRequest(req)
316-
}
317314
client, err := client.StreamingRead(ctx, req, opts...)
318315
if err != nil {
319316
return client, err
@@ -344,9 +341,6 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
344341
DataBoostEnabled: p.qreq.DataBoostEnabled,
345342
DirectedReadOptions: p.qreq.DirectedReadOptions,
346343
}
347-
if t.locationRouter != nil {
348-
t.locationRouter.prepareExecuteSQLRequest(req)
349-
}
350344
client, err := client.ExecuteStreamingSql(ctx, req, opts...)
351345
if err != nil {
352346
return client, err
@@ -376,12 +370,7 @@ func (t *BatchReadOnlyTransaction) Execute(ctx context.Context, p *Partition) *R
376370
nil,
377371
t.setTimestamp,
378372
t.release,
379-
client.(*grpcSpannerClient),
380-
func(prs *sppb.PartialResultSet) {
381-
if t.locationRouter != nil {
382-
t.locationRouter.observePartialResultSet(prs)
383-
}
384-
},
373+
asGRPCSpannerClient(client),
385374
)
386375
}
387376

spanner/batch_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,10 @@ func TestBatchExecute_Query_PreparesRoutingHint(t *testing.T) {
383383
t.Fatal(err)
384384
}
385385
defer txn.Cleanup(ctx)
386-
if txn.locationRouter == nil {
386+
if client.locationRouter == nil {
387387
t.Fatal("expected location router to be enabled")
388388
}
389-
txn.locationRouter.observePartialResultSet(&sppb.PartialResultSet{
389+
client.locationRouter.observePartialResultSet(&sppb.PartialResultSet{
390390
CacheUpdate: &sppb.CacheUpdate{DatabaseId: 7},
391391
})
392392

@@ -441,10 +441,10 @@ func TestBatchExecute_Read_PreparesRoutingHint(t *testing.T) {
441441
t.Fatal(err)
442442
}
443443
defer txn.Cleanup(ctx)
444-
if txn.locationRouter == nil {
444+
if client.locationRouter == nil {
445445
t.Fatal("expected location router to be enabled")
446446
}
447-
txn.locationRouter.observePartialResultSet(&sppb.PartialResultSet{
447+
client.locationRouter.observePartialResultSet(&sppb.PartialResultSet{
448448
CacheUpdate: &sppb.CacheUpdate{DatabaseId: 9},
449449
})
450450

spanner/channel_finder.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
package spanner
1818

1919
import (
20+
"context"
2021
"sync"
2122
"sync/atomic"
2223

@@ -63,39 +64,39 @@ func (f *channelFinder) update(update *sppb.CacheUpdate) {
6364
f.rangeCache.addRanges(update)
6465
}
6566

66-
func (f *channelFinder) findServerRead(req *sppb.ReadRequest, preferLeader bool) channelEndpoint {
67+
func (f *channelFinder) findServerRead(ctx context.Context, req *sppb.ReadRequest, preferLeader bool) channelEndpoint {
6768
if req == nil {
6869
return nil
6970
}
7071
f.recipeCache.computeReadKeys(req)
7172
hint := ensureReadRoutingHint(req)
72-
return f.fillRoutingHint(preferLeader, rangeModeCoveringSplit, req.GetDirectedReadOptions(), hint)
73+
return f.fillRoutingHint(ctx, preferLeader, rangeModeCoveringSplit, req.GetDirectedReadOptions(), hint)
7374
}
7475

75-
func (f *channelFinder) findServerReadWithTransaction(req *sppb.ReadRequest) channelEndpoint {
76+
func (f *channelFinder) findServerReadWithTransaction(ctx context.Context, req *sppb.ReadRequest) channelEndpoint {
7677
if req == nil {
7778
return nil
7879
}
79-
return f.findServerRead(req, preferLeaderFromSelector(req.GetTransaction()))
80+
return f.findServerRead(ctx, req, preferLeaderFromSelector(req.GetTransaction()))
8081
}
8182

82-
func (f *channelFinder) findServerExecuteSQL(req *sppb.ExecuteSqlRequest, preferLeader bool) channelEndpoint {
83+
func (f *channelFinder) findServerExecuteSQL(ctx context.Context, req *sppb.ExecuteSqlRequest, preferLeader bool) channelEndpoint {
8384
if req == nil {
8485
return nil
8586
}
8687
f.recipeCache.computeQueryKeys(req)
8788
hint := ensureExecuteSQLRoutingHint(req)
88-
return f.fillRoutingHint(preferLeader, rangeModePickRandom, req.GetDirectedReadOptions(), hint)
89+
return f.fillRoutingHint(ctx, preferLeader, rangeModePickRandom, req.GetDirectedReadOptions(), hint)
8990
}
9091

91-
func (f *channelFinder) findServerExecuteSQLWithTransaction(req *sppb.ExecuteSqlRequest) channelEndpoint {
92+
func (f *channelFinder) findServerExecuteSQLWithTransaction(ctx context.Context, req *sppb.ExecuteSqlRequest) channelEndpoint {
9293
if req == nil {
9394
return nil
9495
}
95-
return f.findServerExecuteSQL(req, preferLeaderFromSelector(req.GetTransaction()))
96+
return f.findServerExecuteSQL(ctx, req, preferLeaderFromSelector(req.GetTransaction()))
9697
}
9798

98-
func (f *channelFinder) findServerBeginTransaction(req *sppb.BeginTransactionRequest) channelEndpoint {
99+
func (f *channelFinder) findServerBeginTransaction(ctx context.Context, req *sppb.BeginTransactionRequest) channelEndpoint {
99100
if req == nil || req.GetMutationKey() == nil {
100101
return nil
101102
}
@@ -107,10 +108,10 @@ func (f *channelFinder) findServerBeginTransaction(req *sppb.BeginTransactionReq
107108
if len(target.limit) > 0 {
108109
hint.LimitKey = append([]byte(nil), target.limit...)
109110
}
110-
return f.fillRoutingHint(preferLeaderFromTransactionOptions(req.GetOptions()), rangeModeCoveringSplit, &sppb.DirectedReadOptions{}, hint)
111+
return f.fillRoutingHint(ctx, preferLeaderFromTransactionOptions(req.GetOptions()), rangeModeCoveringSplit, &sppb.DirectedReadOptions{}, hint)
111112
}
112113

113-
func (f *channelFinder) fillRoutingHint(preferLeader bool, mode rangeMode, directedReadOptions *sppb.DirectedReadOptions, hint *sppb.RoutingHint) channelEndpoint {
114+
func (f *channelFinder) fillRoutingHint(ctx context.Context, preferLeader bool, mode rangeMode, directedReadOptions *sppb.DirectedReadOptions, hint *sppb.RoutingHint) channelEndpoint {
114115
if hint == nil {
115116
return nil
116117
}
@@ -119,7 +120,7 @@ func (f *channelFinder) fillRoutingHint(preferLeader bool, mode rangeMode, direc
119120
return nil
120121
}
121122
hint.DatabaseId = databaseID
122-
return f.rangeCache.fillRoutingHint(preferLeader, mode, directedReadOptions, hint)
123+
return f.rangeCache.fillRoutingHint(ctx, preferLeader, mode, directedReadOptions, hint)
123124
}
124125

125126
func preferLeaderFromSelector(selector *sppb.TransactionSelector) bool {

spanner/client.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,12 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
490490
}
491491

492492
var pool gtransport.ConnPool
493+
var endpointClientOpts []option.ClientOption
493494

494495
if gme != nil {
495496
// Use GCPMultiEndpoint if provided.
496497
pool = &gmeWrapper{gme}
498+
endpointClientOpts = append(endpointClientOpts, opts...)
497499
} else {
498500
// Create gtransport ConnPool as usual if MultiEndpoint is not used.
499501
// gRPC options.
@@ -506,6 +508,7 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
506508
)
507509

508510
allOpts := allClientOpts(config.NumChannels, config.Compression, opts...)
511+
endpointClientOpts = append(endpointClientOpts, allOpts...)
509512
pool, err = gtransport.DialPool(ctx, allOpts...)
510513
if err != nil {
511514
return nil, err
@@ -571,7 +574,9 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
571574

572575
var locationRouter *locationRouter
573576
if isExperimentalLocationAPIEnabled() {
574-
locationRouter = newLocationRouter()
577+
sc.baseClientOpts = endpointClientOpts
578+
epCache := newEndpointClientCache(sc.createEndpointClient)
579+
locationRouter = newLocationRouter(epCache)
575580
}
576581

577582
// Create a session manager.
@@ -581,6 +586,10 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
581586
return nil, err
582587
}
583588

589+
if locationRouter != nil {
590+
sp.locationRouter = locationRouter
591+
}
592+
584593
if enableLogClientOptions() {
585594
projectID, _, _, _ := parseDatabaseName(database)
586595
logf(config.Logger, `
@@ -795,6 +804,9 @@ func (c *Client) Close() {
795804
defer cancel()
796805
c.sm.close(ctx)
797806
}
807+
if c.locationRouter != nil {
808+
c.locationRouter.Close()
809+
}
798810
c.sc.close()
799811
}
800812

@@ -818,7 +830,7 @@ func (c *Client) Single() *ReadOnlyTransaction {
818830
t.txReadOnly.ro.DirectedReadOptions = c.dro
819831
t.txReadOnly.ro.LockHint = sppb.ReadRequest_LOCK_HINT_UNSPECIFIED
820832
t.txReadOnly.clientContext = c.clientContext
821-
t.txReadOnly.locationRouter = c.locationRouter
833+
822834
t.ct = c.ct
823835
t.otConfig = c.otConfig
824836
return t
@@ -847,7 +859,7 @@ func (c *Client) ReadOnlyTransaction() *ReadOnlyTransaction {
847859
t.txReadOnly.ro.DirectedReadOptions = c.dro
848860
t.txReadOnly.ro.LockHint = sppb.ReadRequest_LOCK_HINT_UNSPECIFIED
849861
t.txReadOnly.clientContext = c.clientContext
850-
t.txReadOnly.locationRouter = c.locationRouter
862+
851863
t.ct = c.ct
852864
t.otConfig = c.otConfig
853865
return t
@@ -920,7 +932,7 @@ func (c *Client) BatchReadOnlyTransaction(ctx context.Context, tb TimestampBound
920932
t.txReadOnly.ro.DirectedReadOptions = c.dro
921933
t.txReadOnly.ro.LockHint = sppb.ReadRequest_LOCK_HINT_UNSPECIFIED
922934
t.txReadOnly.clientContext = c.clientContext
923-
t.txReadOnly.locationRouter = c.locationRouter
935+
924936
t.ct = c.ct
925937
t.otConfig = c.otConfig
926938
return t, nil
@@ -958,7 +970,7 @@ func (c *Client) BatchReadOnlyTransactionFromID(tid BatchReadOnlyTransactionID)
958970
t.txReadOnly.ro.DirectedReadOptions = c.dro
959971
t.txReadOnly.ro.LockHint = sppb.ReadRequest_LOCK_HINT_UNSPECIFIED
960972
t.txReadOnly.clientContext = c.clientContext
961-
t.txReadOnly.locationRouter = c.locationRouter
973+
962974
t.ct = c.ct
963975
t.otConfig = c.otConfig
964976
return t
@@ -1046,7 +1058,7 @@ func (c *Client) rwTransaction(ctx context.Context, f func(context.Context, *Rea
10461058
t.wb = []*Mutation{}
10471059
t.txOpts = c.txo.merge(options)
10481060
t.txReadOnly.clientContext = mergeClientContext(c.clientContext, t.txOpts.ClientContext)
1049-
t.txReadOnly.locationRouter = c.locationRouter
1061+
10501062
t.ct = c.ct
10511063
t.otConfig = c.otConfig
10521064
}

spanner/endpoint_client_cache.go

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package spanner
18+
19+
import (
20+
"context"
21+
"sync"
22+
"sync/atomic"
23+
)
24+
25+
// grpcChannelEndpoint is a channelEndpoint backed by a real gRPC connection.
26+
type grpcChannelEndpoint struct {
27+
address string
28+
client spannerClient
29+
healthy atomic.Bool
30+
}
31+
32+
func (e *grpcChannelEndpoint) Address() string {
33+
return e.address
34+
}
35+
36+
func (e *grpcChannelEndpoint) IsHealthy() bool {
37+
return e.healthy.Load()
38+
}
39+
40+
// endpointClientCache implements channelEndpointCache with actual gRPC
41+
// connections to specific server addresses.
42+
type endpointClientCache struct {
43+
mu sync.RWMutex
44+
endpoints map[string]*grpcChannelEndpoint
45+
inflight map[string]*endpointClientCreation
46+
clientFactory func(ctx context.Context, address string) (spannerClient, error)
47+
closed bool
48+
}
49+
50+
type endpointClientCreation struct {
51+
done chan struct{}
52+
ep channelEndpoint
53+
}
54+
55+
func newEndpointClientCache(clientFactory func(ctx context.Context, address string) (spannerClient, error)) *endpointClientCache {
56+
return &endpointClientCache{
57+
endpoints: make(map[string]*grpcChannelEndpoint),
58+
inflight: make(map[string]*endpointClientCreation),
59+
clientFactory: clientFactory,
60+
}
61+
}
62+
63+
// Get returns a channelEndpoint for the given address, creating a new gRPC
64+
// connection if one does not already exist. Channel creation is coordinated
65+
// per address so slow dials do not block unrelated cache access.
66+
func (c *endpointClientCache) Get(ctx context.Context, address string) channelEndpoint {
67+
if ctx == nil {
68+
ctx = context.Background()
69+
}
70+
// Fast path: read lock.
71+
c.mu.RLock()
72+
if ep, ok := c.endpoints[address]; ok {
73+
c.mu.RUnlock()
74+
return ep
75+
}
76+
c.mu.RUnlock()
77+
78+
c.mu.Lock()
79+
if ep, ok := c.endpoints[address]; ok {
80+
c.mu.Unlock()
81+
return ep
82+
}
83+
if c.closed {
84+
c.mu.Unlock()
85+
return nil
86+
}
87+
if creation, ok := c.inflight[address]; ok {
88+
c.mu.Unlock()
89+
select {
90+
case <-creation.done:
91+
return creation.ep
92+
case <-ctx.Done():
93+
return nil
94+
}
95+
}
96+
creation := &endpointClientCreation{done: make(chan struct{})}
97+
c.inflight[address] = creation
98+
c.mu.Unlock()
99+
100+
client, err := c.clientFactory(ctx, address)
101+
102+
c.mu.Lock()
103+
delete(c.inflight, address)
104+
if err == nil && !c.closed {
105+
ep := &grpcChannelEndpoint{
106+
address: address,
107+
client: client,
108+
}
109+
ep.healthy.Store(true)
110+
c.endpoints[address] = ep
111+
creation.ep = ep
112+
}
113+
shouldCloseClient := c.closed && client != nil
114+
close(creation.done)
115+
c.mu.Unlock()
116+
117+
if shouldCloseClient {
118+
_ = client.Close()
119+
}
120+
return creation.ep
121+
}
122+
123+
// ClientFor resolves a channelEndpoint to the underlying spannerClient.
124+
func (c *endpointClientCache) ClientFor(ep channelEndpoint) spannerClient {
125+
if ep == nil {
126+
return nil
127+
}
128+
gep, ok := ep.(*grpcChannelEndpoint)
129+
if !ok {
130+
return nil
131+
}
132+
return gep.client
133+
}
134+
135+
// Close shuts down all cached gRPC connections.
136+
func (c *endpointClientCache) Close() error {
137+
c.mu.Lock()
138+
c.closed = true
139+
defer c.mu.Unlock()
140+
var firstErr error
141+
for addr, ep := range c.endpoints {
142+
if ep.client != nil {
143+
if err := ep.client.Close(); err != nil && firstErr == nil {
144+
firstErr = err
145+
}
146+
}
147+
delete(c.endpoints, addr)
148+
}
149+
return firstErr
150+
}

0 commit comments

Comments
 (0)