From 1732eb7361b3c80b656458539b0303d4627b9cf8 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 08:15:56 -0400 Subject: [PATCH 01/67] Update `TabletBalancer.Pick` signature to accept options Add `PickOpts` which allow balancers to accept options specific to their implementation. This allows the `Pick` signature not to get overly long as implementations and their options are added. Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 14 +++-- go/vt/vtgate/balancer/balancer_test.go | 9 ++-- go/vt/vtgate/tabletgateway.go | 72 ++++++++++++++------------ 3 files changed, 53 insertions(+), 42 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index bfe85194c05..c55f205115e 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -90,12 +90,19 @@ converge on the desired balanced query load. type TabletBalancer interface { // Pick is the main entry point to the balancer. Returns the best tablet out of the list // for a given query to maintain the desired balanced allocation over multiple executions. - Pick(target *querypb.Target, tablets []*discovery.TabletHealth) *discovery.TabletHealth + Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth // DebugHandler provides a summary of tablet balancer state DebugHandler(w http.ResponseWriter, r *http.Request) } +// PickOpts are balancer options that are passed into Pick. This exists so that as more balancer +// implementations are added the Pick signature does not get overly long. +type PickOpts struct { + // sessionHash is the hash of the current session UUID. + sessionHash uint64 +} + func NewTabletBalancer(localCell string, vtGateCells []string) TabletBalancer { return &tabletBalancer{ localCell: localCell, @@ -167,8 +174,7 @@ func (b *tabletBalancer) DebugHandler(w http.ResponseWriter, _ *http.Request) { // Given the total allocation for the set of tablets, choose the best target // by a weighted random sample so that over time the system will achieve the // desired balanced allocation. -func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth) *discovery.TabletHealth { - +func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ *PickOpts) *discovery.TabletHealth { numTablets := len(tablets) if numTablets == 0 { return nil @@ -306,7 +312,7 @@ func (b *tabletBalancer) allocateFlows(allTablets []*discovery.TabletHealth) *ta // to avoid truncating the integer values. shiftFlow := overAllocatedFlow * currentFlow * underAllocatedFlow / a.Inflows[overAllocatedCell] / unbalancedFlow - //fmt.Printf("shift %d %s %s -> %s (over %d current %d in %d under %d unbalanced %d) \n", shiftFlow, vtgateCell, overAllocatedCell, underAllocatedCell, + // fmt.Printf("shift %d %s %s -> %s (over %d current %d in %d under %d unbalanced %d) \n", shiftFlow, vtgateCell, overAllocatedCell, underAllocatedCell, // overAllocatedFlow, currentFlow, a.Inflows[overAllocatedCell], underAllocatedFlow, unbalancedFlow) a.Outflows[vtgateCell][overAllocatedCell] -= shiftFlow diff --git a/go/vt/vtgate/balancer/balancer_test.go b/go/vt/vtgate/balancer/balancer_test.go index 1c6a72421fc..78b6e708d50 100644 --- a/go/vt/vtgate/balancer/balancer_test.go +++ b/go/vt/vtgate/balancer/balancer_test.go @@ -298,7 +298,7 @@ func TestBalancedPick(t *testing.T) { b := NewTabletBalancer(localCell, vtGateCells).(*tabletBalancer) for i := 0; i < N/len(vtGateCells); i++ { - th := b.Pick(target, tablets) + th := b.Pick(target, tablets, nil) if i == 0 { t.Logf("Target Flows %v, Balancer: %s\n", expectedPerCell, b.print()) } @@ -336,7 +336,7 @@ func TestTopologyChanged(t *testing.T) { tablets = tablets[0:2] for i := 0; i < N; i++ { - th := b.Pick(target, tablets) + th := b.Pick(target, tablets, nil) allocation, totalAllocation := b.getAllocation(target, tablets) assert.Equalf(t, ALLOCATION/2, totalAllocation, "totalAllocation mismatch %s", b.print()) @@ -346,7 +346,7 @@ func TestTopologyChanged(t *testing.T) { // Run again with the full topology. Now traffic should go to cell b for i := 0; i < N; i++ { - th := b.Pick(target, allTablets) + th := b.Pick(target, allTablets, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -359,7 +359,7 @@ func TestTopologyChanged(t *testing.T) { newTablet := createTestTablet("b") allTablets[2] = newTablet for i := 0; i < N; i++ { - th := b.Pick(target, allTablets) + th := b.Pick(target, allTablets, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -367,5 +367,4 @@ func TestTopologyChanged(t *testing.T) { assert.Equalf(t, ALLOCATION/4, allocation[th.Tablet.Alias.Uid], "allocation mismatch %s, cell %s", b.print(), allTablets[0].Tablet.Alias.Cell) assert.Equalf(t, "b", th.Tablet.Alias.Cell, "shuffle promoted wrong tablet from cell %s", allTablets[0].Tablet.Alias.Cell) } - } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index c36c6981fa2..82c2849af9f 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -279,8 +279,8 @@ func (gw *TabletGateway) DebugBalancerHandler(w http.ResponseWriter, r *http.Req // withRetry also adds shard information to errors returned from the inner QueryService, so // withShardError should not be combined with withRetry. func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, _ queryservice.QueryService, - _ string, inTransaction bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error)) error { - + _ string, inTransaction bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), +) error { // for transactions, we connect to a specific tablet instead of letting gateway choose one if inTransaction && target.TabletType != topodatapb.TabletType_PRIMARY { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "tabletGateway's query service can only be used for non-transactional queries on replicas") @@ -359,35 +359,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, break } - var th *discovery.TabletHealth - - useBalancer := balancerEnabled - if balancerEnabled && len(balancerKeyspaces) > 0 { - useBalancer = slices.Contains(balancerKeyspaces, target.Keyspace) - } - if useBalancer { - // filter out the tablets that we've tried before (if any), then pick the best one - if len(invalidTablets) > 0 { - tablets = slices.DeleteFunc(tablets, func(t *discovery.TabletHealth) bool { - _, isInvalid := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)] - return isInvalid - }) - } - - th = gw.balancer.Pick(target, tablets) - - } else { - gw.shuffleTablets(gw.localCell, tablets) - - // skip tablets we tried before - for _, t := range tablets { - if _, ok := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)]; !ok { - th = t - break - } - } - } - + th := gw.getBalancerTablet(target, invalidTablets, tablets) if th == nil { // do not override error from last attempt. if err == nil { @@ -419,9 +391,44 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, return NewShardError(err, target) } +// getBalancerTablet selects a tablet for the given query target, using the configured balancer if enabled. Otherwise, it will +// select a random tablet, with preference to the local cell. +func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablets map[string]bool, tablets []*discovery.TabletHealth) *discovery.TabletHealth { + useBalancer := balancerEnabled + if balancerEnabled && len(balancerKeyspaces) > 0 { + useBalancer = slices.Contains(balancerKeyspaces, target.Keyspace) + } + + // Get the tablet from the balancer if enabled + if useBalancer { + // filter out the tablets that we've tried before (if any), then pick the best one + if len(invalidTablets) > 0 { + tablets = slices.DeleteFunc(tablets, func(t *discovery.TabletHealth) bool { + _, isInvalid := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)] + return isInvalid + }) + } + + return gw.balancer.Pick(target, tablets, nil) + } + + // Otherwise, randomly select a tablet, with preference to the local cell + gw.shuffleTablets(gw.localCell, tablets) + + // skip tablets we tried before + for _, t := range tablets { + if _, ok := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)]; !ok { + return t + } + } + + return nil +} + // withShardError adds shard information to errors returned from the inner QueryService. func (gw *TabletGateway) withShardError(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, - _ string, _ bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error)) error { + _ string, _ bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), +) error { _, err := inner(ctx, target, conn) return NewShardError(err, target) } @@ -449,7 +456,6 @@ func (gw *TabletGateway) getStatsAggregator(target *querypb.Target) *TabletStatu } func (gw *TabletGateway) shuffleTablets(cell string, tablets []*discovery.TabletHealth) { - // Randomly shuffle the list of tablets, putting the same-cell hosts at the front // of the list and the other-cell hosts at the back // From f37cdc88a707bf43d74f352bed76163f81ab90ea Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 11:17:24 -0400 Subject: [PATCH 02/67] Initial implementation of the hash ring Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/hash.go | 43 ++++ go/vt/vtgate/balancer/hashring.go | 182 ++++++++++++++ go/vt/vtgate/balancer/hashring_test.go | 330 +++++++++++++++++++++++++ 3 files changed, 555 insertions(+) create mode 100644 go/vt/vtgate/balancer/hash.go create mode 100644 go/vt/vtgate/balancer/hashring.go create mode 100644 go/vt/vtgate/balancer/hashring_test.go diff --git a/go/vt/vtgate/balancer/hash.go b/go/vt/vtgate/balancer/hash.go new file mode 100644 index 00000000000..ba047488f71 --- /dev/null +++ b/go/vt/vtgate/balancer/hash.go @@ -0,0 +1,43 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "net/http" + + "vitess.io/vitess/go/vt/discovery" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// ConsistentHashBalancer implements the [TabletBalancer] interface. For a given +// session, it will return the same tablet for its duration. The tablet is initially +// selected randomly, with preference to tablets in the local cell. +type ConsistentHashBalancer struct { + // localCell is the cell the gateway is currently running in. + localCell string +} + +// Pick is the main entry point to the balancer. +// +// For a given session, it will return the same tablet for its duration. The tablet is +// initially selected randomly, with preference to tablets in the local cell. +func Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { + return nil +} + +// DebugHandler provides a summary of the consistent hash balancer state. +func DebugHandler(w http.ResponseWriter, r *http.Request) {} diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go new file mode 100644 index 00000000000..bb27d995609 --- /dev/null +++ b/go/vt/vtgate/balancer/hashring.go @@ -0,0 +1,182 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "slices" + "sort" + "strconv" + "sync" + + "github.com/cespare/xxhash/v2" + "vitess.io/vitess/go/vt/discovery" + "vitess.io/vitess/go/vt/topo/topoproto" +) + +// defaultVirtualNodes is the default number of virtual nodes to use in +// the hash ring. +const defaultVirtualNodes = 16 + +// hashRing represents a hash ring of tablets. +type hashRing struct { + mu sync.RWMutex + + // nodes is the sorted list of virtual nodes. + nodes []uint64 + + // numVirtualNodes is the number of virtual nodes each member of + // the hash ring has. + numVirtualNodes int + + // nodeMap is a map from a tablet's hash to the tablet. + nodeMap map[uint64]*discovery.TabletHealth + + // tablets is a "set" of all the tablets currently in the hash ring (by alias). + tablets map[string]struct{} +} + +// newHashRing returns a new hash ring with the default number of virtual nodes. +func newHashRing() *hashRing { + return &hashRing{ + numVirtualNodes: defaultVirtualNodes, + nodeMap: make(map[uint64]*discovery.TabletHealth), + tablets: make(map[string]struct{}), + } +} + +// add adds a tablet to the hash ring. It does not sort the nodes after adding. +func (r *hashRing) add(tablet *discovery.TabletHealth) { + if r.contains(tablet) { + return + } + + // Build the tablet's hashes before locking + hashes := make([]uint64, 0, r.numVirtualNodes) + for i := range r.numVirtualNodes { + hash := buildHash(tablet, i) + hashes = append(hashes, hash) + } + + r.mu.Lock() + defer r.mu.Unlock() + + for _, hash := range hashes { + r.nodes = append(r.nodes, hash) + r.nodeMap[hash] = tablet + } + + r.tablets[tabletAlias(tablet)] = struct{}{} +} + +// remove removes a tablet from the hash ring. +func (r *hashRing) remove(tablet *discovery.TabletHealth) { + if !r.contains(tablet) { + return + } + + // Build the tablet's hashes before locking + hashes := make(map[uint64]struct{}, r.numVirtualNodes) + for i := range r.numVirtualNodes { + hash := buildHash(tablet, i) + hashes[hash] = struct{}{} + } + + r.mu.Lock() + defer r.mu.Unlock() + + for hash := range hashes { + delete(r.nodeMap, hash) + } + + r.nodes = removeNodes(r.nodes, hashes) + delete(r.tablets, tabletAlias(tablet)) +} + +// get returns the tablet for the given key. +func (r *hashRing) get(key string) *discovery.TabletHealth { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.nodes) == 0 { + return nil + } + + // Find the first node greater than or equal to this hash + hash := xxhash.Sum64String(key) + i := sort.Search(len(r.nodes), func(i int) bool { + return r.nodes[i] >= hash + }) + + // Wrap around if needed + if i == len(r.nodes) { + i = 0 + } + + // Return the associated tablet + node := r.nodes[i] + tablet := r.nodeMap[node] + return tablet +} + +// contains checks if a tablet exists in the hash ring. +func (r *hashRing) contains(tablet *discovery.TabletHealth) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, exists := r.tablets[tabletAlias(tablet)] + return exists +} + +// sort sorts the list of nodes. +func (r *hashRing) sort() { + r.mu.Lock() + defer r.mu.Unlock() + + slices.Sort(r.nodes) +} + +// buildHash builds a virtual node hash. +func buildHash(tablet *discovery.TabletHealth, node int) uint64 { + key := tabletAlias(tablet) + "#" + strconv.Itoa(node) + hash := xxhash.Sum64String(key) + + return hash +} + +// tabletAlias returns the tablet's alias as a string. +func tabletAlias(tablet *discovery.TabletHealth) string { + return topoproto.TabletAliasString(tablet.Tablet.Alias) +} + +// removeNodes removes the nodes in the set of hashes from the given list of nodes. +func removeNodes(nodes []uint64, hashes map[uint64]struct{}) []uint64 { + // Update the node list in-place + + writeIdx := 0 + for _, node := range nodes { + // Check if this node belongs to the tablet being removed + _, isTabletNode := hashes[node] + if isTabletNode { + continue + } + + nodes[writeIdx] = node + writeIdx++ + } + + return nodes[:writeIdx] +} diff --git a/go/vt/vtgate/balancer/hashring_test.go b/go/vt/vtgate/balancer/hashring_test.go new file mode 100644 index 00000000000..06ab740dd12 --- /dev/null +++ b/go/vt/vtgate/balancer/hashring_test.go @@ -0,0 +1,330 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "fmt" + "slices" + "strconv" + "sync" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/discovery" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/topo" +) + +func createTestTabletForHashRing(cell string, uid uint32) *discovery.TabletHealth { + tablet := topo.NewTablet(uid, cell, strconv.FormatUint(uint64(uid), 10)) + tablet.PortMap["vt"] = 1 + tablet.PortMap["grpc"] = 2 + tablet.Keyspace = "test_keyspace" + tablet.Shard = "0" + + return &discovery.TabletHealth{ + Tablet: tablet, + Target: &querypb.Target{Keyspace: "test_keyspace", Shard: "0", TabletType: topodatapb.TabletType_REPLICA}, + Serving: true, + Stats: nil, + PrimaryTermStartTime: 0, + } +} + +func TestNewHashRing(t *testing.T) { + ring := newHashRing() + + require.NotNil(t, ring) + require.Equal(t, defaultVirtualNodes, ring.numVirtualNodes) + require.NotNil(t, ring.nodeMap) + require.Empty(t, ring.nodes) + require.Empty(t, ring.nodeMap) +} + +func TestHashRingAdd(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + ring.add(tablet1) + + require.Len(t, ring.nodes, defaultVirtualNodes) + require.Len(t, ring.nodeMap, defaultVirtualNodes) + + for _, hash := range ring.nodes { + require.Equal(t, tablet1, ring.nodeMap[hash]) + } + + ring.add(tablet2) + + require.Len(t, ring.nodes, 2*defaultVirtualNodes) + require.Len(t, ring.nodeMap, 2*defaultVirtualNodes) + + tablet1Count := 0 + tablet2Count := 0 + for _, tablet := range ring.nodeMap { + switch tablet { + case tablet1: + tablet1Count++ + case tablet2: + tablet2Count++ + } + } + require.Equal(t, defaultVirtualNodes, tablet1Count) + require.Equal(t, defaultVirtualNodes, tablet2Count) +} + +func TestHashRingAddDuplicate(t *testing.T) { + ring := newHashRing() + tablet := createTestTabletForHashRing("cell1", 100) + + ring.add(tablet) + originalLen := len(ring.nodes) + + ring.add(tablet) + + require.Len(t, ring.nodes, originalLen) + require.Len(t, ring.nodeMap, originalLen) +} + +func TestHashRingRemove(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + ring.add(tablet1) + ring.add(tablet2) + + require.Len(t, ring.nodes, 2*defaultVirtualNodes) + + ring.remove(tablet1) + + require.Len(t, ring.nodes, defaultVirtualNodes) + require.Len(t, ring.nodeMap, defaultVirtualNodes) + + for _, tablet := range ring.nodeMap { + require.Equal(t, tablet2, tablet) + } +} + +func TestHashRingRemoveNonExistent(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + ring.add(tablet1) + originalLen := len(ring.nodes) + + ring.remove(tablet2) + + require.Len(t, ring.nodes, originalLen) + require.Len(t, ring.nodeMap, originalLen) +} + +func TestHashRingGet(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + result := ring.get("test_key") + require.Nil(t, result) + + ring.add(tablet1) + ring.sort() + + result = ring.get("test_key") + require.NotNil(t, result) + require.Equal(t, tablet1, result) + + ring.add(tablet2) + ring.sort() + + result = ring.get("test_key") + require.NotNil(t, result) + + // Empirically know that "test_key" hashes closest to tablet2 + require.Equal(t, tablet2, result) +} + +func TestHashRingSort(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + ring.add(tablet1) + ring.add(tablet2) + + originalNodes := make([]uint64, len(ring.nodes)) + copy(originalNodes, ring.nodes) + + ring.sort() + + require.True(t, slices.IsSorted(ring.nodes)) + require.Equal(t, len(originalNodes), len(ring.nodes)) + + for _, node := range originalNodes { + require.Contains(t, ring.nodes, node) + } +} + +func TestBuildHash(t *testing.T) { + tablet := createTestTabletForHashRing("cell1", 100) + + hash1 := buildHash(tablet, 0) + hash2 := buildHash(tablet, 1) + hash3 := buildHash(tablet, 0) + + require.NotEqual(t, hash1, hash2) + require.Equal(t, hash1, hash3) + + tablet2 := createTestTabletForHashRing("cell2", 200) + hash4 := buildHash(tablet2, 0) + + require.NotEqual(t, hash1, hash4) +} + +func TestHashRingAddRemoveSequence(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + tablet3 := createTestTabletForHashRing("cell3", 300) + + ring.add(tablet1) + ring.add(tablet2) + ring.add(tablet3) + ring.sort() + + key := "test_sequence" + + initialTablet := ring.get(key) + require.NotNil(t, initialTablet) + + ring.remove(tablet2) + ring.sort() + + afterRemovalTablet := ring.get(key) + require.NotNil(t, afterRemovalTablet) + + ring.add(tablet2) + ring.sort() + + afterReaddTablet := ring.get(key) + require.NotNil(t, afterReaddTablet) + + require.Equal(t, initialTablet, afterReaddTablet) +} + +func TestHashRingWrapAround(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + + // Create a synthetic scenario where we know the hash will be larger + ring.nodes = []uint64{1000, 2000, 3000} // Small values + ring.nodeMap = make(map[uint64]*discovery.TabletHealth) + ring.nodeMap[1000] = tablet1 + ring.nodeMap[2000] = tablet2 + ring.nodeMap[3000] = tablet1 + + // Any large hash should wrap around to the first node + result := ring.get("this_should_wrap_around_with_large_hash") + require.NotNil(t, result) + require.Contains(t, []*discovery.TabletHealth{tablet1, tablet2}, result) +} + +func TestHashRingRemoveAllTablets(t *testing.T) { + ring := newHashRing() + tablets := make([]*discovery.TabletHealth, 3) + + for i := range 3 { + tablets[i] = createTestTabletForHashRing(fmt.Sprintf("cell%d", i), uint32(100+i)) + ring.add(tablets[i]) + } + + ring.sort() + + for _, tablet := range tablets { + ring.remove(tablet) + } + + require.Empty(t, ring.nodes) + require.Empty(t, ring.nodeMap) + require.Nil(t, ring.get("any_key")) +} + +func TestHashRingMultipleAddSameTablet(t *testing.T) { + ring := newHashRing() + tablet := createTestTabletForHashRing("cell1", 100) + + // Add the same tablet multiple times + for range 5 { + ring.add(tablet) + } + + // Should still only have defaultVirtualNodes entries + require.Len(t, ring.nodes, defaultVirtualNodes) + require.Len(t, ring.nodeMap, defaultVirtualNodes) +} + +func TestHashRingGetAfterRemove(t *testing.T) { + ring := newHashRing() + tablet1 := createTestTabletForHashRing("cell1", 100) + tablet2 := createTestTabletForHashRing("cell2", 200) + tablet3 := createTestTabletForHashRing("cell3", 300) + + ring.add(tablet1) + ring.add(tablet2) + ring.add(tablet3) + ring.sort() + + // Empirically know that this hashes closest to tablet3 + got := ring.get("key") + require.Equal(t, tablet3, got) + + // Remove tablet3 + ring.remove(tablet3) + + got = ring.get("key") + require.NotEqual(t, tablet3, got) +} + +func TestHashRingConcurrentGetOperations(t *testing.T) { + ring := newHashRing() + tablets := make([]*discovery.TabletHealth, 5) + + for i := range 5 { + tablets[i] = createTestTabletForHashRing(fmt.Sprintf("cell%d", i), uint32(100+i)) + ring.add(tablets[i]) + } + ring.sort() + + var wg sync.WaitGroup + numGoroutines := 1000 + wg.Add(numGoroutines) + for i := range numGoroutines { + go func(i int) { + defer wg.Done() + key := fmt.Sprintf("concurrent_key_%d", i) + tablet := ring.get(key) + require.NotNil(t, tablet) + }(i) + } + + wg.Wait() +} From 114feceeebdd55d2ee251d98455fbfc1bcbaea47 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 12:07:43 -0400 Subject: [PATCH 03/67] Set up initial health check subscription Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/hash.go | 41 +++++++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/balancer/hash.go b/go/vt/vtgate/balancer/hash.go index ba047488f71..9dbd352a74d 100644 --- a/go/vt/vtgate/balancer/hash.go +++ b/go/vt/vtgate/balancer/hash.go @@ -17,6 +17,7 @@ limitations under the License. package balancer import ( + "context" "net/http" "vitess.io/vitess/go/vt/discovery" @@ -29,15 +30,51 @@ import ( type ConsistentHashBalancer struct { // localCell is the cell the gateway is currently running in. localCell string + + // hc is the tablet health check. + hc discovery.HealthCheck + + // // hcChan is the channel to receive tablet health events on. + // hcChan chan *discovery.TabletHealth +} + +// NewConsistentHashBalancer creates a new consistent hash balancer. +func NewConsistentHashBalancer(ctx context.Context, localCell string, hc discovery.HealthCheck) TabletBalancer { + b := &ConsistentHashBalancer{ + localCell: localCell, + hc: hc, + } + + // Set up health check subscription + hcChan := b.hc.Subscribe("ConsistentHashBalancer") + go b.watchHealthCheck(ctx, hcChan) + + return b } // Pick is the main entry point to the balancer. // // For a given session, it will return the same tablet for its duration. The tablet is // initially selected randomly, with preference to tablets in the local cell. -func Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { +func (b *ConsistentHashBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { return nil } // DebugHandler provides a summary of the consistent hash balancer state. -func DebugHandler(w http.ResponseWriter, r *http.Request) {} +func (b *ConsistentHashBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) {} + +// watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. +func (b *ConsistentHashBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { + for { + select { + case <-ctx.Done(): + b.hc.Unsubscribe(hcChan) + close(hcChan) + return + case th := <-hcChan: + if th == nil { + return + } + } + } +} From 4742e8d891e97239fc5893220513b3d45b92d2ff Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 19:57:38 -0400 Subject: [PATCH 04/67] Rename to `SessionBalancer` and create initial tests Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/hash.go | 80 -------- go/vt/vtgate/balancer/hashring.go | 3 +- go/vt/vtgate/balancer/session.go | 143 +++++++++++++ go/vt/vtgate/balancer/session_test.go | 276 ++++++++++++++++++++++++++ 4 files changed, 421 insertions(+), 81 deletions(-) delete mode 100644 go/vt/vtgate/balancer/hash.go create mode 100644 go/vt/vtgate/balancer/session.go create mode 100644 go/vt/vtgate/balancer/session_test.go diff --git a/go/vt/vtgate/balancer/hash.go b/go/vt/vtgate/balancer/hash.go deleted file mode 100644 index 9dbd352a74d..00000000000 --- a/go/vt/vtgate/balancer/hash.go +++ /dev/null @@ -1,80 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package balancer - -import ( - "context" - "net/http" - - "vitess.io/vitess/go/vt/discovery" - querypb "vitess.io/vitess/go/vt/proto/query" -) - -// ConsistentHashBalancer implements the [TabletBalancer] interface. For a given -// session, it will return the same tablet for its duration. The tablet is initially -// selected randomly, with preference to tablets in the local cell. -type ConsistentHashBalancer struct { - // localCell is the cell the gateway is currently running in. - localCell string - - // hc is the tablet health check. - hc discovery.HealthCheck - - // // hcChan is the channel to receive tablet health events on. - // hcChan chan *discovery.TabletHealth -} - -// NewConsistentHashBalancer creates a new consistent hash balancer. -func NewConsistentHashBalancer(ctx context.Context, localCell string, hc discovery.HealthCheck) TabletBalancer { - b := &ConsistentHashBalancer{ - localCell: localCell, - hc: hc, - } - - // Set up health check subscription - hcChan := b.hc.Subscribe("ConsistentHashBalancer") - go b.watchHealthCheck(ctx, hcChan) - - return b -} - -// Pick is the main entry point to the balancer. -// -// For a given session, it will return the same tablet for its duration. The tablet is -// initially selected randomly, with preference to tablets in the local cell. -func (b *ConsistentHashBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { - return nil -} - -// DebugHandler provides a summary of the consistent hash balancer state. -func (b *ConsistentHashBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) {} - -// watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. -func (b *ConsistentHashBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { - for { - select { - case <-ctx.Done(): - b.hc.Unsubscribe(hcChan) - close(hcChan) - return - case th := <-hcChan: - if th == nil { - return - } - } - } -} diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go index bb27d995609..848b1139b55 100644 --- a/go/vt/vtgate/balancer/hashring.go +++ b/go/vt/vtgate/balancer/hashring.go @@ -58,7 +58,7 @@ func newHashRing() *hashRing { } } -// add adds a tablet to the hash ring. It does not sort the nodes after adding. +// add adds a tablet to the hash ring. func (r *hashRing) add(tablet *discovery.TabletHealth) { if r.contains(tablet) { return @@ -79,6 +79,7 @@ func (r *hashRing) add(tablet *discovery.TabletHealth) { r.nodeMap[hash] = tablet } + slices.Sort(r.nodes) r.tablets[tabletAlias(tablet)] = struct{}{} } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go new file mode 100644 index 00000000000..c4224a64917 --- /dev/null +++ b/go/vt/vtgate/balancer/session.go @@ -0,0 +1,143 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "context" + "fmt" + "net/http" + "sync" + + "vitess.io/vitess/go/vt/discovery" + querypb "vitess.io/vitess/go/vt/proto/query" +) + +// SessionBalancer implements the [TabletBalancer] interface. For a given +// session, it will return the same tablet for its duration. The tablet is initially +// selected randomly, with preference to tablets in the local cell. +type SessionBalancer struct { + // localCell is the cell the gateway is currently running in. + localCell string + + // hc is the tablet health check. + hc discovery.HealthCheck + + mu sync.RWMutex + + // localRings are the hash rings created for each target. It contains only tablets + // local to [localCell]. + localRings map[discovery.KeyspaceShardTabletType]*hashRing + + // externalRings are the hash rings created for each target. It contains only tablets + // external to [localCell]. + externalRings map[discovery.KeyspaceShardTabletType]*hashRing +} + +// NewSessionBalancer creates a new session balancer. +func NewSessionBalancer(ctx context.Context, localCell string, hc discovery.HealthCheck) TabletBalancer { + b := &SessionBalancer{ + localCell: localCell, + hc: hc, + localRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), + externalRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), + } + + // Set up health check subscription + hcChan := b.hc.Subscribe("SessionBalancer") + go b.watchHealthCheck(ctx, hcChan) + + return b +} + +// Pick is the main entry point to the balancer. +// +// For a given session, it will return the same tablet for its duration. The tablet is +// initially selected randomly, with preference to tablets in the local cell. +func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { + return nil +} + +// DebugHandler provides a summary of the session balancer state. +func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(w, "Session Balancer\n") + fmt.Fprintf(w, "================\n") + fmt.Fprintf(w, "Local Cell: %s\n\n", b.localCell) + + b.mu.RLock() + defer b.mu.RUnlock() + + fmt.Fprintf(w, "Local Rings (%d):\n", len(b.localRings)) + for key := range b.localRings { + fmt.Fprintf(w, " - %s\n", key) + } + + fmt.Fprintf(w, "\nExternal Rings (%d):\n", len(b.externalRings)) + for key := range b.externalRings { + fmt.Fprintf(w, " - %s\n", key) + } +} + +// watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. +func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { + for { + select { + case <-ctx.Done(): + b.hc.Unsubscribe(hcChan) + return + case tablet := <-hcChan: + if tablet == nil { + return + } + + b.onTabletHealthChange(tablet) + } + } +} + +// onTabletHealthChange is the handler for tablet health events. +func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { + b.mu.Lock() + defer b.mu.Unlock() + + var ring *hashRing + if tablet.Target.Cell == b.localCell { + ring = getRing(b.localRings, tablet) + } else { + ring = getRing(b.externalRings, tablet) + } + + if tablet.Serving { + ring.add(tablet) + ring.sort() + } else { + ring.remove(tablet) + } +} + +// getRing gets or creates a new ring for the given tablet. +func getRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tablet *discovery.TabletHealth) *hashRing { + key := discovery.KeyFromTarget(tablet.Target) + + ring, exists := rings[key] + if !exists { + ring = newHashRing() + rings[key] = ring + } + + return ring +} diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go new file mode 100644 index 00000000000..7858613db8c --- /dev/null +++ b/go/vt/vtgate/balancer/session_test.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" + + "vitess.io/vitess/go/vt/discovery" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" +) + +func newSessionBalancer(t *testing.T) (*SessionBalancer, chan *discovery.TabletHealth) { + ctx := t.Context() + + ch := make(chan *discovery.TabletHealth, 10) + hc := discovery.NewFakeHealthCheck(ch) + b := NewSessionBalancer(ctx, "local", hc) + sb := b.(*SessionBalancer) + + return sb, ch +} + +func TestNewSessionBalancer(t *testing.T) { + b, _ := newSessionBalancer(t) + + require.Equal(t, "local", b.localCell) + require.NotNil(t, b.hc) + require.NotNil(t, b.localRings) + require.Len(t, b.localRings, 0) + require.NotNil(t, b.externalRings) + require.Len(t, b.externalRings, 0) +} + +func TestPickNoTablets(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + result := b.Pick(target, nil, nil) + require.Nil(t, result) +} + +func TestPickLocalOnly(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + localTablet1 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + localTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet1 + hcChan <- localTablet2 + + // Pick for a specific session hash + opts := &PickOpts{sessionHash: 12345} + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + + // Pick again with same session hash, should return same tablet + picked2 := b.Pick(target, nil, opts) + require.Equal(t, picked1, picked2) + + // Pick with different session hash, empirically know that it should return different tablet + opts = &PickOpts{sessionHash: 67890} + picked3 := b.Pick(target, nil, opts) + require.NotNil(t, picked3) + require.NotEqual(t, picked2, picked3) +} + +func TestPickPreferLocal(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + localTablet1 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + localTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + externalTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + hcChan <- localTablet1 + hcChan <- localTablet2 + hcChan <- externalTablet + + // Pick should prefer local cell + opts := &PickOpts{sessionHash: 12345} + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + require.Equal(t, "local", picked1.Target.Cell) +} + +func TestPickNoLocal(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + externalTablet1 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + externalTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 201, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + hcChan <- externalTablet1 + hcChan <- externalTablet2 + + // Pick should return external cell since there are no local cells + opts := &PickOpts{sessionHash: 12345} + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + require.Equal(t, "external", picked1.Target.Cell) +} + +func TestDebugHandler(t *testing.T) { + ctx := t.Context() + + ch := make(chan *discovery.TabletHealth, 10) + hc := discovery.NewFakeHealthCheck(ch) + b := NewSessionBalancer(ctx, "local", hc) + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/debug", nil) + + b.DebugHandler(w, r) + require.Equal(t, http.StatusOK, w.Code) +} From 5be7ac790128e15905418ffb4adabc0eac77f2a8 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 21:39:55 -0400 Subject: [PATCH 05/67] Improvements and add more tests Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 2 +- go/vt/vtgate/balancer/hashring.go | 17 +- go/vt/vtgate/balancer/session.go | 99 ++++++++--- go/vt/vtgate/balancer/session_test.go | 226 ++++++++++++++++++++++++-- go/vt/vtgate/tabletgateway.go | 8 +- 5 files changed, 320 insertions(+), 32 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index c55f205115e..764f5b02ac1 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -100,7 +100,7 @@ type TabletBalancer interface { // implementations are added the Pick signature does not get overly long. type PickOpts struct { // sessionHash is the hash of the current session UUID. - sessionHash uint64 + sessionHash *uint64 } func NewTabletBalancer(localCell string, vtGateCells []string) TabletBalancer { diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go index 848b1139b55..5267482b6f5 100644 --- a/go/vt/vtgate/balancer/hashring.go +++ b/go/vt/vtgate/balancer/hashring.go @@ -23,6 +23,7 @@ import ( "sync" "github.com/cespare/xxhash/v2" + "vitess.io/vitess/go/vt/discovery" "vitess.io/vitess/go/vt/topo/topoproto" ) @@ -116,8 +117,21 @@ func (r *hashRing) get(key string) *discovery.TabletHealth { return nil } - // Find the first node greater than or equal to this hash hash := xxhash.Sum64String(key) + tablet := r.getHashed(hash) + return tablet +} + +// getHashed returns the tablet for the given hash. +func (r *hashRing) getHashed(hash uint64) *discovery.TabletHealth { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.nodes) == 0 { + return nil + } + + // Find the first node greater than or equal to this hash i := sort.Search(len(r.nodes), func(i int) bool { return r.nodes[i] >= hash }) @@ -130,6 +144,7 @@ func (r *hashRing) get(key string) *discovery.TabletHealth { // Return the associated tablet node := r.nodes[i] tablet := r.nodeMap[node] + return tablet } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index c4224a64917..e5923af7070 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -19,7 +19,10 @@ package balancer import ( "context" "fmt" + "maps" "net/http" + "slices" + "strings" "sync" "vitess.io/vitess/go/vt/discovery" @@ -65,31 +68,46 @@ func NewSessionBalancer(ctx context.Context, localCell string, hc discovery.Heal // Pick is the main entry point to the balancer. // -// For a given session, it will return the same tablet for its duration. The tablet is -// initially selected randomly, with preference to tablets in the local cell. +// For a given session, it will return the same tablet for its duration, with preference to tablets +// in the local cell. +// +// NOTE: this currently won't consider any invalid tablets. This means we'll keep returning the same +// invalid tablet on subsequent tries. We can improve this by maybe returning a random tablet (local +// cell preferred) when the session hash falls on an invalid tablet. func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { - return nil + if opts == nil || opts.sessionHash == nil { + // No session hash. Returning nil here will allow the gateway to select a random + // tablet instead. + return nil + } + + sessionHash := *opts.sessionHash + + b.mu.RLock() + defer b.mu.RUnlock() + + // Try to find a tablet in the local cell first + tablet := getFromRing(b.localRings, target, sessionHash) + if tablet != nil { + return tablet + } + + // If we didn't find a tablet in the local cell, try external cells + tablet = getFromRing(b.externalRings, target, sessionHash) + return tablet } // DebugHandler provides a summary of the session balancer state. func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(w, "Session Balancer\n") + fmt.Fprintf(w, "Session balancer\n") fmt.Fprintf(w, "================\n") - fmt.Fprintf(w, "Local Cell: %s\n\n", b.localCell) + fmt.Fprintf(w, "Local cell: %s\n\n", b.localCell) b.mu.RLock() defer b.mu.RUnlock() - fmt.Fprintf(w, "Local Rings (%d):\n", len(b.localRings)) - for key := range b.localRings { - fmt.Fprintf(w, " - %s\n", key) - } - - fmt.Fprintf(w, "\nExternal Rings (%d):\n", len(b.externalRings)) - for key := range b.externalRings { - fmt.Fprintf(w, " - %s\n", key) - } + fmt.Fprint(w, b.print()) } // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. @@ -116,21 +134,20 @@ func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { var ring *hashRing if tablet.Target.Cell == b.localCell { - ring = getRing(b.localRings, tablet) + ring = getOrCreateRing(b.localRings, tablet) } else { - ring = getRing(b.externalRings, tablet) + ring = getOrCreateRing(b.externalRings, tablet) } if tablet.Serving { ring.add(tablet) - ring.sort() } else { ring.remove(tablet) } } -// getRing gets or creates a new ring for the given tablet. -func getRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tablet *discovery.TabletHealth) *hashRing { +// getOrCreateRing gets or creates a new ring for the given tablet. +func getOrCreateRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tablet *discovery.TabletHealth) *hashRing { key := discovery.KeyFromTarget(tablet.Target) ring, exists := rings[key] @@ -141,3 +158,47 @@ func getRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tablet *disc return ring } + +// print returns a string representation of the session balancer state for debugging. +func (b *SessionBalancer) print() string { + b.mu.RLock() + defer b.mu.RUnlock() + + sb := strings.Builder{} + + sb.WriteString("Local rings:\n") + if len(b.localRings) == 0 { + sb.WriteString("\tNo local rings\n") + } + + for target, ring := range b.localRings { + sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) + sb.WriteString(fmt.Sprintf("\t\tNode count: %d\n", len(ring.nodes))) + sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(ring.tablets)))) + } + + sb.WriteString("External rings:\n") + if len(b.externalRings) == 0 { + sb.WriteString("\tNo external rings\n") + } + + for target, ring := range b.externalRings { + sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) + sb.WriteString(fmt.Sprintf("\t\tNode count: %d\n", len(ring.nodes))) + sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(ring.tablets)))) + } + + return sb.String() +} + +// getFromRing gets a tablet from the respective ring for the given target and session hash. +func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, sessionHash uint64) *discovery.TabletHealth { + key := discovery.KeyFromTarget(target) + + ring, exists := rings[key] + if !exists { + return nil + } + + return ring.getHashed(sessionHash) +} diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 7858613db8c..9eced92b24a 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -17,12 +17,13 @@ limitations under the License. package balancer import ( + "fmt" "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" - "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/ext" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" @@ -52,7 +53,7 @@ func TestNewSessionBalancer(t *testing.T) { } func TestPickNoTablets(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b, _ := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -61,7 +62,8 @@ func TestPickNoTablets(t *testing.T) { Cell: "local", } - result := b.Pick(target, nil, nil) + opts := sessionHash(12345) + result := b.Pick(target, nil, opts) require.Nil(t, result) } @@ -114,20 +116,23 @@ func TestPickLocalOnly(t *testing.T) { hcChan <- localTablet1 hcChan <- localTablet2 + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + // Pick for a specific session hash - opts := &PickOpts{sessionHash: 12345} + opts := sessionHash(12345) picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) // Pick again with same session hash, should return same tablet picked2 := b.Pick(target, nil, opts) - require.Equal(t, picked1, picked2) + require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) - // Pick with different session hash, empirically know that it should return different tablet - opts = &PickOpts{sessionHash: 67890} + // Pick with different session hash, empirically know that it should return tablet2 + opts = sessionHash(5018141287610575993) picked3 := b.Pick(target, nil, opts) require.NotNil(t, picked3) - require.NotEqual(t, picked2, picked3) + require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked3))) } func TestPickPreferLocal(t *testing.T) { @@ -198,8 +203,11 @@ func TestPickPreferLocal(t *testing.T) { hcChan <- localTablet2 hcChan <- externalTablet + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + // Pick should prefer local cell - opts := &PickOpts{sessionHash: 12345} + opts := sessionHash(5018141287610575993) picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) @@ -254,13 +262,207 @@ func TestPickNoLocal(t *testing.T) { hcChan <- externalTablet1 hcChan <- externalTablet2 + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + // Pick should return external cell since there are no local cells - opts := &PickOpts{sessionHash: 12345} + opts := sessionHash(12345) picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) } +func TestTabletNotServing(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + } + + localTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + externalTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + hcChan <- localTablet + hcChan <- externalTablet + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + opts := sessionHash(5018141287610575993) + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + + // Local tablet goes out of serving + localTablet.Serving = false + hcChan <- localTablet + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + // Should not pick the local tablet again + picked2 := b.Pick(target, nil, opts) + require.NotEqual(t, picked1, picked2) +} + +func TestNewLocalTablet(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + } + + localTablet1 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet1 + + time.Sleep(10 * time.Millisecond) + + opts := sessionHash(5018141287610575993) + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + + localTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet2 + + time.Sleep(10 * time.Millisecond) + + picked2 := b.Pick(target, nil, opts) + require.NotNil(t, picked2) + require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) +} + +func TestNewExternalTablet(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + } + + externalTablet1 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- externalTablet1 + + time.Sleep(10 * time.Millisecond) + + opts := sessionHash(5018141287610575993) + picked1 := b.Pick(target, nil, opts) + require.NotNil(t, picked1) + + externalTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- externalTablet2 + + time.Sleep(10 * time.Millisecond) + + picked2 := b.Pick(target, nil, opts) + require.NotNil(t, picked2) + require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) +} + func TestDebugHandler(t *testing.T) { ctx := t.Context() @@ -274,3 +476,7 @@ func TestDebugHandler(t *testing.T) { b.DebugHandler(w, r) require.Equal(t, http.StatusOK, w.Code) } + +func sessionHash(i uint64) *PickOpts { + return &PickOpts{sessionHash: &i} +} diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 82c2849af9f..3e19d6dae9c 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -394,6 +394,8 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, // getBalancerTablet selects a tablet for the given query target, using the configured balancer if enabled. Otherwise, it will // select a random tablet, with preference to the local cell. func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablets map[string]bool, tablets []*discovery.TabletHealth) *discovery.TabletHealth { + var tablet *discovery.TabletHealth + useBalancer := balancerEnabled if balancerEnabled && len(balancerKeyspaces) > 0 { useBalancer = slices.Contains(balancerKeyspaces, target.Keyspace) @@ -409,7 +411,11 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablet }) } - return gw.balancer.Pick(target, tablets, nil) + tablet = gw.balancer.Pick(target, tablets, nil) + } + + if tablet != nil { + return tablet } // Otherwise, randomly select a tablet, with preference to the local cell From 80eb92ff67a59ddfaccf0e8b41073a3ebc37f181 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 21:48:31 -0400 Subject: [PATCH 06/67] Clarify comment Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index e5923af7070..c2ad1cc0591 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -29,9 +29,9 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -// SessionBalancer implements the [TabletBalancer] interface. For a given -// session, it will return the same tablet for its duration. The tablet is initially -// selected randomly, with preference to tablets in the local cell. +// SessionBalancer implements the [TabletBalancer] interface. For a given session, +// it will return the same tablet for its duration, with preference to tablets in +// the local cell. type SessionBalancer struct { // localCell is the cell the gateway is currently running in. localCell string From c421c3d12e0913bceb16cb7e290591e0003bd146 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 21:49:45 -0400 Subject: [PATCH 07/67] Remove brackets Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index c2ad1cc0591..01331cb6308 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -29,7 +29,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -// SessionBalancer implements the [TabletBalancer] interface. For a given session, +// SessionBalancer implements the TabletBalancer interface. For a given session, // it will return the same tablet for its duration, with preference to tablets in // the local cell. type SessionBalancer struct { @@ -42,11 +42,11 @@ type SessionBalancer struct { mu sync.RWMutex // localRings are the hash rings created for each target. It contains only tablets - // local to [localCell]. + // local to localCell. localRings map[discovery.KeyspaceShardTabletType]*hashRing // externalRings are the hash rings created for each target. It contains only tablets - // external to [localCell]. + // external to localCell. externalRings map[discovery.KeyspaceShardTabletType]*hashRing } From 34ca9e97fe6d3a4c52bf1234706dc6fc28850c33 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 16 Aug 2025 21:53:27 -0400 Subject: [PATCH 08/67] Clarify comment Signed-off-by: Mohamed Hamza --- go/vt/vtgate/tabletgateway.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 3e19d6dae9c..3a3fdcfbfc4 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -418,7 +418,8 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablet return tablet } - // Otherwise, randomly select a tablet, with preference to the local cell + // If the balancer isn't enabled, or it didn't return a tablet, randomly select a + // tablet, with preference to the local cell gw.shuffleTablets(gw.localCell, tablets) // skip tablets we tried before From 92abecb3227348fc93bd2ded5f89ccf7b15bfe56 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sun, 17 Aug 2025 09:32:30 -0400 Subject: [PATCH 09/67] Add session hash to session proto Signed-off-by: Mohamed Hamza --- go/vt/proto/vtgate/vtgate.pb.go | 19 ++++++++++++---- go/vt/proto/vtgate/vtgate_vtproto.pb.go | 30 +++++++++++++++++++++++++ go/vt/vtgate/plugin_mysql_server.go | 6 ++++- proto/vtgate.proto | 4 ++++ 4 files changed, 54 insertions(+), 5 deletions(-) diff --git a/go/vt/proto/vtgate/vtgate.pb.go b/go/vt/proto/vtgate/vtgate.pb.go index cfbd3e73fd2..2fca80cda61 100644 --- a/go/vt/proto/vtgate/vtgate.pb.go +++ b/go/vt/proto/vtgate/vtgate.pb.go @@ -226,8 +226,11 @@ type Session struct { // MigrationContext MigrationContext string `protobuf:"bytes,27,opt,name=migration_context,json=migrationContext,proto3" json:"migration_context,omitempty"` ErrorUntilRollback bool `protobuf:"varint,28,opt,name=error_until_rollback,json=errorUntilRollback,proto3" json:"error_until_rollback,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same + // tablet. + SessionHash uint64 `protobuf:"varint,29,opt,name=SessionHash,proto3" json:"SessionHash,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Session) Reset() { @@ -449,6 +452,13 @@ func (x *Session) GetErrorUntilRollback() bool { return false } +func (x *Session) GetSessionHash() uint64 { + if x != nil { + return x.SessionHash + } + return 0 +} + // PrepareData keeps the prepared statement and other information related for execution of it. type PrepareData struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1906,7 +1916,7 @@ var File_vtgate_proto protoreflect.FileDescriptor const file_vtgate_proto_rawDesc = "" + "\n" + - "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\xce\x0f\n" + + "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\xf0\x0f\n" + "\aSession\x12%\n" + "\x0ein_transaction\x18\x01 \x01(\bR\rinTransaction\x12C\n" + "\x0eshard_sessions\x18\x02 \x03(\v2\x1c.vtgate.Session.ShardSessionR\rshardSessions\x12\x1e\n" + @@ -1940,7 +1950,8 @@ const file_vtgate_proto_rawDesc = "" + "\rquery_timeout\x18\x19 \x01(\x03R\fqueryTimeout\x12R\n" + "\x11prepare_statement\x18\x1a \x03(\v2%.vtgate.Session.PrepareStatementEntryR\x10prepareStatement\x12+\n" + "\x11migration_context\x18\x1b \x01(\tR\x10migrationContext\x120\n" + - "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x1a\xf9\x01\n" + + "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x12 \n" + + "\vSessionHash\x18\x1d \x01(\x04R\vSessionHash\x1a\xf9\x01\n" + "\fShardSession\x12%\n" + "\x06target\x18\x01 \x01(\v2\r.query.TargetR\x06target\x12%\n" + "\x0etransaction_id\x18\x02 \x01(\x03R\rtransactionId\x128\n" + diff --git a/go/vt/proto/vtgate/vtgate_vtproto.pb.go b/go/vt/proto/vtgate/vtgate_vtproto.pb.go index 7e1a5679d7e..d5c5ba27bab 100644 --- a/go/vt/proto/vtgate/vtgate_vtproto.pb.go +++ b/go/vt/proto/vtgate/vtgate_vtproto.pb.go @@ -70,6 +70,7 @@ func (m *Session) CloneVT() *Session { r.QueryTimeout = m.QueryTimeout r.MigrationContext = m.MigrationContext r.ErrorUntilRollback = m.ErrorUntilRollback + r.SessionHash = m.SessionHash if rhs := m.ShardSessions; rhs != nil { tmpContainer := make([]*Session_ShardSession, len(rhs)) for k, v := range rhs { @@ -689,6 +690,13 @@ func (m *Session) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if m.SessionHash != 0 { + i = protohelpers.EncodeVarint(dAtA, i, uint64(m.SessionHash)) + i-- + dAtA[i] = 0x1 + i-- + dAtA[i] = 0xe8 + } if m.ErrorUntilRollback { i-- if m.ErrorUntilRollback { @@ -2456,6 +2464,9 @@ func (m *Session) SizeVT() (n int) { if m.ErrorUntilRollback { n += 3 } + if m.SessionHash != 0 { + n += 2 + protohelpers.SizeOfVarint(uint64(m.SessionHash)) + } n += len(m.unknownFields) return n } @@ -4270,6 +4281,25 @@ func (m *Session) UnmarshalVT(dAtA []byte) error { } } m.ErrorUntilRollback = bool(v != 0) + case 29: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field SessionHash", wireType) + } + m.SessionHash = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.SessionHash |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 5e6c2c689c8..339ac2d08ea 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -29,6 +29,7 @@ import ( "syscall" "time" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/spf13/pflag" @@ -181,7 +182,8 @@ var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions func startSpanTestable(ctx context.Context, query, label string, newSpan func(context.Context, string) (trace.Span, context.Context), - newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) { + newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error), +) (trace.Span, context.Context, error) { _, comments := sqlparser.SplitMarginComments(query) match := r.FindStringSubmatch(comments.Leading) span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) @@ -517,6 +519,7 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { session, _ := c.ClientData.(*vtgatepb.Session) if session == nil { u, _ := uuid.NewUUID() + sessionHash := xxhash.Sum64String(u.String()) session = &vtgatepb.Session{ Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, @@ -528,6 +531,7 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { DDLStrategy: defaultDDLStrategy, MigrationContext: "", SessionUUID: u.String(), + SessionHash: sessionHash, EnableSystemSettings: sysVarSetEnabled, } if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { diff --git a/proto/vtgate.proto b/proto/vtgate.proto index 06edb7feb62..b61d229a736 100644 --- a/proto/vtgate.proto +++ b/proto/vtgate.proto @@ -162,6 +162,10 @@ message Session { string migration_context = 27; bool error_until_rollback = 28; + + // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same + // tablet. + uint64 SessionHash = 29; } // PrepareData keeps the prepared statement and other information related for execution of it. From 41141d7c7506cf57bc2d027dba64c0d7ff871be9 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sun, 17 Aug 2025 09:34:21 -0400 Subject: [PATCH 10/67] Clarify comment Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 764f5b02ac1..7fbabfa644c 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -96,8 +96,7 @@ type TabletBalancer interface { DebugHandler(w http.ResponseWriter, r *http.Request) } -// PickOpts are balancer options that are passed into Pick. This exists so that as more balancer -// implementations are added the Pick signature does not get overly long. +// PickOpts are balancer options that are passed into Pick. type PickOpts struct { // sessionHash is the hash of the current session UUID. sessionHash *uint64 From 70ed65f503865ec8effa759c5f01ae70dd0c40a3 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 18 Aug 2025 11:18:27 -0400 Subject: [PATCH 11/67] Update `SessionHash` to be optional Signed-off-by: Mohamed Hamza --- go/vt/proto/vtgate/vtgate.pb.go | 16 +++++---- go/vt/proto/vtgate/vtgate_vtproto.pb.go | 18 +++++++---- go/vt/vtgate/balancer/session_test.go | 43 +++++++++++++++++++++++++ go/vt/vtgate/plugin_mysql_server.go | 2 +- proto/vtgate.proto | 2 +- 5 files changed, 65 insertions(+), 16 deletions(-) diff --git a/go/vt/proto/vtgate/vtgate.pb.go b/go/vt/proto/vtgate/vtgate.pb.go index 2fca80cda61..087950da386 100644 --- a/go/vt/proto/vtgate/vtgate.pb.go +++ b/go/vt/proto/vtgate/vtgate.pb.go @@ -228,7 +228,7 @@ type Session struct { ErrorUntilRollback bool `protobuf:"varint,28,opt,name=error_until_rollback,json=errorUntilRollback,proto3" json:"error_until_rollback,omitempty"` // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same // tablet. - SessionHash uint64 `protobuf:"varint,29,opt,name=SessionHash,proto3" json:"SessionHash,omitempty"` + SessionHash *uint64 `protobuf:"varint,29,opt,name=SessionHash,proto3,oneof" json:"SessionHash,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -453,8 +453,8 @@ func (x *Session) GetErrorUntilRollback() bool { } func (x *Session) GetSessionHash() uint64 { - if x != nil { - return x.SessionHash + if x != nil && x.SessionHash != nil { + return *x.SessionHash } return 0 } @@ -1916,7 +1916,7 @@ var File_vtgate_proto protoreflect.FileDescriptor const file_vtgate_proto_rawDesc = "" + "\n" + - "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\xf0\x0f\n" + + "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\x85\x10\n" + "\aSession\x12%\n" + "\x0ein_transaction\x18\x01 \x01(\bR\rinTransaction\x12C\n" + "\x0eshard_sessions\x18\x02 \x03(\v2\x1c.vtgate.Session.ShardSessionR\rshardSessions\x12\x1e\n" + @@ -1950,8 +1950,8 @@ const file_vtgate_proto_rawDesc = "" + "\rquery_timeout\x18\x19 \x01(\x03R\fqueryTimeout\x12R\n" + "\x11prepare_statement\x18\x1a \x03(\v2%.vtgate.Session.PrepareStatementEntryR\x10prepareStatement\x12+\n" + "\x11migration_context\x18\x1b \x01(\tR\x10migrationContext\x120\n" + - "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x12 \n" + - "\vSessionHash\x18\x1d \x01(\x04R\vSessionHash\x1a\xf9\x01\n" + + "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x12%\n" + + "\vSessionHash\x18\x1d \x01(\x04H\x00R\vSessionHash\x88\x01\x01\x1a\xf9\x01\n" + "\fShardSession\x12%\n" + "\x06target\x18\x01 \x01(\v2\r.query.TargetR\x06target\x12%\n" + "\x0etransaction_id\x18\x02 \x01(\x03R\rtransactionId\x128\n" + @@ -1971,7 +1971,8 @@ const file_vtgate_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\x03R\x05value:\x028\x01\x1aX\n" + "\x15PrepareStatementEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.vtgate.PrepareDataR\x05value:\x028\x01J\x04\b\x03\x10\x04\"]\n" + + "\x05value\x18\x02 \x01(\v2\x13.vtgate.PrepareDataR\x05value:\x028\x01B\x0e\n" + + "\f_SessionHashJ\x04\b\x03\x10\x04\"]\n" + "\vPrepareData\x12+\n" + "\x11prepare_statement\x18\x01 \x01(\tR\x10prepareStatement\x12!\n" + "\fparams_count\x18\x02 \x01(\x05R\vparamsCount\"\xac\x01\n" + @@ -2206,6 +2207,7 @@ func file_vtgate_proto_init() { if File_vtgate_proto != nil { return } + file_vtgate_proto_msgTypes[0].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ diff --git a/go/vt/proto/vtgate/vtgate_vtproto.pb.go b/go/vt/proto/vtgate/vtgate_vtproto.pb.go index d5c5ba27bab..c2f5c6c8260 100644 --- a/go/vt/proto/vtgate/vtgate_vtproto.pb.go +++ b/go/vt/proto/vtgate/vtgate_vtproto.pb.go @@ -70,7 +70,6 @@ func (m *Session) CloneVT() *Session { r.QueryTimeout = m.QueryTimeout r.MigrationContext = m.MigrationContext r.ErrorUntilRollback = m.ErrorUntilRollback - r.SessionHash = m.SessionHash if rhs := m.ShardSessions; rhs != nil { tmpContainer := make([]*Session_ShardSession, len(rhs)) for k, v := range rhs { @@ -132,6 +131,10 @@ func (m *Session) CloneVT() *Session { } r.PrepareStatement = tmpContainer } + if rhs := m.SessionHash; rhs != nil { + tmpVal := *rhs + r.SessionHash = &tmpVal + } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -690,8 +693,8 @@ func (m *Session) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } - if m.SessionHash != 0 { - i = protohelpers.EncodeVarint(dAtA, i, uint64(m.SessionHash)) + if m.SessionHash != nil { + i = protohelpers.EncodeVarint(dAtA, i, uint64(*m.SessionHash)) i-- dAtA[i] = 0x1 i-- @@ -2464,8 +2467,8 @@ func (m *Session) SizeVT() (n int) { if m.ErrorUntilRollback { n += 3 } - if m.SessionHash != 0 { - n += 2 + protohelpers.SizeOfVarint(uint64(m.SessionHash)) + if m.SessionHash != nil { + n += 2 + protohelpers.SizeOfVarint(uint64(*m.SessionHash)) } n += len(m.unknownFields) return n @@ -4285,7 +4288,7 @@ func (m *Session) UnmarshalVT(dAtA []byte) error { if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field SessionHash", wireType) } - m.SessionHash = 0 + var v uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return protohelpers.ErrIntOverflow @@ -4295,11 +4298,12 @@ func (m *Session) UnmarshalVT(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.SessionHash |= uint64(b&0x7F) << shift + v |= uint64(b&0x7F) << shift if b < 0x80 { break } } + m.SessionHash = &v default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 9eced92b24a..cad7f796ce4 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -477,6 +477,49 @@ func TestDebugHandler(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestPickNoSessionHash(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + localTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + // Test with nil opts + result := b.Pick(target, nil, nil) + require.Nil(t, result) + + // Test with opts but nil session hash + optsNoHash := &PickOpts{sessionHash: nil} + result = b.Pick(target, nil, optsNoHash) + require.Nil(t, result) +} + func sessionHash(i uint64) *PickOpts { return &PickOpts{sessionHash: &i} } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 339ac2d08ea..ce4e85301bf 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -531,7 +531,7 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { DDLStrategy: defaultDDLStrategy, MigrationContext: "", SessionUUID: u.String(), - SessionHash: sessionHash, + SessionHash: &sessionHash, EnableSystemSettings: sysVarSetEnabled, } if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { diff --git a/proto/vtgate.proto b/proto/vtgate.proto index b61d229a736..8e9c6357267 100644 --- a/proto/vtgate.proto +++ b/proto/vtgate.proto @@ -165,7 +165,7 @@ message Session { // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same // tablet. - uint64 SessionHash = 29; + optional uint64 SessionHash = 29; } // PrepareData keeps the prepared statement and other information related for execution of it. From 0698c2bf5503170e045f841aa780cbb9d7f1142c Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 23 Aug 2025 18:02:07 -0400 Subject: [PATCH 12/67] Pass in invalid tablets to tablet balancer Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 4 +- go/vt/vtgate/balancer/balancer_test.go | 8 +-- go/vt/vtgate/balancer/hashring.go | 32 +++++++-- go/vt/vtgate/balancer/hashring_test.go | 22 +++--- go/vt/vtgate/balancer/session.go | 14 ++-- go/vt/vtgate/balancer/session_test.go | 97 ++++++++++++++++++++++---- go/vt/vtgate/tabletgateway.go | 2 +- 7 files changed, 131 insertions(+), 48 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 7fbabfa644c..586a2d1ad40 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -90,7 +90,7 @@ converge on the desired balanced query load. type TabletBalancer interface { // Pick is the main entry point to the balancer. Returns the best tablet out of the list // for a given query to maintain the desired balanced allocation over multiple executions. - Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth + Pick(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts *PickOpts) *discovery.TabletHealth // DebugHandler provides a summary of tablet balancer state DebugHandler(w http.ResponseWriter, r *http.Request) @@ -173,7 +173,7 @@ func (b *tabletBalancer) DebugHandler(w http.ResponseWriter, _ *http.Request) { // Given the total allocation for the set of tablets, choose the best target // by a weighted random sample so that over time the system will achieve the // desired balanced allocation. -func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ *PickOpts) *discovery.TabletHealth { +func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ map[string]bool, _ *PickOpts) *discovery.TabletHealth { numTablets := len(tablets) if numTablets == 0 { return nil diff --git a/go/vt/vtgate/balancer/balancer_test.go b/go/vt/vtgate/balancer/balancer_test.go index 78b6e708d50..b8e5fe58183 100644 --- a/go/vt/vtgate/balancer/balancer_test.go +++ b/go/vt/vtgate/balancer/balancer_test.go @@ -298,7 +298,7 @@ func TestBalancedPick(t *testing.T) { b := NewTabletBalancer(localCell, vtGateCells).(*tabletBalancer) for i := 0; i < N/len(vtGateCells); i++ { - th := b.Pick(target, tablets, nil) + th := b.Pick(target, tablets, nil, nil) if i == 0 { t.Logf("Target Flows %v, Balancer: %s\n", expectedPerCell, b.print()) } @@ -336,7 +336,7 @@ func TestTopologyChanged(t *testing.T) { tablets = tablets[0:2] for i := 0; i < N; i++ { - th := b.Pick(target, tablets, nil) + th := b.Pick(target, tablets, nil, nil) allocation, totalAllocation := b.getAllocation(target, tablets) assert.Equalf(t, ALLOCATION/2, totalAllocation, "totalAllocation mismatch %s", b.print()) @@ -346,7 +346,7 @@ func TestTopologyChanged(t *testing.T) { // Run again with the full topology. Now traffic should go to cell b for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil) + th := b.Pick(target, allTablets, nil, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -359,7 +359,7 @@ func TestTopologyChanged(t *testing.T) { newTablet := createTestTablet("b") allTablets[2] = newTablet for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil) + th := b.Pick(target, allTablets, nil, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go index 5267482b6f5..09ae50d903c 100644 --- a/go/vt/vtgate/balancer/hashring.go +++ b/go/vt/vtgate/balancer/hashring.go @@ -108,8 +108,8 @@ func (r *hashRing) remove(tablet *discovery.TabletHealth) { delete(r.tablets, tabletAlias(tablet)) } -// get returns the tablet for the given key. -func (r *hashRing) get(key string) *discovery.TabletHealth { +// get returns the tablet for the given key, ignoring invalid tablets. +func (r *hashRing) get(key string, invalidTablets map[string]bool) *discovery.TabletHealth { r.mu.RLock() defer r.mu.RUnlock() @@ -118,12 +118,12 @@ func (r *hashRing) get(key string) *discovery.TabletHealth { } hash := xxhash.Sum64String(key) - tablet := r.getHashed(hash) + tablet := r.getHashed(hash, invalidTablets) return tablet } -// getHashed returns the tablet for the given hash. -func (r *hashRing) getHashed(hash uint64) *discovery.TabletHealth { +// getHashed returns the tablet for the given hash, ignoring invalid tablets. +func (r *hashRing) getHashed(hash uint64, invalidTablets map[string]bool) *discovery.TabletHealth { r.mu.RLock() defer r.mu.RUnlock() @@ -131,14 +131,22 @@ func (r *hashRing) getHashed(hash uint64) *discovery.TabletHealth { return nil } - // Find the first node greater than or equal to this hash + // Find the first node greater than or equal to this hash, and isn't invalid i := sort.Search(len(r.nodes), func(i int) bool { - return r.nodes[i] >= hash + node := r.nodes[i] + + return !r.invalidNode(node, invalidTablets) && r.nodes[i] >= hash }) // Wrap around if needed if i == len(r.nodes) { i = 0 + + // If the first tablet is invalid, it means we couldn't find any valid tablets + node := r.nodes[i] + if r.invalidNode(node, invalidTablets) { + return nil + } } // Return the associated tablet @@ -196,3 +204,13 @@ func removeNodes(nodes []uint64, hashes map[uint64]struct{}) []uint64 { return nodes[:writeIdx] } + +// invalidNode returns whether the virtual node is associated with an invalid tablet. +func (r *hashRing) invalidNode(node uint64, invalidTablets map[string]bool) bool { + tablet := r.nodeMap[node] + + alias := topoproto.TabletAliasString(tablet.Tablet.Alias) + _, invalid := invalidTablets[alias] + + return invalid +} diff --git a/go/vt/vtgate/balancer/hashring_test.go b/go/vt/vtgate/balancer/hashring_test.go index 06ab740dd12..2a3e8c43592 100644 --- a/go/vt/vtgate/balancer/hashring_test.go +++ b/go/vt/vtgate/balancer/hashring_test.go @@ -142,20 +142,20 @@ func TestHashRingGet(t *testing.T) { tablet1 := createTestTabletForHashRing("cell1", 100) tablet2 := createTestTabletForHashRing("cell2", 200) - result := ring.get("test_key") + result := ring.get("test_key", nil) require.Nil(t, result) ring.add(tablet1) ring.sort() - result = ring.get("test_key") + result = ring.get("test_key", nil) require.NotNil(t, result) require.Equal(t, tablet1, result) ring.add(tablet2) ring.sort() - result = ring.get("test_key") + result = ring.get("test_key", nil) require.NotNil(t, result) // Empirically know that "test_key" hashes closest to tablet2 @@ -212,19 +212,19 @@ func TestHashRingAddRemoveSequence(t *testing.T) { key := "test_sequence" - initialTablet := ring.get(key) + initialTablet := ring.get(key, nil) require.NotNil(t, initialTablet) ring.remove(tablet2) ring.sort() - afterRemovalTablet := ring.get(key) + afterRemovalTablet := ring.get(key, nil) require.NotNil(t, afterRemovalTablet) ring.add(tablet2) ring.sort() - afterReaddTablet := ring.get(key) + afterReaddTablet := ring.get(key, nil) require.NotNil(t, afterReaddTablet) require.Equal(t, initialTablet, afterReaddTablet) @@ -243,7 +243,7 @@ func TestHashRingWrapAround(t *testing.T) { ring.nodeMap[3000] = tablet1 // Any large hash should wrap around to the first node - result := ring.get("this_should_wrap_around_with_large_hash") + result := ring.get("this_should_wrap_around_with_large_hash", nil) require.NotNil(t, result) require.Contains(t, []*discovery.TabletHealth{tablet1, tablet2}, result) } @@ -265,7 +265,7 @@ func TestHashRingRemoveAllTablets(t *testing.T) { require.Empty(t, ring.nodes) require.Empty(t, ring.nodeMap) - require.Nil(t, ring.get("any_key")) + require.Nil(t, ring.get("any_key", nil)) } func TestHashRingMultipleAddSameTablet(t *testing.T) { @@ -294,13 +294,13 @@ func TestHashRingGetAfterRemove(t *testing.T) { ring.sort() // Empirically know that this hashes closest to tablet3 - got := ring.get("key") + got := ring.get("key", nil) require.Equal(t, tablet3, got) // Remove tablet3 ring.remove(tablet3) - got = ring.get("key") + got = ring.get("key", nil) require.NotEqual(t, tablet3, got) } @@ -321,7 +321,7 @@ func TestHashRingConcurrentGetOperations(t *testing.T) { go func(i int) { defer wg.Done() key := fmt.Sprintf("concurrent_key_%d", i) - tablet := ring.get(key) + tablet := ring.get(key, nil) require.NotNil(t, tablet) }(i) } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 01331cb6308..011fff2707a 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -70,11 +70,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, hc discovery.Heal // // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. -// -// NOTE: this currently won't consider any invalid tablets. This means we'll keep returning the same -// invalid tablet on subsequent tries. We can improve this by maybe returning a random tablet (local -// cell preferred) when the session hash falls on an invalid tablet. -func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { +func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, invalidTablets map[string]bool, opts *PickOpts) *discovery.TabletHealth { if opts == nil || opts.sessionHash == nil { // No session hash. Returning nil here will allow the gateway to select a random // tablet instead. @@ -87,13 +83,13 @@ func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHeal defer b.mu.RUnlock() // Try to find a tablet in the local cell first - tablet := getFromRing(b.localRings, target, sessionHash) + tablet := getFromRing(b.localRings, target, invalidTablets, sessionHash) if tablet != nil { return tablet } // If we didn't find a tablet in the local cell, try external cells - tablet = getFromRing(b.externalRings, target, sessionHash) + tablet = getFromRing(b.externalRings, target, invalidTablets, sessionHash) return tablet } @@ -192,7 +188,7 @@ func (b *SessionBalancer) print() string { } // getFromRing gets a tablet from the respective ring for the given target and session hash. -func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, sessionHash uint64) *discovery.TabletHealth { +func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, invalidTablets map[string]bool, sessionHash uint64) *discovery.TabletHealth { key := discovery.KeyFromTarget(target) ring, exists := rings[key] @@ -200,5 +196,5 @@ func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target * return nil } - return ring.getHashed(sessionHash) + return ring.getHashed(sessionHash, invalidTablets) } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index cad7f796ce4..d13dcf530d6 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -28,6 +28,7 @@ import ( "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/topo/topoproto" ) func newSessionBalancer(t *testing.T) (*SessionBalancer, chan *discovery.TabletHealth) { @@ -63,7 +64,7 @@ func TestPickNoTablets(t *testing.T) { } opts := sessionHash(12345) - result := b.Pick(target, nil, opts) + result := b.Pick(target, nil, nil, opts) require.Nil(t, result) } @@ -121,16 +122,16 @@ func TestPickLocalOnly(t *testing.T) { // Pick for a specific session hash opts := sessionHash(12345) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) // Pick again with same session hash, should return same tablet - picked2 := b.Pick(target, nil, opts) + picked2 := b.Pick(target, nil, nil, opts) require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 opts = sessionHash(5018141287610575993) - picked3 := b.Pick(target, nil, opts) + picked3 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked3) require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked3))) } @@ -208,7 +209,7 @@ func TestPickPreferLocal(t *testing.T) { // Pick should prefer local cell opts := sessionHash(5018141287610575993) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) } @@ -267,7 +268,7 @@ func TestPickNoLocal(t *testing.T) { // Pick should return external cell since there are no local cells opts := sessionHash(12345) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) } @@ -324,7 +325,7 @@ func TestTabletNotServing(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := sessionHash(5018141287610575993) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) // Local tablet goes out of serving @@ -335,7 +336,7 @@ func TestTabletNotServing(t *testing.T) { time.Sleep(10 * time.Millisecond) // Should not pick the local tablet again - picked2 := b.Pick(target, nil, opts) + picked2 := b.Pick(target, nil, nil, opts) require.NotEqual(t, picked1, picked2) } @@ -371,7 +372,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := sessionHash(5018141287610575993) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) localTablet2 := &discovery.TabletHealth{ @@ -396,7 +397,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - picked2 := b.Pick(target, nil, opts) + picked2 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked2) require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) } @@ -433,7 +434,7 @@ func TestNewExternalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := sessionHash(5018141287610575993) - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) externalTablet2 := &discovery.TabletHealth{ @@ -458,7 +459,7 @@ func TestNewExternalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - picked2 := b.Pick(target, nil, opts) + picked2 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked2) require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) } @@ -511,15 +512,83 @@ func TestPickNoSessionHash(t *testing.T) { time.Sleep(10 * time.Millisecond) // Test with nil opts - result := b.Pick(target, nil, nil) + result := b.Pick(target, nil, nil, nil) require.Nil(t, result) // Test with opts but nil session hash optsNoHash := &PickOpts{sessionHash: nil} - result = b.Pick(target, nil, optsNoHash) + result = b.Pick(target, nil, nil, optsNoHash) require.Nil(t, result) } +func TestPickInvalidTablets(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + localTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + localTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet + hcChan <- localTablet2 + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + // Get a tablet regularly + opts := sessionHash(5018141287610575993) + tablet := b.Pick(target, nil, nil, opts) + require.NotNil(t, tablet) + + // Mark returned tablet as invalid, should return other tablet + invalidTablets := map[string]bool{topoproto.TabletAliasString(tablet.Tablet.Alias): true} + tablet2 := b.Pick(target, nil, invalidTablets, opts) + require.NotEqual(t, tablet, tablet2) + + // Mark both as invalid, should return nil + invalidTablets[topoproto.TabletAliasString(tablet2.Tablet.Alias)] = true + tablet3 := b.Pick(target, nil, invalidTablets, opts) + require.Nil(t, tablet3) +} + func sessionHash(i uint64) *PickOpts { return &PickOpts{sessionHash: &i} } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 3a3fdcfbfc4..d9e3b3e92e0 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -411,7 +411,7 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablet }) } - tablet = gw.balancer.Pick(target, tablets, nil) + tablet = gw.balancer.Pick(target, tablets, nil, nil) } if tablet != nil { From 7159adf41eeeb5ce2902f5ce7a48e31b3b64d728 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sun, 24 Aug 2025 11:06:35 -0400 Subject: [PATCH 13/67] Build initial hash rings Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 36 +++++++++++++++++++++++---- go/vt/vtgate/balancer/session_test.go | 9 +++---- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 011fff2707a..47eec090e3d 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -25,10 +25,16 @@ import ( "strings" "sync" + "github.com/DataDog/appsec-internal-go/log" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/srvtopo" ) +// tabletTypesToWatch are the tablet types that will be included in the hash rings. +var tabletTypesToWatch = []topodata.TabletType{topodata.TabletType_PRIMARY, topodata.TabletType_REPLICA, topodata.TabletType_BATCH} + // SessionBalancer implements the TabletBalancer interface. For a given session, // it will return the same tablet for its duration, with preference to tablets in // the local cell. @@ -51,7 +57,7 @@ type SessionBalancer struct { } // NewSessionBalancer creates a new session balancer. -func NewSessionBalancer(ctx context.Context, localCell string, hc discovery.HealthCheck) TabletBalancer { +func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) TabletBalancer { b := &SessionBalancer{ localCell: localCell, hc: hc, @@ -60,8 +66,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, hc discovery.Heal } // Set up health check subscription - hcChan := b.hc.Subscribe("SessionBalancer") - go b.watchHealthCheck(ctx, hcChan) + go b.watchHealthCheck(ctx, topoServer) return b } @@ -107,7 +112,26 @@ func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { } // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. -func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { +func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvtopo.Server) { + // Build initial hash rings + + // Find all the targets we're watching + targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, tabletTypesToWatch) + if err != nil { + log.Errorf("session balancer: failed to find all targets and keyspaces: %q", err) + return + } + + // Add each tablet to the hash ring + for _, target := range targets { + tablets := b.hc.GetHealthyTabletStats(target) + for _, tablet := range tablets { + b.onTabletHealthChange(tablet) + } + } + + // Start watching health check channel for future tablet health changes + hcChan := b.hc.Subscribe("SessionBalancer") for { select { case <-ctx.Done(): @@ -123,7 +147,9 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis } } -// onTabletHealthChange is the handler for tablet health events. +// onTabletHealthChange is the handler for tablet health events. If a tablet goes into serving, +// it is added to the appropriate (local or external) hash ring for its target. If it goes out +// of serving, it is removed from the hash ring. func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { b.mu.Lock() defer b.mu.Unlock() diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index d13dcf530d6..b476651fab2 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -28,6 +28,7 @@ import ( "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/vt/srvtopo/fakesrvtopo" "vitess.io/vitess/go/vt/topo/topoproto" ) @@ -36,7 +37,7 @@ func newSessionBalancer(t *testing.T) (*SessionBalancer, chan *discovery.TabletH ch := make(chan *discovery.TabletHealth, 10) hc := discovery.NewFakeHealthCheck(ch) - b := NewSessionBalancer(ctx, "local", hc) + b := NewSessionBalancer(ctx, "local", &fakesrvtopo.FakeSrvTopo{}, hc) sb := b.(*SessionBalancer) return sb, ch @@ -465,11 +466,7 @@ func TestNewExternalTablet(t *testing.T) { } func TestDebugHandler(t *testing.T) { - ctx := t.Context() - - ch := make(chan *discovery.TabletHealth, 10) - hc := discovery.NewFakeHealthCheck(ch) - b := NewSessionBalancer(ctx, "local", hc) + b, _ := newSessionBalancer(t) w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/debug", nil) From b77e75826c0061435c13caf19e2b04078e392a21 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sun, 24 Aug 2025 18:38:51 -0400 Subject: [PATCH 14/67] Set up health check subscription first to avoid missing changes Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 38 +++++++++++++-------------- go/vt/vtgate/balancer/session_test.go | 2 +- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 47eec090e3d..e382701c682 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -57,7 +57,7 @@ type SessionBalancer struct { } // NewSessionBalancer creates a new session balancer. -func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) TabletBalancer { +func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) (TabletBalancer, error) { b := &SessionBalancer{ localCell: localCell, hc: hc, @@ -68,7 +68,24 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // Set up health check subscription go b.watchHealthCheck(ctx, topoServer) - return b + // Build initial hash rings + + // Find all the targets we're watching + targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, tabletTypesToWatch) + if err != nil { + log.Errorf("session balancer: failed to find all targets and keyspaces: %q", err) + return nil, err + } + + // Add each tablet to the hash ring + for _, target := range targets { + tablets := b.hc.GetHealthyTabletStats(target) + for _, tablet := range tablets { + b.onTabletHealthChange(tablet) + } + } + + return b, nil } // Pick is the main entry point to the balancer. @@ -113,23 +130,6 @@ func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvtopo.Server) { - // Build initial hash rings - - // Find all the targets we're watching - targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, tabletTypesToWatch) - if err != nil { - log.Errorf("session balancer: failed to find all targets and keyspaces: %q", err) - return - } - - // Add each tablet to the hash ring - for _, target := range targets { - tablets := b.hc.GetHealthyTabletStats(target) - for _, tablet := range tablets { - b.onTabletHealthChange(tablet) - } - } - // Start watching health check channel for future tablet health changes hcChan := b.hc.Subscribe("SessionBalancer") for { diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index b476651fab2..37e48e2fd83 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -37,7 +37,7 @@ func newSessionBalancer(t *testing.T) (*SessionBalancer, chan *discovery.TabletH ch := make(chan *discovery.TabletHealth, 10) hc := discovery.NewFakeHealthCheck(ch) - b := NewSessionBalancer(ctx, "local", &fakesrvtopo.FakeSrvTopo{}, hc) + b, _ := NewSessionBalancer(ctx, "local", &fakesrvtopo.FakeSrvTopo{}, hc) sb := b.(*SessionBalancer) return sb, ch From b38c28ffe4267caac42e7b16a399c7cdd36a2cd3 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 25 Aug 2025 17:28:13 -0400 Subject: [PATCH 15/67] Hash uuid on pick Signed-off-by: Mohamed Hamza --- go/vt/proto/vtgate/vtgate.pb.go | 23 ++++------------- go/vt/proto/vtgate/vtgate_vtproto.pb.go | 34 ------------------------- go/vt/vtgate/balancer/balancer.go | 4 +-- go/vt/vtgate/balancer/hashring.go | 12 --------- go/vt/vtgate/balancer/session.go | 12 ++++----- go/vt/vtgate/balancer/session_test.go | 26 +++++++++---------- go/vt/vtgate/plugin_mysql_server.go | 2 -- proto/vtgate.proto | 4 --- 8 files changed, 26 insertions(+), 91 deletions(-) diff --git a/go/vt/proto/vtgate/vtgate.pb.go b/go/vt/proto/vtgate/vtgate.pb.go index 087950da386..cfbd3e73fd2 100644 --- a/go/vt/proto/vtgate/vtgate.pb.go +++ b/go/vt/proto/vtgate/vtgate.pb.go @@ -226,11 +226,8 @@ type Session struct { // MigrationContext MigrationContext string `protobuf:"bytes,27,opt,name=migration_context,json=migrationContext,proto3" json:"migration_context,omitempty"` ErrorUntilRollback bool `protobuf:"varint,28,opt,name=error_until_rollback,json=errorUntilRollback,proto3" json:"error_until_rollback,omitempty"` - // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same - // tablet. - SessionHash *uint64 `protobuf:"varint,29,opt,name=SessionHash,proto3,oneof" json:"SessionHash,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Session) Reset() { @@ -452,13 +449,6 @@ func (x *Session) GetErrorUntilRollback() bool { return false } -func (x *Session) GetSessionHash() uint64 { - if x != nil && x.SessionHash != nil { - return *x.SessionHash - } - return 0 -} - // PrepareData keeps the prepared statement and other information related for execution of it. type PrepareData struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1916,7 +1906,7 @@ var File_vtgate_proto protoreflect.FileDescriptor const file_vtgate_proto_rawDesc = "" + "\n" + - "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\x85\x10\n" + + "\fvtgate.proto\x12\x06vtgate\x1a\x10binlogdata.proto\x1a\vquery.proto\x1a\x0etopodata.proto\x1a\vvtrpc.proto\"\xce\x0f\n" + "\aSession\x12%\n" + "\x0ein_transaction\x18\x01 \x01(\bR\rinTransaction\x12C\n" + "\x0eshard_sessions\x18\x02 \x03(\v2\x1c.vtgate.Session.ShardSessionR\rshardSessions\x12\x1e\n" + @@ -1950,8 +1940,7 @@ const file_vtgate_proto_rawDesc = "" + "\rquery_timeout\x18\x19 \x01(\x03R\fqueryTimeout\x12R\n" + "\x11prepare_statement\x18\x1a \x03(\v2%.vtgate.Session.PrepareStatementEntryR\x10prepareStatement\x12+\n" + "\x11migration_context\x18\x1b \x01(\tR\x10migrationContext\x120\n" + - "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x12%\n" + - "\vSessionHash\x18\x1d \x01(\x04H\x00R\vSessionHash\x88\x01\x01\x1a\xf9\x01\n" + + "\x14error_until_rollback\x18\x1c \x01(\bR\x12errorUntilRollback\x1a\xf9\x01\n" + "\fShardSession\x12%\n" + "\x06target\x18\x01 \x01(\v2\r.query.TargetR\x06target\x12%\n" + "\x0etransaction_id\x18\x02 \x01(\x03R\rtransactionId\x128\n" + @@ -1971,8 +1960,7 @@ const file_vtgate_proto_rawDesc = "" + "\x05value\x18\x02 \x01(\x03R\x05value:\x028\x01\x1aX\n" + "\x15PrepareStatementEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.vtgate.PrepareDataR\x05value:\x028\x01B\x0e\n" + - "\f_SessionHashJ\x04\b\x03\x10\x04\"]\n" + + "\x05value\x18\x02 \x01(\v2\x13.vtgate.PrepareDataR\x05value:\x028\x01J\x04\b\x03\x10\x04\"]\n" + "\vPrepareData\x12+\n" + "\x11prepare_statement\x18\x01 \x01(\tR\x10prepareStatement\x12!\n" + "\fparams_count\x18\x02 \x01(\x05R\vparamsCount\"\xac\x01\n" + @@ -2207,7 +2195,6 @@ func file_vtgate_proto_init() { if File_vtgate_proto != nil { return } - file_vtgate_proto_msgTypes[0].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ diff --git a/go/vt/proto/vtgate/vtgate_vtproto.pb.go b/go/vt/proto/vtgate/vtgate_vtproto.pb.go index c2f5c6c8260..7e1a5679d7e 100644 --- a/go/vt/proto/vtgate/vtgate_vtproto.pb.go +++ b/go/vt/proto/vtgate/vtgate_vtproto.pb.go @@ -131,10 +131,6 @@ func (m *Session) CloneVT() *Session { } r.PrepareStatement = tmpContainer } - if rhs := m.SessionHash; rhs != nil { - tmpVal := *rhs - r.SessionHash = &tmpVal - } if len(m.unknownFields) > 0 { r.unknownFields = make([]byte, len(m.unknownFields)) copy(r.unknownFields, m.unknownFields) @@ -693,13 +689,6 @@ func (m *Session) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } - if m.SessionHash != nil { - i = protohelpers.EncodeVarint(dAtA, i, uint64(*m.SessionHash)) - i-- - dAtA[i] = 0x1 - i-- - dAtA[i] = 0xe8 - } if m.ErrorUntilRollback { i-- if m.ErrorUntilRollback { @@ -2467,9 +2456,6 @@ func (m *Session) SizeVT() (n int) { if m.ErrorUntilRollback { n += 3 } - if m.SessionHash != nil { - n += 2 + protohelpers.SizeOfVarint(uint64(*m.SessionHash)) - } n += len(m.unknownFields) return n } @@ -4284,26 +4270,6 @@ func (m *Session) UnmarshalVT(dAtA []byte) error { } } m.ErrorUntilRollback = bool(v != 0) - case 29: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field SessionHash", wireType) - } - var v uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return protohelpers.ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - v |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - m.SessionHash = &v default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 586a2d1ad40..1ec13244307 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -98,8 +98,8 @@ type TabletBalancer interface { // PickOpts are balancer options that are passed into Pick. type PickOpts struct { - // sessionHash is the hash of the current session UUID. - sessionHash *uint64 + // SessionUUID is the hash of the current session UUID. + SessionUUID *string } func NewTabletBalancer(localCell string, vtGateCells []string) TabletBalancer { diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go index 09ae50d903c..9b5c08d57cb 100644 --- a/go/vt/vtgate/balancer/hashring.go +++ b/go/vt/vtgate/balancer/hashring.go @@ -118,18 +118,6 @@ func (r *hashRing) get(key string, invalidTablets map[string]bool) *discovery.Ta } hash := xxhash.Sum64String(key) - tablet := r.getHashed(hash, invalidTablets) - return tablet -} - -// getHashed returns the tablet for the given hash, ignoring invalid tablets. -func (r *hashRing) getHashed(hash uint64, invalidTablets map[string]bool) *discovery.TabletHealth { - r.mu.RLock() - defer r.mu.RUnlock() - - if len(r.nodes) == 0 { - return nil - } // Find the first node greater than or equal to this hash, and isn't invalid i := sort.Search(len(r.nodes), func(i int) bool { diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index e382701c682..c15687c38fc 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -93,25 +93,25 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, invalidTablets map[string]bool, opts *PickOpts) *discovery.TabletHealth { - if opts == nil || opts.sessionHash == nil { + if opts == nil || opts.SessionUUID == nil { // No session hash. Returning nil here will allow the gateway to select a random // tablet instead. return nil } - sessionHash := *opts.sessionHash + sessionUUID := *opts.SessionUUID b.mu.RLock() defer b.mu.RUnlock() // Try to find a tablet in the local cell first - tablet := getFromRing(b.localRings, target, invalidTablets, sessionHash) + tablet := getFromRing(b.localRings, target, invalidTablets, sessionUUID) if tablet != nil { return tablet } // If we didn't find a tablet in the local cell, try external cells - tablet = getFromRing(b.externalRings, target, invalidTablets, sessionHash) + tablet = getFromRing(b.externalRings, target, invalidTablets, sessionUUID) return tablet } @@ -214,7 +214,7 @@ func (b *SessionBalancer) print() string { } // getFromRing gets a tablet from the respective ring for the given target and session hash. -func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, invalidTablets map[string]bool, sessionHash uint64) *discovery.TabletHealth { +func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, invalidTablets map[string]bool, sessionUUID string) *discovery.TabletHealth { key := discovery.KeyFromTarget(target) ring, exists := rings[key] @@ -222,5 +222,5 @@ func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target * return nil } - return ring.getHashed(sessionHash, invalidTablets) + return ring.get(sessionUUID, invalidTablets) } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 37e48e2fd83..aacff4d68ed 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -64,7 +64,7 @@ func TestPickNoTablets(t *testing.T) { Cell: "local", } - opts := sessionHash(12345) + opts := buildOpts("a") result := b.Pick(target, nil, nil, opts) require.Nil(t, result) } @@ -122,7 +122,7 @@ func TestPickLocalOnly(t *testing.T) { time.Sleep(10 * time.Millisecond) // Pick for a specific session hash - opts := sessionHash(12345) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) @@ -131,7 +131,7 @@ func TestPickLocalOnly(t *testing.T) { require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 - opts = sessionHash(5018141287610575993) + opts = buildOpts("c") picked3 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked3) require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked3))) @@ -209,7 +209,7 @@ func TestPickPreferLocal(t *testing.T) { time.Sleep(10 * time.Millisecond) // Pick should prefer local cell - opts := sessionHash(5018141287610575993) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) @@ -268,7 +268,7 @@ func TestPickNoLocal(t *testing.T) { time.Sleep(10 * time.Millisecond) // Pick should return external cell since there are no local cells - opts := sessionHash(12345) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) @@ -325,7 +325,7 @@ func TestTabletNotServing(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(10 * time.Millisecond) - opts := sessionHash(5018141287610575993) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) @@ -372,7 +372,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - opts := sessionHash(5018141287610575993) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) @@ -434,7 +434,7 @@ func TestNewExternalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - opts := sessionHash(5018141287610575993) + opts := buildOpts("a") picked1 := b.Pick(target, nil, nil, opts) require.NotNil(t, picked1) @@ -512,8 +512,8 @@ func TestPickNoSessionHash(t *testing.T) { result := b.Pick(target, nil, nil, nil) require.Nil(t, result) - // Test with opts but nil session hash - optsNoHash := &PickOpts{sessionHash: nil} + // Test with opts but nil session uuid + optsNoHash := &PickOpts{SessionUUID: nil} result = b.Pick(target, nil, nil, optsNoHash) require.Nil(t, result) } @@ -571,7 +571,7 @@ func TestPickInvalidTablets(t *testing.T) { time.Sleep(10 * time.Millisecond) // Get a tablet regularly - opts := sessionHash(5018141287610575993) + opts := buildOpts("a") tablet := b.Pick(target, nil, nil, opts) require.NotNil(t, tablet) @@ -586,6 +586,6 @@ func TestPickInvalidTablets(t *testing.T) { require.Nil(t, tablet3) } -func sessionHash(i uint64) *PickOpts { - return &PickOpts{sessionHash: &i} +func buildOpts(uuid string) *PickOpts { + return &PickOpts{SessionUUID: &uuid} } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index ce4e85301bf..09d503c51fa 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -519,7 +519,6 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { session, _ := c.ClientData.(*vtgatepb.Session) if session == nil { u, _ := uuid.NewUUID() - sessionHash := xxhash.Sum64String(u.String()) session = &vtgatepb.Session{ Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, @@ -531,7 +530,6 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { DDLStrategy: defaultDDLStrategy, MigrationContext: "", SessionUUID: u.String(), - SessionHash: &sessionHash, EnableSystemSettings: sysVarSetEnabled, } if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { diff --git a/proto/vtgate.proto b/proto/vtgate.proto index 8e9c6357267..06edb7feb62 100644 --- a/proto/vtgate.proto +++ b/proto/vtgate.proto @@ -162,10 +162,6 @@ message Session { string migration_context = 27; bool error_until_rollback = 28; - - // SessionHash is the xxhash of the Session UUID. Used to route sessions to the same - // tablet. - optional uint64 SessionHash = 29; } // PrepareData keeps the prepared statement and other information related for execution of it. From df48aa1c5643fefbadcd327ff1dc53ea892b9d2d Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 25 Aug 2025 17:33:04 -0400 Subject: [PATCH 16/67] Add invalid tablets to `PickOpts` Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 7 +++-- go/vt/vtgate/balancer/session.go | 6 ++--- go/vt/vtgate/balancer/session_test.go | 38 +++++++++++++-------------- go/vt/vtgate/tabletgateway.go | 2 +- 4 files changed, 28 insertions(+), 25 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 1ec13244307..d9d1e144d2d 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -90,7 +90,7 @@ converge on the desired balanced query load. type TabletBalancer interface { // Pick is the main entry point to the balancer. Returns the best tablet out of the list // for a given query to maintain the desired balanced allocation over multiple executions. - Pick(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts *PickOpts) *discovery.TabletHealth + Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth // DebugHandler provides a summary of tablet balancer state DebugHandler(w http.ResponseWriter, r *http.Request) @@ -98,6 +98,9 @@ type TabletBalancer interface { // PickOpts are balancer options that are passed into Pick. type PickOpts struct { + // InvalidTablets is a set of tablets that should not be picked. + InvalidTablets map[string]bool + // SessionUUID is the hash of the current session UUID. SessionUUID *string } @@ -173,7 +176,7 @@ func (b *tabletBalancer) DebugHandler(w http.ResponseWriter, _ *http.Request) { // Given the total allocation for the set of tablets, choose the best target // by a weighted random sample so that over time the system will achieve the // desired balanced allocation. -func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ map[string]bool, _ *PickOpts) *discovery.TabletHealth { +func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ *PickOpts) *discovery.TabletHealth { numTablets := len(tablets) if numTablets == 0 { return nil diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index c15687c38fc..888d4959e0b 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -92,7 +92,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. -func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, invalidTablets map[string]bool, opts *PickOpts) *discovery.TabletHealth { +func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { if opts == nil || opts.SessionUUID == nil { // No session hash. Returning nil here will allow the gateway to select a random // tablet instead. @@ -105,13 +105,13 @@ func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHeal defer b.mu.RUnlock() // Try to find a tablet in the local cell first - tablet := getFromRing(b.localRings, target, invalidTablets, sessionUUID) + tablet := getFromRing(b.localRings, target, opts.InvalidTablets, sessionUUID) if tablet != nil { return tablet } // If we didn't find a tablet in the local cell, try external cells - tablet = getFromRing(b.externalRings, target, invalidTablets, sessionUUID) + tablet = getFromRing(b.externalRings, target, opts.InvalidTablets, sessionUUID) return tablet } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index aacff4d68ed..980c05e508d 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -65,7 +65,7 @@ func TestPickNoTablets(t *testing.T) { } opts := buildOpts("a") - result := b.Pick(target, nil, nil, opts) + result := b.Pick(target, nil, opts) require.Nil(t, result) } @@ -123,16 +123,16 @@ func TestPickLocalOnly(t *testing.T) { // Pick for a specific session hash opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) // Pick again with same session hash, should return same tablet - picked2 := b.Pick(target, nil, nil, opts) + picked2 := b.Pick(target, nil, opts) require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 opts = buildOpts("c") - picked3 := b.Pick(target, nil, nil, opts) + picked3 := b.Pick(target, nil, opts) require.NotNil(t, picked3) require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked3))) } @@ -210,7 +210,7 @@ func TestPickPreferLocal(t *testing.T) { // Pick should prefer local cell opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) } @@ -269,7 +269,7 @@ func TestPickNoLocal(t *testing.T) { // Pick should return external cell since there are no local cells opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) } @@ -326,7 +326,7 @@ func TestTabletNotServing(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) // Local tablet goes out of serving @@ -337,7 +337,7 @@ func TestTabletNotServing(t *testing.T) { time.Sleep(10 * time.Millisecond) // Should not pick the local tablet again - picked2 := b.Pick(target, nil, nil, opts) + picked2 := b.Pick(target, nil, opts) require.NotEqual(t, picked1, picked2) } @@ -373,7 +373,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) localTablet2 := &discovery.TabletHealth{ @@ -398,7 +398,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - picked2 := b.Pick(target, nil, nil, opts) + picked2 := b.Pick(target, nil, opts) require.NotNil(t, picked2) require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) } @@ -435,7 +435,7 @@ func TestNewExternalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) opts := buildOpts("a") - picked1 := b.Pick(target, nil, nil, opts) + picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) externalTablet2 := &discovery.TabletHealth{ @@ -460,7 +460,7 @@ func TestNewExternalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - picked2 := b.Pick(target, nil, nil, opts) + picked2 := b.Pick(target, nil, opts) require.NotNil(t, picked2) require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) } @@ -509,12 +509,12 @@ func TestPickNoSessionHash(t *testing.T) { time.Sleep(10 * time.Millisecond) // Test with nil opts - result := b.Pick(target, nil, nil, nil) + result := b.Pick(target, nil, nil) require.Nil(t, result) // Test with opts but nil session uuid optsNoHash := &PickOpts{SessionUUID: nil} - result = b.Pick(target, nil, nil, optsNoHash) + result = b.Pick(target, nil, optsNoHash) require.Nil(t, result) } @@ -572,17 +572,17 @@ func TestPickInvalidTablets(t *testing.T) { // Get a tablet regularly opts := buildOpts("a") - tablet := b.Pick(target, nil, nil, opts) + tablet := b.Pick(target, nil, opts) require.NotNil(t, tablet) // Mark returned tablet as invalid, should return other tablet - invalidTablets := map[string]bool{topoproto.TabletAliasString(tablet.Tablet.Alias): true} - tablet2 := b.Pick(target, nil, invalidTablets, opts) + opts.InvalidTablets = map[string]bool{topoproto.TabletAliasString(tablet.Tablet.Alias): true} + tablet2 := b.Pick(target, nil, opts) require.NotEqual(t, tablet, tablet2) // Mark both as invalid, should return nil - invalidTablets[topoproto.TabletAliasString(tablet2.Tablet.Alias)] = true - tablet3 := b.Pick(target, nil, invalidTablets, opts) + opts.InvalidTablets[topoproto.TabletAliasString(tablet2.Tablet.Alias)] = true + tablet3 := b.Pick(target, nil, opts) require.Nil(t, tablet3) } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index d9e3b3e92e0..3a3fdcfbfc4 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -411,7 +411,7 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablet }) } - tablet = gw.balancer.Pick(target, tablets, nil, nil) + tablet = gw.balancer.Pick(target, tablets, nil) } if tablet != nil { From a18bab7b41ca37aada445693ef79487080133510 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 25 Aug 2025 17:43:41 -0400 Subject: [PATCH 17/67] Make session uuid not pointer Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 4 ++-- go/vt/vtgate/balancer/balancer_test.go | 8 ++++---- go/vt/vtgate/balancer/session.go | 10 +++------- go/vt/vtgate/balancer/session_test.go | 9 ++------- 4 files changed, 11 insertions(+), 20 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index d9d1e144d2d..ff49c69add6 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -101,8 +101,8 @@ type PickOpts struct { // InvalidTablets is a set of tablets that should not be picked. InvalidTablets map[string]bool - // SessionUUID is the hash of the current session UUID. - SessionUUID *string + // SessionUUID is the the current session UUID. + SessionUUID string } func NewTabletBalancer(localCell string, vtGateCells []string) TabletBalancer { diff --git a/go/vt/vtgate/balancer/balancer_test.go b/go/vt/vtgate/balancer/balancer_test.go index b8e5fe58183..78b6e708d50 100644 --- a/go/vt/vtgate/balancer/balancer_test.go +++ b/go/vt/vtgate/balancer/balancer_test.go @@ -298,7 +298,7 @@ func TestBalancedPick(t *testing.T) { b := NewTabletBalancer(localCell, vtGateCells).(*tabletBalancer) for i := 0; i < N/len(vtGateCells); i++ { - th := b.Pick(target, tablets, nil, nil) + th := b.Pick(target, tablets, nil) if i == 0 { t.Logf("Target Flows %v, Balancer: %s\n", expectedPerCell, b.print()) } @@ -336,7 +336,7 @@ func TestTopologyChanged(t *testing.T) { tablets = tablets[0:2] for i := 0; i < N; i++ { - th := b.Pick(target, tablets, nil, nil) + th := b.Pick(target, tablets, nil) allocation, totalAllocation := b.getAllocation(target, tablets) assert.Equalf(t, ALLOCATION/2, totalAllocation, "totalAllocation mismatch %s", b.print()) @@ -346,7 +346,7 @@ func TestTopologyChanged(t *testing.T) { // Run again with the full topology. Now traffic should go to cell b for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil, nil) + th := b.Pick(target, allTablets, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -359,7 +359,7 @@ func TestTopologyChanged(t *testing.T) { newTablet := createTestTablet("b") allTablets[2] = newTablet for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil, nil) + th := b.Pick(target, allTablets, nil) allocation, totalAllocation := b.getAllocation(target, allTablets) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 888d4959e0b..4a7e8e19c70 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -93,25 +93,21 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { - if opts == nil || opts.SessionUUID == nil { - // No session hash. Returning nil here will allow the gateway to select a random - // tablet instead. + if opts == nil { return nil } - sessionUUID := *opts.SessionUUID - b.mu.RLock() defer b.mu.RUnlock() // Try to find a tablet in the local cell first - tablet := getFromRing(b.localRings, target, opts.InvalidTablets, sessionUUID) + tablet := getFromRing(b.localRings, target, opts.InvalidTablets, opts.SessionUUID) if tablet != nil { return tablet } // If we didn't find a tablet in the local cell, try external cells - tablet = getFromRing(b.externalRings, target, opts.InvalidTablets, sessionUUID) + tablet = getFromRing(b.externalRings, target, opts.InvalidTablets, opts.SessionUUID) return tablet } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 980c05e508d..051de57d1d9 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -475,7 +475,7 @@ func TestDebugHandler(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } -func TestPickNoSessionHash(t *testing.T) { +func TestPickNoOpts(t *testing.T) { b, hcChan := newSessionBalancer(t) target := &querypb.Target{ @@ -511,11 +511,6 @@ func TestPickNoSessionHash(t *testing.T) { // Test with nil opts result := b.Pick(target, nil, nil) require.Nil(t, result) - - // Test with opts but nil session uuid - optsNoHash := &PickOpts{SessionUUID: nil} - result = b.Pick(target, nil, optsNoHash) - require.Nil(t, result) } func TestPickInvalidTablets(t *testing.T) { @@ -587,5 +582,5 @@ func TestPickInvalidTablets(t *testing.T) { } func buildOpts(uuid string) *PickOpts { - return &PickOpts{SessionUUID: &uuid} + return &PickOpts{SessionUUID: uuid} } From f9956e2af85ae18f9ff63f20adcbb8f084263876 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 11:05:09 -0400 Subject: [PATCH 18/67] Remove unused import Signed-off-by: Mohamed Hamza --- go/vt/vtgate/plugin_mysql_server.go | 1 - 1 file changed, 1 deletion(-) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 09d503c51fa..c50e766dae0 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -29,7 +29,6 @@ import ( "syscall" "time" - "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/spf13/pflag" From d60c5931828d0570a40d0604a121f7b51b0c8ceb Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 12:42:02 -0400 Subject: [PATCH 19/67] Remove old tablets when a tablet's target changes Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 52 +++++++++++-- go/vt/vtgate/balancer/session_test.go | 108 ++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 8 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 4a7e8e19c70..3c589bf9bbc 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -25,11 +25,12 @@ import ( "strings" "sync" - "github.com/DataDog/appsec-internal-go/log" "vitess.io/vitess/go/vt/discovery" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" + "vitess.io/vitess/go/vt/topo/topoproto" ) // tabletTypesToWatch are the tablet types that will be included in the hash rings. @@ -54,6 +55,9 @@ type SessionBalancer struct { // externalRings are the hash rings created for each target. It contains only tablets // external to localCell. externalRings map[discovery.KeyspaceShardTabletType]*hashRing + + // tabletMap is a map of all the tablets by alias currently in any of the hash rings. + tabletMap map[string]*discovery.TabletHealth } // NewSessionBalancer creates a new session balancer. @@ -63,6 +67,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop hc: hc, localRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), externalRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), + tabletMap: make(map[string]*discovery.TabletHealth), } // Set up health check subscription @@ -150,18 +155,49 @@ func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { b.mu.Lock() defer b.mu.Unlock() - var ring *hashRing - if tablet.Target.Cell == b.localCell { - ring = getOrCreateRing(b.localRings, tablet) - } else { - ring = getOrCreateRing(b.externalRings, tablet) + // Remove this tablet from other ring in case the target has changed. This can happen in + // a reparent for example, where the same tablet of target REPLICA now has a target of + // PRIMARY, and vice versa. + oldTablet, exists := b.tabletMap[topoproto.TabletAliasString(tablet.Tablet.Alias)] + if exists { + oldTarget := discovery.KeyFromTarget(oldTablet.Target) + newTarget := discovery.KeyFromTarget(tablet.Target) + + if oldTarget != newTarget { + b.removeFromRing(oldTablet) + } } if tablet.Serving { - ring.add(tablet) + b.addToRing(tablet) } else { - ring.remove(tablet) + b.removeFromRing(tablet) + } +} + +// addToRing adds a tablet to the appropriate (local or external) ring. +func (b *SessionBalancer) addToRing(tablet *discovery.TabletHealth) { + ring := b.getRing(tablet) + + ring.add(tablet) + b.tabletMap[topoproto.TabletAliasString(tablet.Tablet.Alias)] = tablet +} + +// removeFromRing removes a tablet from the appropriate (local or external) ring. +func (b *SessionBalancer) removeFromRing(tablet *discovery.TabletHealth) { + ring := b.getRing(tablet) + + ring.remove(tablet) + delete(b.tabletMap, topoproto.TabletAliasString(tablet.Tablet.Alias)) +} + +// getRing gets the appropriate (local or external) ring for the tablet. +func (b *SessionBalancer) getRing(tablet *discovery.TabletHealth) *hashRing { + if tablet.Target.Cell == b.localCell { + return getOrCreateRing(b.localRings, tablet) } + + return getOrCreateRing(b.externalRings, tablet) } // getOrCreateRing gets or creates a new ring for the given tablet. diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 051de57d1d9..0e1dbdd4297 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -581,6 +581,114 @@ func TestPickInvalidTablets(t *testing.T) { require.Nil(t, tablet3) } +func TestPickReparent(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + // Create a primary and a replica tablet + localTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + Cell: "local", + }, + Serving: true, + } + + externalTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + hcChan <- localTablet + hcChan <- externalTablet + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + // Get a tablet regularly + opts := buildOpts("a") + tablet := b.Pick(target, nil, opts) + require.NotNil(t, tablet) + require.Equal(t, uint32(101), tablet.Tablet.Alias.Uid) + + // Now perform a reparent and change the tablet types + localTablet = &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + externalTablet = &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + Cell: "external", + }, + Serving: true, + } + + hcChan <- localTablet + hcChan <- externalTablet + + // Give a moment for the worker to process the tablets + time.Sleep(10 * time.Millisecond) + + // Should pick the old primary/new replica instead + tablet = b.Pick(target, nil, opts) + require.NotNil(t, tablet) + require.Equal(t, uint32(100), tablet.Tablet.Alias.Uid) +} + func buildOpts(uuid string) *PickOpts { return &PickOpts{SessionUUID: uuid} } From 5833a54510eaa4402d1cf5e26c4659e19ba5423a Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 12:50:22 -0400 Subject: [PATCH 20/67] Fetch initial tablet state after subscribing to health check Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 3c589bf9bbc..86845140314 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -26,7 +26,6 @@ import ( "sync" "vitess.io/vitess/go/vt/discovery" - "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" @@ -71,15 +70,14 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop } // Set up health check subscription - go b.watchHealthCheck(ctx, topoServer) + hcChan := b.hc.Subscribe("SessionBalancer") // Build initial hash rings // Find all the targets we're watching targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, tabletTypesToWatch) if err != nil { - log.Errorf("session balancer: failed to find all targets and keyspaces: %q", err) - return nil, err + return nil, fmt.Errorf("session balancer: failed to find all targets and keyspaces: %w", err) } // Add each tablet to the hash ring @@ -90,6 +88,9 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop } } + // Start watcher to keep track of tablet health + go b.watchHealthCheck(ctx, topoServer, hcChan) + return b, nil } @@ -130,9 +131,8 @@ func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { } // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. -func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvtopo.Server) { +func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvtopo.Server, hcChan chan *discovery.TabletHealth) { // Start watching health check channel for future tablet health changes - hcChan := b.hc.Subscribe("SessionBalancer") for { select { case <-ctx.Done(): From 5bf666dbe076a785b6c263d999a103c1a9ae522a Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 16:15:41 -0400 Subject: [PATCH 21/67] Revert "Remove old tablets when a tablet's target changes" This reverts commit bd6a1b792d69b238df90b5bf979e0911087ada16. Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 51 ++---------- go/vt/vtgate/balancer/session_test.go | 108 -------------------------- 2 files changed, 8 insertions(+), 151 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 86845140314..94777d49634 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -25,11 +25,11 @@ import ( "strings" "sync" + "github.com/DataDog/appsec-internal-go/log" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" - "vitess.io/vitess/go/vt/topo/topoproto" ) // tabletTypesToWatch are the tablet types that will be included in the hash rings. @@ -54,9 +54,6 @@ type SessionBalancer struct { // externalRings are the hash rings created for each target. It contains only tablets // external to localCell. externalRings map[discovery.KeyspaceShardTabletType]*hashRing - - // tabletMap is a map of all the tablets by alias currently in any of the hash rings. - tabletMap map[string]*discovery.TabletHealth } // NewSessionBalancer creates a new session balancer. @@ -66,7 +63,6 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop hc: hc, localRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), externalRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), - tabletMap: make(map[string]*discovery.TabletHealth), } // Set up health check subscription @@ -155,49 +151,18 @@ func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { b.mu.Lock() defer b.mu.Unlock() - // Remove this tablet from other ring in case the target has changed. This can happen in - // a reparent for example, where the same tablet of target REPLICA now has a target of - // PRIMARY, and vice versa. - oldTablet, exists := b.tabletMap[topoproto.TabletAliasString(tablet.Tablet.Alias)] - if exists { - oldTarget := discovery.KeyFromTarget(oldTablet.Target) - newTarget := discovery.KeyFromTarget(tablet.Target) - - if oldTarget != newTarget { - b.removeFromRing(oldTablet) - } + var ring *hashRing + if tablet.Target.Cell == b.localCell { + ring = getOrCreateRing(b.localRings, tablet) + } else { + ring = getOrCreateRing(b.externalRings, tablet) } if tablet.Serving { - b.addToRing(tablet) + ring.add(tablet) } else { - b.removeFromRing(tablet) - } -} - -// addToRing adds a tablet to the appropriate (local or external) ring. -func (b *SessionBalancer) addToRing(tablet *discovery.TabletHealth) { - ring := b.getRing(tablet) - - ring.add(tablet) - b.tabletMap[topoproto.TabletAliasString(tablet.Tablet.Alias)] = tablet -} - -// removeFromRing removes a tablet from the appropriate (local or external) ring. -func (b *SessionBalancer) removeFromRing(tablet *discovery.TabletHealth) { - ring := b.getRing(tablet) - - ring.remove(tablet) - delete(b.tabletMap, topoproto.TabletAliasString(tablet.Tablet.Alias)) -} - -// getRing gets the appropriate (local or external) ring for the tablet. -func (b *SessionBalancer) getRing(tablet *discovery.TabletHealth) *hashRing { - if tablet.Target.Cell == b.localCell { - return getOrCreateRing(b.localRings, tablet) + ring.remove(tablet) } - - return getOrCreateRing(b.externalRings, tablet) } // getOrCreateRing gets or creates a new ring for the given tablet. diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 0e1dbdd4297..051de57d1d9 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -581,114 +581,6 @@ func TestPickInvalidTablets(t *testing.T) { require.Nil(t, tablet3) } -func TestPickReparent(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - // Create a primary and a replica tablet - localTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, - Cell: "local", - }, - Serving: true, - } - - externalTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", - }, - Serving: true, - } - - hcChan <- localTablet - hcChan <- externalTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - - target := &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - } - - // Get a tablet regularly - opts := buildOpts("a") - tablet := b.Pick(target, nil, opts) - require.NotNil(t, tablet) - require.Equal(t, uint32(101), tablet.Tablet.Alias.Uid) - - // Now perform a reparent and change the tablet types - localTablet = &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - externalTablet = &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, - Cell: "external", - }, - Serving: true, - } - - hcChan <- localTablet - hcChan <- externalTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - - // Should pick the old primary/new replica instead - tablet = b.Pick(target, nil, opts) - require.NotNil(t, tablet) - require.Equal(t, uint32(100), tablet.Tablet.Alias.Uid) -} - func buildOpts(uuid string) *PickOpts { return &PickOpts{SessionUUID: uuid} } From d8723b3acf1a30b2458430834ce1484298913aa1 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 16:12:21 -0400 Subject: [PATCH 22/67] Remove primary tablets from rings Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 14 ++-- go/vt/vtgate/balancer/session_test.go | 100 ++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 5 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 94777d49634..5df64268ac6 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -25,7 +25,6 @@ import ( "strings" "sync" - "github.com/DataDog/appsec-internal-go/log" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" @@ -33,7 +32,7 @@ import ( ) // tabletTypesToWatch are the tablet types that will be included in the hash rings. -var tabletTypesToWatch = []topodata.TabletType{topodata.TabletType_PRIMARY, topodata.TabletType_REPLICA, topodata.TabletType_BATCH} +var tabletTypesToWatch = map[topodata.TabletType]struct{}{topodata.TabletType_REPLICA: {}, topodata.TabletType_RDONLY: {}} // SessionBalancer implements the TabletBalancer interface. For a given session, // it will return the same tablet for its duration, with preference to tablets in @@ -71,7 +70,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // Build initial hash rings // Find all the targets we're watching - targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, tabletTypesToWatch) + targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, slices.Collect(maps.Keys(tabletTypesToWatch))) if err != nil { return nil, fmt.Errorf("session balancer: failed to find all targets and keyspaces: %w", err) } @@ -85,7 +84,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop } // Start watcher to keep track of tablet health - go b.watchHealthCheck(ctx, topoServer, hcChan) + go b.watchHealthCheck(ctx, hcChan) return b, nil } @@ -127,7 +126,7 @@ func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { } // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. -func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvtopo.Server, hcChan chan *discovery.TabletHealth) { +func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { // Start watching health check channel for future tablet health changes for { select { @@ -139,6 +138,11 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, topoServer srvto return } + // Ignore tablets we aren't supposed to watch + if _, ok := tabletTypesToWatch[tablet.Target.TabletType]; !ok { + return + } + b.onTabletHealthChange(tablet) } } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 051de57d1d9..611a646f8aa 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -581,6 +581,106 @@ func TestPickInvalidTablets(t *testing.T) { require.Nil(t, tablet3) } +func TestTabletTypesToWatch(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + // Valid tablet type + localTablet := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + // Valid tablet type + localTablet2 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_RDONLY, + Cell: "local", + }, + Serving: true, + } + + // Invalid tablet type + localTablet3 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + Cell: "local", + }, + Serving: true, + } + + // Invalid tablet type + localTablet4 := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_BACKUP, + Cell: "local", + }, + Serving: true, + } + + hcChan <- localTablet + hcChan <- localTablet2 + hcChan <- localTablet3 + hcChan <- localTablet4 + + // Give a moment for the worker to process the tablets + time.Sleep(100 * time.Millisecond) + + b.mu.RLock() + defer b.mu.RUnlock() + + require.Len(t, b.localRings, 2) + require.Len(t, b.externalRings, 0) + + for _, ring := range b.localRings { + for _, tablet := range ring.nodeMap { + require.Contains(t, tabletTypesToWatch, tablet.Target.TabletType) + } + } +} + func buildOpts(uuid string) *PickOpts { return &PickOpts{SessionUUID: uuid} } From 479cde47806f27e4d0297dabaa56d0ff232876f0 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 26 Aug 2025 16:21:42 -0400 Subject: [PATCH 23/67] undo formatting Signed-off-by: Mohamed Hamza --- go/vt/vtgate/plugin_mysql_server.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index c50e766dae0..5e6c2c689c8 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -181,8 +181,7 @@ var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions func startSpanTestable(ctx context.Context, query, label string, newSpan func(context.Context, string) (trace.Span, context.Context), - newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error), -) (trace.Span, context.Context, error) { + newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) { _, comments := sqlparser.SplitMarginComments(query) match := r.FindStringSubmatch(comments.Leading) span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) From 475408c5ba25c753b3714e65c538e9b0ad20f90f Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 27 Aug 2025 21:28:57 -0400 Subject: [PATCH 24/67] Pass session uuid to `withRetry` Update the signature of the wrapper func to accept a new `WrapOpts` struct, which currently contains `ExecuteOptions`, which now contains the session UUID so that it can be passed into `Pick` for the balancer. Signed-off-by: Mohamed Hamza --- go/flags/endtoend/vtgate.txt | 1 + go/vt/proto/query/query.pb.go | 18 +++-- go/vt/proto/query/query_vtproto.pb.go | 46 ++++++++++++ go/vt/vtgate/plugin_mysql_server.go | 4 +- go/vt/vtgate/tabletgateway.go | 48 ++++++++++--- go/vt/vtgate/vtgate.go | 6 +- go/vt/vttablet/queryservice/wrapped.go | 99 ++++++++++++++++++-------- proto/query.proto | 3 + web/vtadmin/src/proto/vtadmin.d.ts | 6 ++ web/vtadmin/src/proto/vtadmin.js | 23 ++++++ 10 files changed, 206 insertions(+), 48 deletions(-) diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index cb7c82eee04..9432be8e452 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -30,6 +30,7 @@ Flags: --alsologtostderr log to standard error as well as files --balancer-keyspaces strings When in balanced mode, a comma-separated list of keyspaces for which to use the balancer (optional) --balancer-vtgate-cells strings When in balanced mode, a comma-separated list of cells that contain vtgates (required) + --balancer-type When in balanced mode, selects the type of balancer to use. "balanced" balances connections evenly, "session" pins a connection to a given tablet for its duration. (default: "balanced") (optional) --bind-address string Bind address for the server. If empty, the server will listen on all available unicast and anycast IP addresses of the local system. --buffer-drain-concurrency int Maximum number of requests retried simultaneously. More concurrency will increase the load on the PRIMARY vttablet when draining the buffer. (default 1) --buffer-keyspace-shards string If not empty, limit buffering to these entries (comma separated). Entry format: keyspace or keyspace/shard. Requires --enable_buffer=true. diff --git a/go/vt/proto/query/query.pb.go b/go/vt/proto/query/query.pb.go index 73440c8cbcd..a1555f5aeb2 100644 --- a/go/vt/proto/query/query.pb.go +++ b/go/vt/proto/query/query.pb.go @@ -1402,8 +1402,10 @@ type ExecuteOptions struct { FetchLastInsertId bool `protobuf:"varint,18,opt,name=fetch_last_insert_id,json=fetchLastInsertId,proto3" json:"fetch_last_insert_id,omitempty"` // in_dml_execution indicates that the query is being executed as part of a DML execution. InDmlExecution bool `protobuf:"varint,19,opt,name=in_dml_execution,json=inDmlExecution,proto3" json:"in_dml_execution,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // SessionUUID is the UUID of the current session. + SessionUUID string `protobuf:"bytes,20,opt,name=SessionUUID,proto3" json:"SessionUUID,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExecuteOptions) Reset() { @@ -1550,6 +1552,13 @@ func (x *ExecuteOptions) GetInDmlExecution() bool { return false } +func (x *ExecuteOptions) GetSessionUUID() string { + if x != nil { + return x.SessionUUID + } + return "" +} + type isExecuteOptions_Timeout interface { isExecuteOptions_Timeout() } @@ -5817,7 +5826,7 @@ const file_query_proto_rawDesc = "" + "\x0ebind_variables\x18\x02 \x03(\v2$.query.BoundQuery.BindVariablesEntryR\rbindVariables\x1aU\n" + "\x12BindVariablesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\xb5\f\n" + + "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\xd7\f\n" + "\x0eExecuteOptions\x12M\n" + "\x0fincluded_fields\x18\x04 \x01(\x0e2$.query.ExecuteOptions.IncludedFieldsR\x0eincludedFields\x12*\n" + "\x11client_found_rows\x18\x05 \x01(\bR\x0fclientFoundRows\x12:\n" + @@ -5834,7 +5843,8 @@ const file_query_proto_rawDesc = "" + "\bpriority\x18\x10 \x01(\tR\bpriority\x125\n" + "\x15authoritative_timeout\x18\x11 \x01(\x03H\x00R\x14authoritativeTimeout\x12/\n" + "\x14fetch_last_insert_id\x18\x12 \x01(\bR\x11fetchLastInsertId\x12(\n" + - "\x10in_dml_execution\x18\x13 \x01(\bR\x0einDmlExecution\";\n" + + "\x10in_dml_execution\x18\x13 \x01(\bR\x0einDmlExecution\x12 \n" + + "\vSessionUUID\x18\x14 \x01(\tR\vSessionUUID\";\n" + "\x0eIncludedFields\x12\x11\n" + "\rTYPE_AND_NAME\x10\x00\x12\r\n" + "\tTYPE_ONLY\x10\x01\x12\a\n" + diff --git a/go/vt/proto/query/query_vtproto.pb.go b/go/vt/proto/query/query_vtproto.pb.go index b2a752a3ed4..9d372f22b4e 100644 --- a/go/vt/proto/query/query_vtproto.pb.go +++ b/go/vt/proto/query/query_vtproto.pb.go @@ -178,6 +178,7 @@ func (m *ExecuteOptions) CloneVT() *ExecuteOptions { r.Priority = m.Priority r.FetchLastInsertId = m.FetchLastInsertId r.InDmlExecution = m.InDmlExecution + r.SessionUUID = m.SessionUUID if rhs := m.TransactionAccessMode; rhs != nil { tmpContainer := make([]ExecuteOptions_TransactionAccessMode, len(rhs)) copy(tmpContainer, rhs) @@ -1896,6 +1897,15 @@ func (m *ExecuteOptions) MarshalToSizedBufferVT(dAtA []byte) (int, error) { } i -= size } + if len(m.SessionUUID) > 0 { + i -= len(m.SessionUUID) + copy(dAtA[i:], m.SessionUUID) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SessionUUID))) + i-- + dAtA[i] = 0x1 + i-- + dAtA[i] = 0xa2 + } if m.InDmlExecution { i-- if m.InDmlExecution { @@ -6212,6 +6222,10 @@ func (m *ExecuteOptions) SizeVT() (n int) { if m.InDmlExecution { n += 3 } + l = len(m.SessionUUID) + if l > 0 { + n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) + } n += len(m.unknownFields) return n } @@ -9000,6 +9014,38 @@ func (m *ExecuteOptions) UnmarshalVT(dAtA []byte) error { } } m.InDmlExecution = bool(v != 0) + case 20: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field SessionUUID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.SessionUUID = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 5e6c2c689c8..bbcea3f4125 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -181,7 +181,8 @@ var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions func startSpanTestable(ctx context.Context, query, label string, newSpan func(context.Context, string) (trace.Span, context.Context), - newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) { + newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error), +) (trace.Span, context.Context, error) { _, comments := sqlparser.SplitMarginComments(query) match := r.FindStringSubmatch(comments.Leading) span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) @@ -521,6 +522,7 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload), + SessionUUID: u.String(), // The collation field of ExecuteOption is set right before an execution. }, diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 3a3fdcfbfc4..30d13dee29d 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -61,6 +61,7 @@ var ( balancerEnabled bool balancerVtgateCells []string balancerKeyspaces []string + balancerType string logCollations = logutil.NewThrottledLogger("CollationInconsistent", 1*time.Minute) ) @@ -73,6 +74,7 @@ func init() { fs.BoolVar(&balancerEnabled, "enable-balancer", false, "Enable the tablet balancer to evenly spread query load for a given tablet type") fs.StringSliceVar(&balancerVtgateCells, "balancer-vtgate-cells", []string{}, "When in balanced mode, a comma-separated list of cells that contain vtgates (required)") fs.StringSliceVar(&balancerKeyspaces, "balancer-keyspaces", []string{}, "When in balanced mode, a comma-separated list of keyspaces for which to use the balancer (optional)") + fs.StringVar(&balancerType, "balancer-type", "balanced", `When in balanced mode, selects the type of balancer to use. "balanced" balances connections evenly, "session" pins a connection to a given tablet for its duration. (default: "balanced") (optional)`) }) } @@ -109,7 +111,7 @@ func createHealthCheck(ctx context.Context, retryDelay, timeout time.Duration, t } // NewTabletGateway creates and returns a new TabletGateway -func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, localCell string) *TabletGateway { +func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, localCell string) (*TabletGateway, error) { // hack to accommodate various users of gateway + tests if hc == nil { var topoServer *topo.Server @@ -130,11 +132,16 @@ func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtop statusAggregators: make(map[string]*TabletStatusAggregator), } gw.setupBuffering(ctx) + if balancerEnabled { - gw.setupBalancer(ctx) + err := gw.setupBalancer(ctx) + if err != nil { + return nil, fmt.Errorf("tablet gateway: failed to set up balancer: %w", err) + } } + gw.QueryService = queryservice.Wrap(nil, gw.withRetry) - return gw + return gw, nil } func (gw *TabletGateway) setupBuffering(ctx context.Context) { @@ -166,11 +173,29 @@ func (gw *TabletGateway) setupBuffering(ctx context.Context) { }(bufferCtx, ksChan, gw.buffer) } -func (gw *TabletGateway) setupBalancer(ctx context.Context) { +func (gw *TabletGateway) setupBalancer(ctx context.Context) error { if len(balancerVtgateCells) == 0 { log.Exitf("balancer-vtgate-cells is required for balanced mode") } - gw.balancer = balancer.NewTabletBalancer(gw.localCell, balancerVtgateCells) + + switch balancerType { + case "session": + balancer, err := balancer.NewSessionBalancer(ctx, gw.localCell, gw.srvTopoServer, gw.hc) + if err != nil { + return fmt.Errorf("failed to create session balancer: %w", err) + } + + gw.balancer = balancer + default: + if balancerType != "balanced" { + log.Warningf("Unrecognized balancer type %q, using default \"balanced\"", balancerType) + } + + balancerType = "balanced" + gw.balancer = balancer.NewTabletBalancer(gw.localCell, balancerVtgateCells) + } + + return nil } // QueryServiceByAlias satisfies the Gateway interface @@ -279,10 +304,10 @@ func (gw *TabletGateway) DebugBalancerHandler(w http.ResponseWriter, r *http.Req // withRetry also adds shard information to errors returned from the inner QueryService, so // withShardError should not be combined with withRetry. func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, _ queryservice.QueryService, - _ string, inTransaction bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), + _ string, opts *queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), ) error { // for transactions, we connect to a specific tablet instead of letting gateway choose one - if inTransaction && target.TabletType != topodatapb.TabletType_PRIMARY { + if opts.InTransaction && target.TabletType != topodatapb.TabletType_PRIMARY { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "tabletGateway's query service can only be used for non-transactional queries on replicas") } var tabletLastUsed *topodatapb.Tablet @@ -308,7 +333,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, // Note: We only buffer once and only "!inTransaction" queries i.e. // a) no transaction is necessary (e.g. critical reads) or // b) no transaction was created yet. - if gw.buffer != nil && !bufferedOnce && !inTransaction && target.TabletType == topodatapb.TabletType_PRIMARY { + if gw.buffer != nil && !bufferedOnce && !opts.InTransaction && target.TabletType == topodatapb.TabletType_PRIMARY { // The next call blocks if we should buffer during a failover. retryDone, bufferErr := gw.buffer.WaitForFailoverEnd(ctx, target.Keyspace, target.Shard, gw.kev, err) @@ -359,7 +384,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, break } - th := gw.getBalancerTablet(target, invalidTablets, tablets) + th := gw.getBalancerTablet(target, tablets, invalidTablets, opts) if th == nil { // do not override error from last attempt. if err == nil { @@ -393,7 +418,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, // getBalancerTablet selects a tablet for the given query target, using the configured balancer if enabled. Otherwise, it will // select a random tablet, with preference to the local cell. -func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablets map[string]bool, tablets []*discovery.TabletHealth) *discovery.TabletHealth { +func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts *queryservice.WrapOpts) *discovery.TabletHealth { var tablet *discovery.TabletHealth useBalancer := balancerEnabled @@ -411,7 +436,8 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, invalidTablet }) } - tablet = gw.balancer.Pick(target, tablets, nil) + opts := &balancer.PickOpts{SessionUUID: opts.Options.SessionUUID, InvalidTablets: invalidTablets} + tablet = gw.balancer.Pick(target, tablets, opts) } if tablet != nil { diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 02b1c0a1925..7ab4b16fc96 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -320,7 +320,11 @@ func Init( // Start with the gateway. If we can't reach the topology service, // we can't go on much further, so we log.Fatal out. // TabletGateway can create it's own healthcheck - gw := NewTabletGateway(ctx, hc, serv, cell) + gw, err := NewTabletGateway(ctx, hc, serv, cell) + if err != nil { + log.Fatalf("vtgate: failed to initalize tablet gateway: %w", err) + } + gw.RegisterStats() if err := gw.WaitForTablets(ctx, tabletTypesToWait); err != nil { log.Fatalf("tabletGateway.WaitForTablets failed: %v", err) diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index 2e31c66dba2..18c764cf91e 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -34,7 +34,14 @@ var _ QueryService = &wrappedService{} // The inner function returns err and canRetry. // If canRetry is true, the error is specific to the current vttablet and can be retried elsewhere. // The flag will be false if there was no error. -type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, QueryService) (canRetry bool, err error)) error +type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QueryService, name string, opts *WrapOpts, inner func(context.Context, *querypb.Target, QueryService) (canRetry bool, err error)) error + +// WrapOpts is the options passed to [WrapperFunc]. +type WrapOpts struct { + inTransaction bool + + options querypb.ExecuteOptions +} // Wrap returns a wrapped version of the original QueryService implementation. // This lets you avoid repeating boiler-plate code by consolidating it in the @@ -110,7 +117,8 @@ type wrappedService struct { } func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state TransactionState, err error) { - err = ws.wrapper(ctx, target, ws.impl, "Begin", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "Begin", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.Begin(ctx, target, options) return canRetry(ctx, innerErr), innerErr @@ -120,7 +128,8 @@ func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, opt func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - err := ws.wrapper(ctx, target, ws.impl, "Commit", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err := ws.wrapper(ctx, target, ws.impl, "Commit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Commit(ctx, target, transactionID) return canRetry(ctx, innerErr), innerErr @@ -133,7 +142,8 @@ func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, tr func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - err := ws.wrapper(ctx, target, ws.impl, "Rollback", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err := ws.wrapper(ctx, target, ws.impl, "Rollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Rollback(ctx, target, transactionID) return canRetry(ctx, innerErr), innerErr @@ -145,7 +155,8 @@ func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) error { - err := ws.wrapper(ctx, target, ws.impl, "Prepare", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err := ws.wrapper(ctx, target, ws.impl, "Prepare", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.Prepare(ctx, target, transactionID, dtid) return canRetry(ctx, innerErr), innerErr }) @@ -153,7 +164,8 @@ func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, t } func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Target, dtid string) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "CommitPrepared", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "CommitPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CommitPrepared(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr }) @@ -161,7 +173,8 @@ func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Ta } func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb.Target, dtid string, originalID int64) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "RollbackPrepared", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "RollbackPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.RollbackPrepared(ctx, target, dtid, originalID) return canRetry(ctx, innerErr), innerErr }) @@ -169,7 +182,8 @@ func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb. } func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb.Target, dtid string, participants []*querypb.Target) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "CreateTransaction", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "CreateTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CreateTransaction(ctx, target, dtid, participants) return canRetry(ctx, innerErr), innerErr }) @@ -177,7 +191,8 @@ func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb } func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (state querypb.StartCommitState, err error) { - err = ws.wrapper(ctx, target, ws.impl, "StartCommit", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "StartCommit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.StartCommit(ctx, target, transactionID, dtid) return canRetry(ctx, innerErr), innerErr @@ -186,7 +201,8 @@ func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Target, dtid string, transactionID int64) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "SetRollback", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "SetRollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.SetRollback(ctx, target, dtid, transactionID) return canRetry(ctx, innerErr), innerErr }) @@ -194,7 +210,8 @@ func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *querypb.Target, dtid string) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "ConcludeTransaction", true, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: true} + err = ws.wrapper(ctx, target, ws.impl, "ConcludeTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.ConcludeTransaction(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr }) @@ -202,7 +219,8 @@ func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *query } func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.Target, dtid string) (metadata *querypb.TransactionMetadata, err error) { - err = ws.wrapper(ctx, target, ws.impl, "ReadTransaction", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "ReadTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error metadata, innerErr = conn.ReadTransaction(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr @@ -211,7 +229,8 @@ func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.T } func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) (transactions []*querypb.TransactionMetadata, err error) { - err = ws.wrapper(ctx, target, ws.impl, "UnresolvedTransactions", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "UnresolvedTransactions", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error transactions, innerErr = conn.UnresolvedTransactions(ctx, target, abandonAgeSeconds) return canRetry(ctx, innerErr), innerErr @@ -221,7 +240,8 @@ func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *qu func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (qr *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 || reservedID != 0 - err = ws.wrapper(ctx, target, ws.impl, "Execute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err = ws.wrapper(ctx, target, ws.impl, "Execute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error qr, innerErr = conn.Execute(ctx, target, query, bindVars, transactionID, reservedID, options) // You cannot retry if you're in a transaction. @@ -234,7 +254,8 @@ func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, q // StreamExecute implements the QueryService interface func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { streamingStarted := false innerErr := conn.StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, func(qr *sqltypes.Result) error { streamingStarted = true @@ -249,7 +270,8 @@ func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Tar func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state TransactionState, qr *sqltypes.Result, err error) { inDedicatedConn := reservedID != 0 - err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, qr, innerErr = conn.BeginExecute(ctx, target, preQueries, query, bindVars, reservedID, options) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr @@ -260,7 +282,8 @@ func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Targ // BeginStreamExecute implements the QueryService interface func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state TransactionState, err error) { inDedicatedConn := reservedID != 0 - err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr @@ -269,14 +292,16 @@ func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *queryp } func (ws *wrappedService) MessageStream(ctx context.Context, target *querypb.Target, name string, callback func(*sqltypes.Result) error) error { - return ws.wrapper(ctx, target, ws.impl, "MessageStream", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, target, ws.impl, "MessageStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.MessageStream(ctx, target, name, callback) return canRetry(ctx, innerErr), innerErr }) } func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target, name string, ids []*querypb.Value) (count int64, err error) { - err = ws.wrapper(ctx, target, ws.impl, "MessageAck", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "MessageAck", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error count, innerErr = conn.MessageAck(ctx, target, name, ids) return canRetry(ctx, innerErr), innerErr @@ -285,35 +310,40 @@ func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target } func (ws *wrappedService) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error { - return ws.wrapper(ctx, request.Target, ws.impl, "VStream", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, request.Target, ws.impl, "VStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStream(ctx, request, send) return false, innerErr }) } func (ws *wrappedService) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error { - return ws.wrapper(ctx, request.Target, ws.impl, "VStreamRows", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, request.Target, ws.impl, "VStreamRows", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamRows(ctx, request, send) return false, innerErr }) } func (ws *wrappedService) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(response *binlogdatapb.VStreamTablesResponse) error) error { - return ws.wrapper(ctx, request.Target, ws.impl, "VStreamTables", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, request.Target, ws.impl, "VStreamTables", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamTables(ctx, request, send) return false, innerErr }) } func (ws *wrappedService) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error { - return ws.wrapper(ctx, target, ws.impl, "VStreamResults", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, target, ws.impl, "VStreamResults", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamResults(ctx, target, query, send) return false, innerErr }) } func (ws *wrappedService) StreamHealth(ctx context.Context, callback func(*querypb.StreamHealthResponse) error) error { - return ws.wrapper(ctx, nil, ws.impl, "StreamHealth", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, nil, ws.impl, "StreamHealth", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.StreamHealth(ctx, callback) return canRetry(ctx, innerErr), innerErr }) @@ -325,7 +355,8 @@ func (ws *wrappedService) HandlePanic(err *error) { // ReserveBeginExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state ReservedTransactionState, res *sqltypes.Result, err error) { - err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) return canRetry(ctx, err), err @@ -336,7 +367,8 @@ func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *query // ReserveBeginStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { - err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) return canRetry(ctx, innerErr), innerErr @@ -347,7 +379,8 @@ func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target // ReserveExecute implements the QueryService interface func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state ReservedState, res *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 - err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) return canRetry(ctx, err) && !inDedicatedConn, err @@ -359,7 +392,8 @@ func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Ta // ReserveStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedState, err error) { inDedicatedConn := transactionID != 0 - err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr @@ -369,14 +403,16 @@ func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *quer func (ws *wrappedService) Release(ctx context.Context, target *querypb.Target, transactionID, reservedID int64) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - return ws.wrapper(ctx, target, ws.impl, "Release", inDedicatedConn, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: inDedicatedConn} + return ws.wrapper(ctx, target, ws.impl, "Release", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Release. return false, conn.Release(ctx, target, transactionID, reservedID) }) } func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) (err error) { - err = ws.wrapper(ctx, target, ws.impl, "GetSchema", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + err = ws.wrapper(ctx, target, ws.impl, "GetSchema", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.GetSchema(ctx, target, tableType, tableNames, callback) return canRetry(ctx, innerErr), innerErr }) @@ -384,7 +420,8 @@ func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Close(ctx context.Context) error { - return ws.wrapper(ctx, nil, ws.impl, "Close", false, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { + opts := &WrapOpts{inTransaction: false} + return ws.wrapper(ctx, nil, ws.impl, "Close", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Close. return false, conn.Close(ctx) }) diff --git a/proto/query.proto b/proto/query.proto index b4ccb61af59..9b1035cc8ae 100644 --- a/proto/query.proto +++ b/proto/query.proto @@ -376,6 +376,9 @@ message ExecuteOptions { // in_dml_execution indicates that the query is being executed as part of a DML execution. bool in_dml_execution = 19; + + // SessionUUID is the UUID of the current session. + string SessionUUID = 20; } // Field describes a single column returned by a query diff --git a/web/vtadmin/src/proto/vtadmin.d.ts b/web/vtadmin/src/proto/vtadmin.d.ts index 67e8dd462da..b96a190d6d1 100644 --- a/web/vtadmin/src/proto/vtadmin.d.ts +++ b/web/vtadmin/src/proto/vtadmin.d.ts @@ -41676,6 +41676,9 @@ export namespace query { /** ExecuteOptions in_dml_execution */ in_dml_execution?: (boolean|null); + + /** ExecuteOptions SessionUUID */ + SessionUUID?: (string|null); } /** Represents an ExecuteOptions. */ @@ -41732,6 +41735,9 @@ export namespace query { /** ExecuteOptions in_dml_execution. */ public in_dml_execution: boolean; + /** ExecuteOptions SessionUUID. */ + public SessionUUID: string; + /** ExecuteOptions timeout. */ public timeout?: "authoritative_timeout"; diff --git a/web/vtadmin/src/proto/vtadmin.js b/web/vtadmin/src/proto/vtadmin.js index 86d78c1c718..2d8db5081a9 100644 --- a/web/vtadmin/src/proto/vtadmin.js +++ b/web/vtadmin/src/proto/vtadmin.js @@ -99385,6 +99385,7 @@ export const query = $root.query = (() => { * @property {number|Long|null} [authoritative_timeout] ExecuteOptions authoritative_timeout * @property {boolean|null} [fetch_last_insert_id] ExecuteOptions fetch_last_insert_id * @property {boolean|null} [in_dml_execution] ExecuteOptions in_dml_execution + * @property {string|null} [SessionUUID] ExecuteOptions SessionUUID */ /** @@ -99523,6 +99524,14 @@ export const query = $root.query = (() => { */ ExecuteOptions.prototype.in_dml_execution = false; + /** + * ExecuteOptions SessionUUID. + * @member {string} SessionUUID + * @memberof query.ExecuteOptions + * @instance + */ + ExecuteOptions.prototype.SessionUUID = ""; + // OneOf field names bound to virtual getters and setters let $oneOfFields; @@ -99595,6 +99604,8 @@ export const query = $root.query = (() => { writer.uint32(/* id 18, wireType 0 =*/144).bool(message.fetch_last_insert_id); if (message.in_dml_execution != null && Object.hasOwnProperty.call(message, "in_dml_execution")) writer.uint32(/* id 19, wireType 0 =*/152).bool(message.in_dml_execution); + if (message.SessionUUID != null && Object.hasOwnProperty.call(message, "SessionUUID")) + writer.uint32(/* id 20, wireType 2 =*/162).string(message.SessionUUID); return writer; }; @@ -99696,6 +99707,10 @@ export const query = $root.query = (() => { message.in_dml_execution = reader.bool(); break; } + case 20: { + message.SessionUUID = reader.string(); + break; + } default: reader.skipType(tag & 7); break; @@ -99830,6 +99845,9 @@ export const query = $root.query = (() => { if (message.in_dml_execution != null && message.hasOwnProperty("in_dml_execution")) if (typeof message.in_dml_execution !== "boolean") return "in_dml_execution: boolean expected"; + if (message.SessionUUID != null && message.hasOwnProperty("SessionUUID")) + if (!$util.isString(message.SessionUUID)) + return "SessionUUID: string expected"; return null; }; @@ -100046,6 +100064,8 @@ export const query = $root.query = (() => { message.fetch_last_insert_id = Boolean(object.fetch_last_insert_id); if (object.in_dml_execution != null) message.in_dml_execution = Boolean(object.in_dml_execution); + if (object.SessionUUID != null) + message.SessionUUID = String(object.SessionUUID); return message; }; @@ -100082,6 +100102,7 @@ export const query = $root.query = (() => { object.priority = ""; object.fetch_last_insert_id = false; object.in_dml_execution = false; + object.SessionUUID = ""; } if (message.included_fields != null && message.hasOwnProperty("included_fields")) object.included_fields = options.enums === String ? $root.query.ExecuteOptions.IncludedFields[message.included_fields] === undefined ? message.included_fields : $root.query.ExecuteOptions.IncludedFields[message.included_fields] : message.included_fields; @@ -100125,6 +100146,8 @@ export const query = $root.query = (() => { object.fetch_last_insert_id = message.fetch_last_insert_id; if (message.in_dml_execution != null && message.hasOwnProperty("in_dml_execution")) object.in_dml_execution = message.in_dml_execution; + if (message.SessionUUID != null && message.hasOwnProperty("SessionUUID")) + object.SessionUUID = message.SessionUUID; return object; }; From 98a7c8ba6b15762534847e79c8c0971c7d59992c Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 28 Aug 2025 06:32:43 -0400 Subject: [PATCH 25/67] Fix new types Signed-off-by: Mohamed Hamza --- go/vt/vtexplain/vtexplain_vtgate.go | 2 +- go/vt/vtexplain/vtexplain_vttablet.go | 3 +- go/vt/vtgate/tabletgateway.go | 2 +- .../queryservice/fakes/error_query_service.go | 5 +- go/vt/vttablet/queryservice/wrapped.go | 64 +++++++++---------- 5 files changed, 37 insertions(+), 39 deletions(-) diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index 7c838da4330..28bcaa6ea6d 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -86,7 +86,7 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, ts *topo.Server, v } func (vte *VTExplain) newFakeResolver(ctx context.Context, opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver { - gw := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell) + gw, _ := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell) _ = gw.WaitForTablets(ctx, []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) txMode := vtgatepb.TransactionMode_MULTI diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 7918162192a..7c475f2b12a 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -127,7 +127,7 @@ func (vte *VTExplain) newTablet(ctx context.Context, env *vtenv.Environment, opt tablet.QueryService = queryservice.Wrap( nil, - func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { + func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts *queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { return fmt.Errorf("explainTablet does not implement %s", name) }, ) @@ -408,7 +408,6 @@ func newTabletEnvironment(ddls []sqlparser.DDLStatement, opts *Options, collatio Rows: [][]sqltypes.Value{}, }, "create table if not exists `_vt`.dt_participant(\n dtid varbinary(512),\n\tid bigint,\n\tkeyspace varchar(256),\n\tshard varchar(256),\n primary key(dtid, id)\n\t) engine=InnoDB": { - Fields: []*querypb.Field{{ Type: sqltypes.Uint64, Charset: collations.CollationBinaryID, diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 30d13dee29d..77531555e9b 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -460,7 +460,7 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // withShardError adds shard information to errors returned from the inner QueryService. func (gw *TabletGateway) withShardError(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, - _ string, _ bool, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), + _ string, _ *queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), ) error { _, err := inner(ctx, target, conn) return NewShardError(err, target) diff --git a/go/vt/vttablet/queryservice/fakes/error_query_service.go b/go/vt/vttablet/queryservice/fakes/error_query_service.go index ff0eef9ea59..bc10b6f9e7e 100644 --- a/go/vt/vttablet/queryservice/fakes/error_query_service.go +++ b/go/vt/vttablet/queryservice/fakes/error_query_service.go @@ -17,9 +17,8 @@ limitations under the License. package fakes import ( - "fmt" - "context" + "fmt" "vitess.io/vitess/go/vt/vttablet/queryservice" @@ -29,7 +28,7 @@ import ( // ErrorQueryService is an object that returns an error for all methods. var ErrorQueryService = queryservice.Wrap( nil, - func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, inTransaction bool, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { + func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts *queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { return fmt.Errorf("ErrorQueryService does not implement any method") }, ) diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index 18c764cf91e..9e545dc529a 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -38,9 +38,9 @@ type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QuerySer // WrapOpts is the options passed to [WrapperFunc]. type WrapOpts struct { - inTransaction bool + InTransaction bool - options querypb.ExecuteOptions + Options querypb.ExecuteOptions } // Wrap returns a wrapped version of the original QueryService implementation. @@ -117,7 +117,7 @@ type wrappedService struct { } func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state TransactionState, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "Begin", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.Begin(ctx, target, options) @@ -128,7 +128,7 @@ func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, opt func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Commit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Commit(ctx, target, transactionID) @@ -142,7 +142,7 @@ func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, tr func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Rollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Rollback(ctx, target, transactionID) @@ -155,7 +155,7 @@ func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) error { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Prepare", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.Prepare(ctx, target, transactionID, dtid) return canRetry(ctx, innerErr), innerErr @@ -164,7 +164,7 @@ func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, t } func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Target, dtid string) (err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "CommitPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CommitPrepared(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr @@ -173,7 +173,7 @@ func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Ta } func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb.Target, dtid string, originalID int64) (err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "RollbackPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.RollbackPrepared(ctx, target, dtid, originalID) return canRetry(ctx, innerErr), innerErr @@ -182,7 +182,7 @@ func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb. } func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb.Target, dtid string, participants []*querypb.Target) (err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "CreateTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CreateTransaction(ctx, target, dtid, participants) return canRetry(ctx, innerErr), innerErr @@ -191,7 +191,7 @@ func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb } func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (state querypb.StartCommitState, err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "StartCommit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.StartCommit(ctx, target, transactionID, dtid) @@ -201,7 +201,7 @@ func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Target, dtid string, transactionID int64) (err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "SetRollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.SetRollback(ctx, target, dtid, transactionID) return canRetry(ctx, innerErr), innerErr @@ -210,7 +210,7 @@ func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *querypb.Target, dtid string) (err error) { - opts := &WrapOpts{inTransaction: true} + opts := &WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "ConcludeTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.ConcludeTransaction(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr @@ -219,7 +219,7 @@ func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *query } func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.Target, dtid string) (metadata *querypb.TransactionMetadata, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReadTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error metadata, innerErr = conn.ReadTransaction(ctx, target, dtid) @@ -229,7 +229,7 @@ func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.T } func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) (transactions []*querypb.TransactionMetadata, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "UnresolvedTransactions", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error transactions, innerErr = conn.UnresolvedTransactions(ctx, target, abandonAgeSeconds) @@ -240,7 +240,7 @@ func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *qu func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (qr *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "Execute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error qr, innerErr = conn.Execute(ctx, target, query, bindVars, transactionID, reservedID, options) @@ -254,7 +254,7 @@ func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, q // StreamExecute implements the QueryService interface func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { streamingStarted := false innerErr := conn.StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, func(qr *sqltypes.Result) error { @@ -270,7 +270,7 @@ func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Tar func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state TransactionState, qr *sqltypes.Result, err error) { inDedicatedConn := reservedID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, qr, innerErr = conn.BeginExecute(ctx, target, preQueries, query, bindVars, reservedID, options) @@ -282,7 +282,7 @@ func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Targ // BeginStreamExecute implements the QueryService interface func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state TransactionState, err error) { inDedicatedConn := reservedID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) @@ -292,7 +292,7 @@ func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *queryp } func (ws *wrappedService) MessageStream(ctx context.Context, target *querypb.Target, name string, callback func(*sqltypes.Result) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, target, ws.impl, "MessageStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.MessageStream(ctx, target, name, callback) return canRetry(ctx, innerErr), innerErr @@ -300,7 +300,7 @@ func (ws *wrappedService) MessageStream(ctx context.Context, target *querypb.Tar } func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target, name string, ids []*querypb.Value) (count int64, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "MessageAck", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error count, innerErr = conn.MessageAck(ctx, target, name, ids) @@ -310,7 +310,7 @@ func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target } func (ws *wrappedService) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStream(ctx, request, send) return false, innerErr @@ -318,7 +318,7 @@ func (ws *wrappedService) VStream(ctx context.Context, request *binlogdatapb.VSt } func (ws *wrappedService) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStreamRows", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamRows(ctx, request, send) return false, innerErr @@ -326,7 +326,7 @@ func (ws *wrappedService) VStreamRows(ctx context.Context, request *binlogdatapb } func (ws *wrappedService) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(response *binlogdatapb.VStreamTablesResponse) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStreamTables", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamTables(ctx, request, send) return false, innerErr @@ -334,7 +334,7 @@ func (ws *wrappedService) VStreamTables(ctx context.Context, request *binlogdata } func (ws *wrappedService) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, target, ws.impl, "VStreamResults", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamResults(ctx, target, query, send) return false, innerErr @@ -342,7 +342,7 @@ func (ws *wrappedService) VStreamResults(ctx context.Context, target *querypb.Ta } func (ws *wrappedService) StreamHealth(ctx context.Context, callback func(*querypb.StreamHealthResponse) error) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, nil, ws.impl, "StreamHealth", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.StreamHealth(ctx, callback) return canRetry(ctx, innerErr), innerErr @@ -355,7 +355,7 @@ func (ws *wrappedService) HandlePanic(err *error) { // ReserveBeginExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state ReservedTransactionState, res *sqltypes.Result, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) @@ -367,7 +367,7 @@ func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *query // ReserveBeginStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) @@ -379,7 +379,7 @@ func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target // ReserveExecute implements the QueryService interface func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state ReservedState, res *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) @@ -392,7 +392,7 @@ func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Ta // ReserveStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedState, err error) { inDedicatedConn := transactionID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) @@ -403,7 +403,7 @@ func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *quer func (ws *wrappedService) Release(ctx context.Context, target *querypb.Target, transactionID, reservedID int64) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{inTransaction: inDedicatedConn} + opts := &WrapOpts{InTransaction: inDedicatedConn} return ws.wrapper(ctx, target, ws.impl, "Release", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Release. return false, conn.Release(ctx, target, transactionID, reservedID) @@ -411,7 +411,7 @@ func (ws *wrappedService) Release(ctx context.Context, target *querypb.Target, t } func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) (err error) { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "GetSchema", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.GetSchema(ctx, target, tableType, tableNames, callback) return canRetry(ctx, innerErr), innerErr @@ -420,7 +420,7 @@ func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Close(ctx context.Context) error { - opts := &WrapOpts{inTransaction: false} + opts := &WrapOpts{InTransaction: false} return ws.wrapper(ctx, nil, ws.impl, "Close", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Close. return false, conn.Close(ctx) From f59a14f908d70aabc7d76463c466d030aadb5122 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 28 Aug 2025 07:19:38 -0400 Subject: [PATCH 26/67] Fix tests Signed-off-by: Mohamed Hamza --- go/vt/vtgate/legacy_scatter_conn_test.go | 2 +- go/vt/vtgate/tabletgateway_flaky_test.go | 6 +++--- go/vt/vtgate/tabletgateway_test.go | 10 +++++----- go/vt/vtgate/vstream_manager_test.go | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/go/vt/vtgate/legacy_scatter_conn_test.go b/go/vt/vtgate/legacy_scatter_conn_test.go index ec345e4308e..9ff2ddcf1d3 100644 --- a/go/vt/vtgate/legacy_scatter_conn_test.go +++ b/go/vt/vtgate/legacy_scatter_conn_test.go @@ -621,7 +621,7 @@ func newTestScatterConn(ctx context.Context, hc discovery.HealthCheck, serv srvt // The topo.Server is used to start watching the cells described // in '-cells_to_watch' command line parameter, which is // empty by default. So it's unused in this test, set to nil. - gw := NewTabletGateway(ctx, hc, serv, cell) + gw, _ := NewTabletGateway(ctx, hc, serv, cell) tc := NewTxConn(gw, &StaticConfig{ TxMode: vtgatepb.TransactionMode_MULTI, }) diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index 124997bea9e..f8ab3f45934 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -59,7 +59,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // add a primary tablet which is serving @@ -162,7 +162,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // add a primary tablet which is serving @@ -292,7 +292,7 @@ func TestInconsistentStateDetectedBuffering(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) tg.retryCount = 0 diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index b318cb84981..7b65e6a947a 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -112,7 +112,7 @@ func TestTabletGatewayShuffleTablets(t *testing.T) { hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg := NewTabletGateway(ctx, hc, ts, "local") + tg, _ := NewTabletGateway(ctx, hc, ts, "local") defer tg.Close(ctx) ts1 := &discovery.TabletHealth{ @@ -186,7 +186,7 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) { } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) _ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) @@ -221,7 +221,7 @@ func testTabletGatewayGenericHelper(t *testing.T, ctx context.Context, f func(ct } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // no tablet want := []string{"target: ks.0.replica", `no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA`} @@ -309,7 +309,7 @@ func testTabletGatewayTransact(t *testing.T, ctx context.Context, f func(ctx con } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg := NewTabletGateway(ctx, hc, ts, "cell") + tg, _ := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // retry error - no retry @@ -350,7 +350,7 @@ func verifyShardErrors(t *testing.T, err error, wantErrors []string, wantCode vt // TestWithRetry tests the functionality of withRetry function in different circumstances. func TestWithRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &econtext.FakeTopoServer{}, "cell") + tg, _ := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &econtext.FakeTopoServer{}, "cell") tg.kev = discovery.NewKeyspaceEventWatcher(ctx, tg.srvTopoServer, tg.hc, tg.localCell) defer func() { cancel() diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index f5af68806cb..2527c63832e 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -1984,7 +1984,7 @@ func TestVStreamManagerHealthCheckResponseHandling(t *testing.T) { } func newTestVStreamManager(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, cell string) *vstreamManager { - gw := NewTabletGateway(ctx, hc, serv, cell) + gw, _ := NewTabletGateway(ctx, hc, serv, cell) srvResolver := srvtopo.NewResolver(serv, gw, cell) return newVStreamManager(srvResolver, serv, cell) } From 2620b76320e3e0055168a49519072cece5fcc925 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 28 Aug 2025 19:13:57 -0400 Subject: [PATCH 27/67] Pass `WrapOpts` by value Signed-off-by: Mohamed Hamza --- go/vt/vtexplain/vtexplain_vttablet.go | 2 +- go/vt/vtgate/tabletgateway.go | 6 +- go/vt/vtgate/tabletgateway_test.go | 2 +- .../queryservice/fakes/error_query_service.go | 2 +- go/vt/vttablet/queryservice/wrapped.go | 62 +++++++++---------- 5 files changed, 37 insertions(+), 37 deletions(-) diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 7c475f2b12a..cf33b3acd95 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -127,7 +127,7 @@ func (vte *VTExplain) newTablet(ctx context.Context, env *vtenv.Environment, opt tablet.QueryService = queryservice.Wrap( nil, - func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts *queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { + func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { return fmt.Errorf("explainTablet does not implement %s", name) }, ) diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 77531555e9b..53758cd75ad 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -304,7 +304,7 @@ func (gw *TabletGateway) DebugBalancerHandler(w http.ResponseWriter, r *http.Req // withRetry also adds shard information to errors returned from the inner QueryService, so // withShardError should not be combined with withRetry. func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, _ queryservice.QueryService, - _ string, opts *queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), + _ string, opts queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), ) error { // for transactions, we connect to a specific tablet instead of letting gateway choose one if opts.InTransaction && target.TabletType != topodatapb.TabletType_PRIMARY { @@ -418,7 +418,7 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, // getBalancerTablet selects a tablet for the given query target, using the configured balancer if enabled. Otherwise, it will // select a random tablet, with preference to the local cell. -func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts *queryservice.WrapOpts) *discovery.TabletHealth { +func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts queryservice.WrapOpts) *discovery.TabletHealth { var tablet *discovery.TabletHealth useBalancer := balancerEnabled @@ -460,7 +460,7 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // withShardError adds shard information to errors returned from the inner QueryService. func (gw *TabletGateway) withShardError(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, - _ string, _ *queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), + _ string, _ queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), ) error { _, err := inner(ctx, target, conn) return NewShardError(err, target) diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index 7b65e6a947a..0ecca0a8348 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -392,7 +392,7 @@ func TestWithRetry(t *testing.T) { } for _, tt := range testcases { t.Run(tt.name, func(t *testing.T) { - err := tg.withRetry(ctx, tt.target, nil, "", tt.inTransaction, tt.inner) + err := tg.withRetry(ctx, tt.target, nil, "", queryservice.WrapOpts{InTransaction: tt.inTransaction}, tt.inner) if tt.expectedErr == "" { require.NoError(t, err) } else { diff --git a/go/vt/vttablet/queryservice/fakes/error_query_service.go b/go/vt/vttablet/queryservice/fakes/error_query_service.go index bc10b6f9e7e..97d8d505e3b 100644 --- a/go/vt/vttablet/queryservice/fakes/error_query_service.go +++ b/go/vt/vttablet/queryservice/fakes/error_query_service.go @@ -28,7 +28,7 @@ import ( // ErrorQueryService is an object that returns an error for all methods. var ErrorQueryService = queryservice.Wrap( nil, - func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts *queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { + func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService, name string, opts queryservice.WrapOpts, inner func(context.Context, *querypb.Target, queryservice.QueryService) (bool, error)) error { return fmt.Errorf("ErrorQueryService does not implement any method") }, ) diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index 9e545dc529a..205510b02b4 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -34,7 +34,7 @@ var _ QueryService = &wrappedService{} // The inner function returns err and canRetry. // If canRetry is true, the error is specific to the current vttablet and can be retried elsewhere. // The flag will be false if there was no error. -type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QueryService, name string, opts *WrapOpts, inner func(context.Context, *querypb.Target, QueryService) (canRetry bool, err error)) error +type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QueryService, name string, opts WrapOpts, inner func(context.Context, *querypb.Target, QueryService) (canRetry bool, err error)) error // WrapOpts is the options passed to [WrapperFunc]. type WrapOpts struct { @@ -117,7 +117,7 @@ type wrappedService struct { } func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state TransactionState, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "Begin", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.Begin(ctx, target, options) @@ -128,7 +128,7 @@ func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, opt func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Commit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Commit(ctx, target, transactionID) @@ -142,7 +142,7 @@ func (ws *wrappedService) Commit(ctx context.Context, target *querypb.Target, tr func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) { var rID int64 - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Rollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error rID, innerErr = conn.Rollback(ctx, target, transactionID) @@ -155,7 +155,7 @@ func (ws *wrappedService) Rollback(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) error { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err := ws.wrapper(ctx, target, ws.impl, "Prepare", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.Prepare(ctx, target, transactionID, dtid) return canRetry(ctx, innerErr), innerErr @@ -164,7 +164,7 @@ func (ws *wrappedService) Prepare(ctx context.Context, target *querypb.Target, t } func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Target, dtid string) (err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "CommitPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CommitPrepared(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr @@ -173,7 +173,7 @@ func (ws *wrappedService) CommitPrepared(ctx context.Context, target *querypb.Ta } func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb.Target, dtid string, originalID int64) (err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "RollbackPrepared", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.RollbackPrepared(ctx, target, dtid, originalID) return canRetry(ctx, innerErr), innerErr @@ -182,7 +182,7 @@ func (ws *wrappedService) RollbackPrepared(ctx context.Context, target *querypb. } func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb.Target, dtid string, participants []*querypb.Target) (err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "CreateTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.CreateTransaction(ctx, target, dtid, participants) return canRetry(ctx, innerErr), innerErr @@ -191,7 +191,7 @@ func (ws *wrappedService) CreateTransaction(ctx context.Context, target *querypb } func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Target, transactionID int64, dtid string) (state querypb.StartCommitState, err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "StartCommit", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.StartCommit(ctx, target, transactionID, dtid) @@ -201,7 +201,7 @@ func (ws *wrappedService) StartCommit(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Target, dtid string, transactionID int64) (err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "SetRollback", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.SetRollback(ctx, target, dtid, transactionID) return canRetry(ctx, innerErr), innerErr @@ -210,7 +210,7 @@ func (ws *wrappedService) SetRollback(ctx context.Context, target *querypb.Targe } func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *querypb.Target, dtid string) (err error) { - opts := &WrapOpts{InTransaction: true} + opts := WrapOpts{InTransaction: true} err = ws.wrapper(ctx, target, ws.impl, "ConcludeTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.ConcludeTransaction(ctx, target, dtid) return canRetry(ctx, innerErr), innerErr @@ -219,7 +219,7 @@ func (ws *wrappedService) ConcludeTransaction(ctx context.Context, target *query } func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.Target, dtid string) (metadata *querypb.TransactionMetadata, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReadTransaction", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error metadata, innerErr = conn.ReadTransaction(ctx, target, dtid) @@ -229,7 +229,7 @@ func (ws *wrappedService) ReadTransaction(ctx context.Context, target *querypb.T } func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) (transactions []*querypb.TransactionMetadata, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "UnresolvedTransactions", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error transactions, innerErr = conn.UnresolvedTransactions(ctx, target, abandonAgeSeconds) @@ -240,7 +240,7 @@ func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *qu func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (qr *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "Execute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error qr, innerErr = conn.Execute(ctx, target, query, bindVars, transactionID, reservedID, options) @@ -254,7 +254,7 @@ func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, q // StreamExecute implements the QueryService interface func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { streamingStarted := false innerErr := conn.StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, func(qr *sqltypes.Result) error { @@ -270,7 +270,7 @@ func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Tar func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state TransactionState, qr *sqltypes.Result, err error) { inDedicatedConn := reservedID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, qr, innerErr = conn.BeginExecute(ctx, target, preQueries, query, bindVars, reservedID, options) @@ -282,7 +282,7 @@ func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Targ // BeginStreamExecute implements the QueryService interface func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state TransactionState, err error) { inDedicatedConn := reservedID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) @@ -292,7 +292,7 @@ func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *queryp } func (ws *wrappedService) MessageStream(ctx context.Context, target *querypb.Target, name string, callback func(*sqltypes.Result) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, target, ws.impl, "MessageStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.MessageStream(ctx, target, name, callback) return canRetry(ctx, innerErr), innerErr @@ -300,7 +300,7 @@ func (ws *wrappedService) MessageStream(ctx context.Context, target *querypb.Tar } func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target, name string, ids []*querypb.Value) (count int64, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "MessageAck", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error count, innerErr = conn.MessageAck(ctx, target, name, ids) @@ -310,7 +310,7 @@ func (ws *wrappedService) MessageAck(ctx context.Context, target *querypb.Target } func (ws *wrappedService) VStream(ctx context.Context, request *binlogdatapb.VStreamRequest, send func([]*binlogdatapb.VEvent) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStream", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStream(ctx, request, send) return false, innerErr @@ -318,7 +318,7 @@ func (ws *wrappedService) VStream(ctx context.Context, request *binlogdatapb.VSt } func (ws *wrappedService) VStreamRows(ctx context.Context, request *binlogdatapb.VStreamRowsRequest, send func(*binlogdatapb.VStreamRowsResponse) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStreamRows", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamRows(ctx, request, send) return false, innerErr @@ -326,7 +326,7 @@ func (ws *wrappedService) VStreamRows(ctx context.Context, request *binlogdatapb } func (ws *wrappedService) VStreamTables(ctx context.Context, request *binlogdatapb.VStreamTablesRequest, send func(response *binlogdatapb.VStreamTablesResponse) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, request.Target, ws.impl, "VStreamTables", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamTables(ctx, request, send) return false, innerErr @@ -334,7 +334,7 @@ func (ws *wrappedService) VStreamTables(ctx context.Context, request *binlogdata } func (ws *wrappedService) VStreamResults(ctx context.Context, target *querypb.Target, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, target, ws.impl, "VStreamResults", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.VStreamResults(ctx, target, query, send) return false, innerErr @@ -342,7 +342,7 @@ func (ws *wrappedService) VStreamResults(ctx context.Context, target *querypb.Ta } func (ws *wrappedService) StreamHealth(ctx context.Context, callback func(*querypb.StreamHealthResponse) error) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, nil, ws.impl, "StreamHealth", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.StreamHealth(ctx, callback) return canRetry(ctx, innerErr), innerErr @@ -355,7 +355,7 @@ func (ws *wrappedService) HandlePanic(err *error) { // ReserveBeginExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state ReservedTransactionState, res *sqltypes.Result, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) @@ -367,7 +367,7 @@ func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *query // ReserveBeginStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) @@ -379,7 +379,7 @@ func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target // ReserveExecute implements the QueryService interface func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state ReservedState, res *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) @@ -392,7 +392,7 @@ func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Ta // ReserveStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedState, err error) { inDedicatedConn := transactionID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) @@ -403,7 +403,7 @@ func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *quer func (ws *wrappedService) Release(ctx context.Context, target *querypb.Target, transactionID, reservedID int64) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := &WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn} return ws.wrapper(ctx, target, ws.impl, "Release", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Release. return false, conn.Release(ctx, target, transactionID, reservedID) @@ -411,7 +411,7 @@ func (ws *wrappedService) Release(ctx context.Context, target *querypb.Target, t } func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, tableType querypb.SchemaTableType, tableNames []string, callback func(schemaRes *querypb.GetSchemaResponse) error) (err error) { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} err = ws.wrapper(ctx, target, ws.impl, "GetSchema", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { innerErr := conn.GetSchema(ctx, target, tableType, tableNames, callback) return canRetry(ctx, innerErr), innerErr @@ -420,7 +420,7 @@ func (ws *wrappedService) GetSchema(ctx context.Context, target *querypb.Target, } func (ws *wrappedService) Close(ctx context.Context) error { - opts := &WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false} return ws.wrapper(ctx, nil, ws.impl, "Close", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { // No point retrying Close. return false, conn.Close(ctx) From e62c88f5a475963729b1135efe01ea00bb2cd07d Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 28 Aug 2025 19:21:01 -0400 Subject: [PATCH 28/67] Pass `PickOpts` by value Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 4 ++-- go/vt/vtgate/balancer/balancer_test.go | 8 ++++---- go/vt/vtgate/balancer/session.go | 4 ++-- go/vt/vtgate/balancer/session_test.go | 8 ++++---- go/vt/vtgate/tabletgateway.go | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index ff49c69add6..dc1620495f7 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -90,7 +90,7 @@ converge on the desired balanced query load. type TabletBalancer interface { // Pick is the main entry point to the balancer. Returns the best tablet out of the list // for a given query to maintain the desired balanced allocation over multiple executions. - Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth + Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth // DebugHandler provides a summary of tablet balancer state DebugHandler(w http.ResponseWriter, r *http.Request) @@ -176,7 +176,7 @@ func (b *tabletBalancer) DebugHandler(w http.ResponseWriter, _ *http.Request) { // Given the total allocation for the set of tablets, choose the best target // by a weighted random sample so that over time the system will achieve the // desired balanced allocation. -func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ *PickOpts) *discovery.TabletHealth { +func (b *tabletBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ PickOpts) *discovery.TabletHealth { numTablets := len(tablets) if numTablets == 0 { return nil diff --git a/go/vt/vtgate/balancer/balancer_test.go b/go/vt/vtgate/balancer/balancer_test.go index 78b6e708d50..ce07fc2f525 100644 --- a/go/vt/vtgate/balancer/balancer_test.go +++ b/go/vt/vtgate/balancer/balancer_test.go @@ -298,7 +298,7 @@ func TestBalancedPick(t *testing.T) { b := NewTabletBalancer(localCell, vtGateCells).(*tabletBalancer) for i := 0; i < N/len(vtGateCells); i++ { - th := b.Pick(target, tablets, nil) + th := b.Pick(target, tablets, PickOpts{}) if i == 0 { t.Logf("Target Flows %v, Balancer: %s\n", expectedPerCell, b.print()) } @@ -336,7 +336,7 @@ func TestTopologyChanged(t *testing.T) { tablets = tablets[0:2] for i := 0; i < N; i++ { - th := b.Pick(target, tablets, nil) + th := b.Pick(target, tablets, PickOpts{}) allocation, totalAllocation := b.getAllocation(target, tablets) assert.Equalf(t, ALLOCATION/2, totalAllocation, "totalAllocation mismatch %s", b.print()) @@ -346,7 +346,7 @@ func TestTopologyChanged(t *testing.T) { // Run again with the full topology. Now traffic should go to cell b for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil) + th := b.Pick(target, allTablets, PickOpts{}) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -359,7 +359,7 @@ func TestTopologyChanged(t *testing.T) { newTablet := createTestTablet("b") allTablets[2] = newTablet for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, nil) + th := b.Pick(target, allTablets, PickOpts{}) allocation, totalAllocation := b.getAllocation(target, allTablets) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 5df64268ac6..1ff9696fe2b 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -93,8 +93,8 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop // // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. -func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts *PickOpts) *discovery.TabletHealth { - if opts == nil { +func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { + if opts.SessionUUID == "" { return nil } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 611a646f8aa..5102b39c5f8 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -508,8 +508,8 @@ func TestPickNoOpts(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(10 * time.Millisecond) - // Test with nil opts - result := b.Pick(target, nil, nil) + // Test with empty opts + result := b.Pick(target, nil, PickOpts{}) require.Nil(t, result) } @@ -681,6 +681,6 @@ func TestTabletTypesToWatch(t *testing.T) { } } -func buildOpts(uuid string) *PickOpts { - return &PickOpts{SessionUUID: uuid} +func buildOpts(uuid string) PickOpts { + return PickOpts{SessionUUID: uuid} } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 53758cd75ad..ed6aec2fef6 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -436,7 +436,7 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di }) } - opts := &balancer.PickOpts{SessionUUID: opts.Options.SessionUUID, InvalidTablets: invalidTablets} + opts := balancer.PickOpts{SessionUUID: opts.Options.SessionUUID, InvalidTablets: invalidTablets} tablet = gw.balancer.Pick(target, tablets, opts) } From 0db53cde5d2b8518ea33290e33584bb89cbdd8cb Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 28 Aug 2025 19:27:48 -0400 Subject: [PATCH 29/67] Change `ExecuteOptions` in `WrapOpts` as a pointer Signed-off-by: Mohamed Hamza --- go/vt/vtgate/tabletgateway.go | 7 ++++++- go/vt/vttablet/queryservice/wrapped.go | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index ed6aec2fef6..2dd77e34a11 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -436,7 +436,12 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di }) } - opts := balancer.PickOpts{SessionUUID: opts.Options.SessionUUID, InvalidTablets: invalidTablets} + var sessionUUID string + if opts.Options != nil { + sessionUUID = opts.Options.SessionUUID + } + + opts := balancer.PickOpts{SessionUUID: sessionUUID, InvalidTablets: invalidTablets} tablet = gw.balancer.Pick(target, tablets, opts) } diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index 205510b02b4..a00a6897617 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -40,7 +40,7 @@ type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QuerySer type WrapOpts struct { InTransaction bool - Options querypb.ExecuteOptions + Options *querypb.ExecuteOptions } // Wrap returns a wrapped version of the original QueryService implementation. From cde392c5fae001d445515f8532e82ed5d299fac8 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 29 Aug 2025 10:31:42 -0400 Subject: [PATCH 30/67] Fix `balancer-type` help text Signed-off-by: Mohamed Hamza --- go/flags/endtoend/vtgate.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index 9432be8e452..2d688754540 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -29,8 +29,8 @@ Flags: --allowed-tablet-types strings Specifies the tablet types this vtgate is allowed to route queries to. Should be provided as a comma-separated set of tablet types. --alsologtostderr log to standard error as well as files --balancer-keyspaces strings When in balanced mode, a comma-separated list of keyspaces for which to use the balancer (optional) + --balancer-type string When in balanced mode, selects the type of balancer to use. "balanced" balances connections evenly, "session" pins a connection to a given tablet for its duration. (default: "balanced") (optional) (default "balanced") --balancer-vtgate-cells strings When in balanced mode, a comma-separated list of cells that contain vtgates (required) - --balancer-type When in balanced mode, selects the type of balancer to use. "balanced" balances connections evenly, "session" pins a connection to a given tablet for its duration. (default: "balanced") (optional) --bind-address string Bind address for the server. If empty, the server will listen on all available unicast and anycast IP addresses of the local system. --buffer-drain-concurrency int Maximum number of requests retried simultaneously. More concurrency will increase the load on the PRIMARY vttablet when draining the buffer. (default 1) --buffer-keyspace-shards string If not empty, limit buffering to these entries (comma separated). Entry format: keyspace or keyspace/shard. Requires --enable_buffer=true. From eadf1ae3a0ec0eda981b8486e240d32fb6a0a14a Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 30 Aug 2025 10:38:25 -0400 Subject: [PATCH 31/67] Fix some bugs Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 11 +++++------ go/vt/vttablet/queryservice/wrapped.go | 18 +++++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 1ff9696fe2b..8b79b96dd54 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -26,6 +26,7 @@ import ( "sync" "vitess.io/vitess/go/vt/discovery" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" @@ -57,6 +58,8 @@ type SessionBalancer struct { // NewSessionBalancer creates a new session balancer. func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) (TabletBalancer, error) { + log.Info("session balancer: creating new session balancer") + b := &SessionBalancer{ localCell: localCell, hc: hc, @@ -127,7 +130,6 @@ func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { // watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { - // Start watching health check channel for future tablet health changes for { select { case <-ctx.Done(): @@ -135,12 +137,12 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis return case tablet := <-hcChan: if tablet == nil { - return + continue } // Ignore tablets we aren't supposed to watch if _, ok := tabletTypesToWatch[tablet.Target.TabletType]; !ok { - return + continue } b.onTabletHealthChange(tablet) @@ -184,9 +186,6 @@ func getOrCreateRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tabl // print returns a string representation of the session balancer state for debugging. func (b *SessionBalancer) print() string { - b.mu.RLock() - defer b.mu.RUnlock() - sb := strings.Builder{} sb.WriteString("Local rings:\n") diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index a00a6897617..1837fef1af8 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -117,7 +117,7 @@ type wrappedService struct { } func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state TransactionState, err error) { - opts := WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false, Options: options} err = ws.wrapper(ctx, target, ws.impl, "Begin", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.Begin(ctx, target, options) @@ -240,7 +240,7 @@ func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *qu func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (qr *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err = ws.wrapper(ctx, target, ws.impl, "Execute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error qr, innerErr = conn.Execute(ctx, target, query, bindVars, transactionID, reservedID, options) @@ -254,7 +254,7 @@ func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, q // StreamExecute implements the QueryService interface func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { streamingStarted := false innerErr := conn.StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, func(qr *sqltypes.Result) error { @@ -270,7 +270,7 @@ func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Tar func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state TransactionState, qr *sqltypes.Result, err error) { inDedicatedConn := reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, qr, innerErr = conn.BeginExecute(ctx, target, preQueries, query, bindVars, reservedID, options) @@ -282,7 +282,7 @@ func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Targ // BeginStreamExecute implements the QueryService interface func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state TransactionState, err error) { inDedicatedConn := reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) @@ -355,7 +355,7 @@ func (ws *wrappedService) HandlePanic(err *error) { // ReserveBeginExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state ReservedTransactionState, res *sqltypes.Result, err error) { - opts := WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false, Options: options} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) @@ -367,7 +367,7 @@ func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *query // ReserveBeginStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { - opts := WrapOpts{InTransaction: false} + opts := WrapOpts{InTransaction: false, Options: options} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) @@ -379,7 +379,7 @@ func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target // ReserveExecute implements the QueryService interface func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state ReservedState, res *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error state, res, err = conn.ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) @@ -392,7 +392,7 @@ func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Ta // ReserveStreamExecute implements the QueryService interface func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedState, err error) { inDedicatedConn := transactionID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn} + opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error state, innerErr = conn.ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) From 8fae28c212411497c47272df6faa92e8a435e12a Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 30 Aug 2025 11:56:31 -0400 Subject: [PATCH 32/67] Get cell from tablet alias rather than target Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 8b79b96dd54..7e28e514237 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -158,7 +158,7 @@ func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { defer b.mu.Unlock() var ring *hashRing - if tablet.Target.Cell == b.localCell { + if tablet.Tablet.Alias.Cell == b.localCell { ring = getOrCreateRing(b.localRings, tablet) } else { ring = getOrCreateRing(b.externalRings, tablet) From cd30d9d3e0d2977cd814b764424c8a34d92d8f78 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 30 Aug 2025 11:56:49 -0400 Subject: [PATCH 33/67] Initial e2e test Signed-off-by: Mohamed Hamza --- .../vtgate/sessionbalancer/main_test.go | 79 +++++++++++++++++ .../vtgate/sessionbalancer/session_test.go | 88 +++++++++++++++++++ .../vtgate/sessionbalancer/uschema.sql | 9 ++ 3 files changed, 176 insertions(+) create mode 100644 go/test/endtoend/vtgate/sessionbalancer/main_test.go create mode 100644 go/test/endtoend/vtgate/sessionbalancer/session_test.go create mode 100644 go/test/endtoend/vtgate/sessionbalancer/uschema.sql diff --git a/go/test/endtoend/vtgate/sessionbalancer/main_test.go b/go/test/endtoend/vtgate/sessionbalancer/main_test.go new file mode 100644 index 00000000000..f79588219c3 --- /dev/null +++ b/go/test/endtoend/vtgate/sessionbalancer/main_test.go @@ -0,0 +1,79 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessionbalancer + +import ( + _ "embed" + "flag" + "os" + "testing" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" +) + +var ( + clusterInstance *cluster.LocalProcessCluster + vtParams mysql.ConnParams + uks = "uks" + cell = "test_misc" + + //go:embed uschema.sql + uschemaSQL string +) + +func TestMain(m *testing.M) { + flag.Parse() + + exitCode := func() int { + clusterInstance = cluster.NewCluster(cell, "localhost") + defer clusterInstance.Teardown() + + // Start topo server + err := clusterInstance.StartTopo() + if err != nil { + return 1 + } + + // Enable session balancer in vtgate + clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, + "--enable-balancer", + "--balancer-vtgate-cells", clusterInstance.Cell, + "--balancer-type", "session") + + // Start keyspace with multiple tablets per shard + keyspace := &cluster.Keyspace{ + Name: uks, + SchemaSQL: uschemaSQL, + } + err = clusterInstance.StartUnshardedKeyspace(*keyspace, 2, false) + if err != nil { + return 1 + } + + // Start vtgate + err = clusterInstance.StartVtgate() + if err != nil { + return 1 + } + + vtParams = clusterInstance.GetVTParams(uks) + + return m.Run() + }() + os.Exit(exitCode) +} diff --git a/go/test/endtoend/vtgate/sessionbalancer/session_test.go b/go/test/endtoend/vtgate/sessionbalancer/session_test.go new file mode 100644 index 00000000000..dc0c50a6374 --- /dev/null +++ b/go/test/endtoend/vtgate/sessionbalancer/session_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sessionbalancer + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql" +) + +// TestSessionBalancer validates that the session balancer consistently routes +// queries for the same session to the same tablet. +func TestSessionBalancer(t *testing.T) { + // Get connections that route to different tablets + conn1, conn2, id1, id2 := connections(t) + + defer conn1.Close() + defer conn2.Close() + + // Validate that each connection consistently returns the same server ID + for range 20 { + newID1 := serverID(t, conn1) + require.Equal(t, id1, newID1) + + newID2 := serverID(t, conn2) + require.Equal(t, id2, newID2) + + require.NotEqual(t, newID2, newID1) + } +} + +// connections returns two connections that should route to different servers. +func connections(t *testing.T) (conn1, conn2 *mysql.Conn, id1, id2 string) { + t.Helper() + + // Keep creating connections until we find two that route to different server IDs + vtParams.DbName = uks + "@replica" + for { + conn1, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + + conn2, err = mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + + id1 = serverID(t, conn1) + id2 = serverID(t, conn2) + + // Break if we found connections with different server IDs + if id1 != id2 { + return conn1, conn2, id1, id2 + } + + // If not, close the connections and try again + conn1.Close() + conn2.Close() + } +} + +// serverID runs a `SELECT @@server_id` on the given connection and returns the +// server's ID. +func serverID(t *testing.T, conn *mysql.Conn) string { + t.Helper() + + result1, err := conn.ExecuteFetch("SELECT @@server_uuid", 1, false) + require.NoError(t, err) + + tablet1Bytes, err := result1.Rows[0][0].ToBytes() + require.NoError(t, err) + + return string(tablet1Bytes) +} diff --git a/go/test/endtoend/vtgate/sessionbalancer/uschema.sql b/go/test/endtoend/vtgate/sessionbalancer/uschema.sql new file mode 100644 index 00000000000..017fe6e0178 --- /dev/null +++ b/go/test/endtoend/vtgate/sessionbalancer/uschema.sql @@ -0,0 +1,9 @@ +CREATE TABLE t1 ( + id int PRIMARY KEY, + name varchar(255) +); + +CREATE TABLE t2 ( + id int PRIMARY KEY, + value varchar(255) +); \ No newline at end of file From 37b7e627c2a1c4c431a516a97147eb040b11e698 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sat, 30 Aug 2025 22:07:38 -0400 Subject: [PATCH 34/67] Add more e2e tests Signed-off-by: Mohamed Hamza --- .../vtgate/sessionbalancer/main_test.go | 79 ------- .../vtgate/sessionbalancer/session_test.go | 205 +++++++++++++++--- .../vtgate/sessionbalancer/uschema.sql | 5 - 3 files changed, 180 insertions(+), 109 deletions(-) delete mode 100644 go/test/endtoend/vtgate/sessionbalancer/main_test.go diff --git a/go/test/endtoend/vtgate/sessionbalancer/main_test.go b/go/test/endtoend/vtgate/sessionbalancer/main_test.go deleted file mode 100644 index f79588219c3..00000000000 --- a/go/test/endtoend/vtgate/sessionbalancer/main_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sessionbalancer - -import ( - _ "embed" - "flag" - "os" - "testing" - - "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/test/endtoend/cluster" -) - -var ( - clusterInstance *cluster.LocalProcessCluster - vtParams mysql.ConnParams - uks = "uks" - cell = "test_misc" - - //go:embed uschema.sql - uschemaSQL string -) - -func TestMain(m *testing.M) { - flag.Parse() - - exitCode := func() int { - clusterInstance = cluster.NewCluster(cell, "localhost") - defer clusterInstance.Teardown() - - // Start topo server - err := clusterInstance.StartTopo() - if err != nil { - return 1 - } - - // Enable session balancer in vtgate - clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, - "--enable-balancer", - "--balancer-vtgate-cells", clusterInstance.Cell, - "--balancer-type", "session") - - // Start keyspace with multiple tablets per shard - keyspace := &cluster.Keyspace{ - Name: uks, - SchemaSQL: uschemaSQL, - } - err = clusterInstance.StartUnshardedKeyspace(*keyspace, 2, false) - if err != nil { - return 1 - } - - // Start vtgate - err = clusterInstance.StartVtgate() - if err != nil { - return 1 - } - - vtParams = clusterInstance.GetVTParams(uks) - - return m.Run() - }() - os.Exit(exitCode) -} diff --git a/go/test/endtoend/vtgate/sessionbalancer/session_test.go b/go/test/endtoend/vtgate/sessionbalancer/session_test.go index dc0c50a6374..281dfa963af 100644 --- a/go/test/endtoend/vtgate/sessionbalancer/session_test.go +++ b/go/test/endtoend/vtgate/sessionbalancer/session_test.go @@ -18,67 +18,201 @@ package sessionbalancer import ( "context" + _ "embed" + "fmt" + "slices" "testing" + "time" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" ) +const ( + cell = "test_misc" + keyspace = "uks" +) + +//go:embed uschema.sql +var uschemaSQL string + +func createCluster(t *testing.T, replicaCount int) (*cluster.LocalProcessCluster, *mysql.ConnParams) { + t.Helper() + + // Create a new clusterInstance + clusterInstance := cluster.NewCluster(cell, "localhost") + + // Start topo server + err := clusterInstance.StartTopo() + require.NoError(t, err, "Failed to start topo server") + + // Enable session balancer in vtgate + clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, + "--enable-balancer", + "--balancer-vtgate-cells", clusterInstance.Cell, + "--balancer-type", "session") + + ks := cluster.Keyspace{ + Name: keyspace, + SchemaSQL: uschemaSQL, + } + err = clusterInstance.StartUnshardedKeyspace(ks, replicaCount, false) + require.NoError(t, err, "Failed to start keyspace") + + err = clusterInstance.StartVtgate() + require.NoError(t, err, "Failed to start vtgate") + + vtParams := clusterInstance.GetVTParams(keyspace) + + return clusterInstance, &vtParams +} + // TestSessionBalancer validates that the session balancer consistently routes // queries for the same session to the same tablet. func TestSessionBalancer(t *testing.T) { + cluster, vtParams := createCluster(t, 2) + defer cluster.Teardown() + + // Get connections that route to different tablets + conns, ids := connections(t, vtParams, "replica", 2) + + for _, conn := range conns { + defer conn.Close() + } + + // Validate that each connection consistently returns the same server ID + for range 20 { + for i, conn := range conns { + id := serverID(t, conn) + require.Equal(t, ids[i], id) + } + } +} + +// TestSessionBalancerRemoveTablet validates that when a tablet is killed, +// connections that were using that tablet get rerouted to remaining tablets. +func TestSessionBalancerRemoveTablet(t *testing.T) { + cluster, vtParams := createCluster(t, 2) + defer cluster.Teardown() + // Get connections that route to different tablets - conn1, conn2, id1, id2 := connections(t) + conns, _ := connections(t, vtParams, "replica", 2) + conn1, conn2 := conns[0], conns[1] defer conn1.Close() defer conn2.Close() - // Validate that each connection consistently returns the same server ID + tablets := tablets(t, cluster, "replica") + require.NotNil(t, tablets) + require.Len(t, tablets, 2) + + err := tablets[0].VttabletProcess.TearDown() + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + for range 20 { newID1 := serverID(t, conn1) - require.Equal(t, id1, newID1) - newID2 := serverID(t, conn2) - require.Equal(t, id2, newID2) - require.NotEqual(t, newID2, newID1) + require.Equal(t, newID1, newID2) + } +} + +// TestSessionBalancerAddTablet validates that when a new tablet is added, +// new connections get routed to the new tablet. +func TestSessionBalancerAddTablet(t *testing.T) { + cluster, vtParams := createCluster(t, 3) + defer cluster.Teardown() + + // Get 3 connections that route to different tablets + conns, ids := connections(t, vtParams, "replica", 3) + for _, conn := range conns { + defer conn.Close() + } + + tablets := tablets(t, cluster, "replica") + require.NotNil(t, tablets) + require.Len(t, tablets, 3) + + // Start with only 2 tablets serving + tablet := tablets[2] + err := tablet.VttabletProcess.TearDown() + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + // Find the connection that moved + var conn *mysql.Conn + for i, c := range conns { + newID := serverID(t, c) + if newID != ids[i] { + conn = c + break + } + } + + require.NotNil(t, conn, "One connection should've moved tablets") + + // Start up the tablet again + err = tablet.RestartOnlyTablet() + require.NoError(t, err) + + time.Sleep(10 * time.Millisecond) + + // All connections should route to the same IDs again + for range 20 { + for i, conn := range conns { + id := ids[i] + newID := serverID(t, conn) + require.Equal(t, id, newID) + } } } -// connections returns two connections that should route to different servers. -func connections(t *testing.T) (conn1, conn2 *mysql.Conn, id1, id2 string) { +// connections returns the specified number of connections that should route to different tablets. +func connections(t *testing.T, vtParams *mysql.ConnParams, tabletType string, numConnections int) ([]*mysql.Conn, []string) { t.Helper() - // Keep creating connections until we find two that route to different server IDs - vtParams.DbName = uks + "@replica" - for { - conn1, err := mysql.Connect(context.Background(), &vtParams) - require.NoError(t, err) + vtParams.DbName = fmt.Sprintf("%s@%s", keyspace, tabletType) + + conns := make([]*mysql.Conn, 0, numConnections) + ids := make([]string, 0, numConnections) - conn2, err = mysql.Connect(context.Background(), &vtParams) + // Keep creating connections until we have the required number with different server IDs + for range 20 { + conn, err := mysql.Connect(context.Background(), vtParams) require.NoError(t, err) - id1 = serverID(t, conn1) - id2 = serverID(t, conn2) + id := serverID(t, conn) + newID := !slices.Contains(ids, id) + + // If we found a new tablet, add it to the list of connections + if newID { + conns = append(conns, conn) + ids = append(ids, id) + + if len(conns) == numConnections { + return conns, ids + } - // Break if we found connections with different server IDs - if id1 != id2 { - return conn1, conn2, id1, id2 + continue } - // If not, close the connections and try again - conn1.Close() - conn2.Close() + conn.Close() } + + t.Fatalf("could not create %d connections with different tablet connections", numConnections) + return nil, nil } -// serverID runs a `SELECT @@server_id` on the given connection and returns the -// server's ID. +// serverID runs a `SELECT @@server_id` on the given connection and returns the server's ID. func serverID(t *testing.T, conn *mysql.Conn) string { t.Helper() - result1, err := conn.ExecuteFetch("SELECT @@server_uuid", 1, false) + result1, err := conn.ExecuteFetch("SELECT @@server_id", 1, false) require.NoError(t, err) tablet1Bytes, err := result1.Rows[0][0].ToBytes() @@ -86,3 +220,24 @@ func serverID(t *testing.T, conn *mysql.Conn) string { return string(tablet1Bytes) } + +func tablets(t *testing.T, clusterInstance *cluster.LocalProcessCluster, tabletType string) []*cluster.Vttablet { + t.Helper() + + if len(clusterInstance.Keyspaces) == 0 { + return nil + } + + if len(clusterInstance.Keyspaces[0].Shards) == 0 { + return nil + } + + tablets := make([]*cluster.Vttablet, 0, 2) + for _, tablet := range clusterInstance.Keyspaces[0].Shards[0].Vttablets { + if tablet.Type == tabletType { + tablets = append(tablets, tablet) + } + } + + return tablets +} diff --git a/go/test/endtoend/vtgate/sessionbalancer/uschema.sql b/go/test/endtoend/vtgate/sessionbalancer/uschema.sql index 017fe6e0178..61b680a618f 100644 --- a/go/test/endtoend/vtgate/sessionbalancer/uschema.sql +++ b/go/test/endtoend/vtgate/sessionbalancer/uschema.sql @@ -2,8 +2,3 @@ CREATE TABLE t1 ( id int PRIMARY KEY, name varchar(255) ); - -CREATE TABLE t2 ( - id int PRIMARY KEY, - value varchar(255) -); \ No newline at end of file From 0a8ce4bc31b2d08e77b12a3fd1ec4bdd86a65e29 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 3 Oct 2025 14:26:47 -0400 Subject: [PATCH 35/67] undo auto fmt Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index dc1620495f7..d5ed58aefeb 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -314,7 +314,7 @@ func (b *tabletBalancer) allocateFlows(allTablets []*discovery.TabletHealth) *ta // to avoid truncating the integer values. shiftFlow := overAllocatedFlow * currentFlow * underAllocatedFlow / a.Inflows[overAllocatedCell] / unbalancedFlow - // fmt.Printf("shift %d %s %s -> %s (over %d current %d in %d under %d unbalanced %d) \n", shiftFlow, vtgateCell, overAllocatedCell, underAllocatedCell, + //fmt.Printf("shift %d %s %s -> %s (over %d current %d in %d under %d unbalanced %d) \n", shiftFlow, vtgateCell, overAllocatedCell, underAllocatedCell, // overAllocatedFlow, currentFlow, a.Inflows[overAllocatedCell], underAllocatedFlow, unbalancedFlow) a.Outflows[vtgateCell][overAllocatedCell] -= shiftFlow From bdc2ba5410ee13f122a2b607ee7c6db8cf783d03 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Sun, 5 Oct 2025 11:39:54 -0400 Subject: [PATCH 36/67] remove tablet from old ring if its target changes Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 48 +++++++++- go/vt/vtgate/balancer/session_test.go | 122 ++++++++++++++++++++++++++ 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 7e28e514237..60ea15a9cc2 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -26,7 +26,6 @@ import ( "sync" "vitess.io/vitess/go/vt/discovery" - "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" @@ -54,17 +53,21 @@ type SessionBalancer struct { // externalRings are the hash rings created for each target. It contains only tablets // external to localCell. externalRings map[discovery.KeyspaceShardTabletType]*hashRing + + // tablets keeps track of the latest state of each tablet, keyed by tablet alias. This is + // used to check whether a tablet's target has changed, and needs to be removed from its old + // hash ring. + tablets map[string]*discovery.TabletHealth } // NewSessionBalancer creates a new session balancer. func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) (TabletBalancer, error) { - log.Info("session balancer: creating new session balancer") - b := &SessionBalancer{ localCell: localCell, hc: hc, localRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), externalRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), + tablets: make(map[string]*discovery.TabletHealth), } // Set up health check subscription @@ -140,6 +143,8 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis continue } + b.removeOldTablet(tablet) + // Ignore tablets we aren't supposed to watch if _, ok := tabletTypesToWatch[tablet.Target.TabletType]; !ok { continue @@ -150,6 +155,39 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis } } +// removeOldTablet removes the entry for a tablet in its old hash ring if its target has changed. For example, if a +// reparent happens and a replica is now a primary, we need to remove it from the replica hash ring. +func (b *SessionBalancer) removeOldTablet(tablet *discovery.TabletHealth) { + alias := tabletAlias(tablet) + prevTablet, ok := b.tablets[alias] + if !ok { + return + } + + prevTarget := prevTablet.Target + + // If this tablet's target changed, remove it from its old hash ring. + targetChanged := prevTarget.TabletType != tablet.Target.TabletType || prevTarget.Keyspace != tablet.Target.Keyspace || prevTarget.Shard != tablet.Target.Shard + if !targetChanged { + return + } + + prevKey := discovery.KeyFromTarget(prevTablet.Target) + + var ring map[discovery.KeyspaceShardTabletType]*hashRing + if tablet.Tablet.Alias.Cell == b.localCell { + ring = b.localRings + } else { + ring = b.externalRings + } + + if ring == nil || ring[prevKey] == nil { + return + } + + ring[prevKey].remove(prevTablet) +} + // onTabletHealthChange is the handler for tablet health events. If a tablet goes into serving, // it is added to the appropriate (local or external) hash ring for its target. If it goes out // of serving, it is removed from the hash ring. @@ -166,8 +204,12 @@ func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { if tablet.Serving { ring.add(tablet) + + alias := tabletAlias(tablet) + b.tablets[alias] = tablet } else { ring.remove(tablet) + delete(b.tablets, tabletAlias(tablet)) } } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 5102b39c5f8..0413e6a2c8e 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -681,6 +681,128 @@ func TestTabletTypesToWatch(t *testing.T) { } } +func TestTabletTargetChanges(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + replica := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, + } + + primary := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + Cell: "local", + }, + Serving: true, + } + + hcChan <- replica + + // Give a moment for the worker to process the tablets + time.Sleep(100 * time.Millisecond) + + require.Len(t, b.localRings, 1, b.print()) + require.Len(t, b.localRings[discovery.KeyFromTarget(replica.Target)].tablets, 1, b.print()) + + require.Len(t, b.externalRings, 0, b.print()) + + // Reparent happens, tablet is now a primary + hcChan <- primary + + // Give a moment for the worker to process the tablets + time.Sleep(100 * time.Millisecond) + + require.Len(t, b.localRings, 1, b.print()) + require.Len(t, b.localRings[discovery.KeyFromTarget(replica.Target)].tablets, 0, b.print()) + + require.Len(t, b.externalRings, 0, b.print()) +} + +func TestExternalTabletTargetChanges(t *testing.T) { + b, hcChan := newSessionBalancer(t) + + replica := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, + } + + primary := &discovery.TabletHealth{ + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_PRIMARY, + Cell: "external", + }, + Serving: true, + } + + hcChan <- replica + + // Give a moment for the worker to process the tablets + time.Sleep(100 * time.Millisecond) + + require.Len(t, b.externalRings, 1, b.print()) + require.Len(t, b.externalRings[discovery.KeyFromTarget(replica.Target)].tablets, 1, b.print()) + + require.Len(t, b.localRings, 0, b.print()) + + // Reparent happens, tablet is now a primary + hcChan <- primary + + // Give a moment for the worker to process the tablets + time.Sleep(100 * time.Millisecond) + + require.Len(t, b.externalRings, 1, b.print()) + require.Len(t, b.externalRings[discovery.KeyFromTarget(replica.Target)].tablets, 0, b.print()) + + require.Len(t, b.localRings, 0, b.print()) +} + func buildOpts(uuid string) PickOpts { return PickOpts{SessionUUID: uuid} } From fa182033907be87acc3171148be5a41925fc8b2e Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 7 Oct 2025 10:21:53 -0400 Subject: [PATCH 37/67] Switch to rendezvous hashing Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/hashring.go | 204 --------------- go/vt/vtgate/balancer/hashring_test.go | 330 ------------------------- go/vt/vtgate/balancer/session.go | 263 +++++++++++--------- go/vt/vtgate/balancer/session_test.go | 61 ++--- 4 files changed, 168 insertions(+), 690 deletions(-) delete mode 100644 go/vt/vtgate/balancer/hashring.go delete mode 100644 go/vt/vtgate/balancer/hashring_test.go diff --git a/go/vt/vtgate/balancer/hashring.go b/go/vt/vtgate/balancer/hashring.go deleted file mode 100644 index 9b5c08d57cb..00000000000 --- a/go/vt/vtgate/balancer/hashring.go +++ /dev/null @@ -1,204 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package balancer - -import ( - "slices" - "sort" - "strconv" - "sync" - - "github.com/cespare/xxhash/v2" - - "vitess.io/vitess/go/vt/discovery" - "vitess.io/vitess/go/vt/topo/topoproto" -) - -// defaultVirtualNodes is the default number of virtual nodes to use in -// the hash ring. -const defaultVirtualNodes = 16 - -// hashRing represents a hash ring of tablets. -type hashRing struct { - mu sync.RWMutex - - // nodes is the sorted list of virtual nodes. - nodes []uint64 - - // numVirtualNodes is the number of virtual nodes each member of - // the hash ring has. - numVirtualNodes int - - // nodeMap is a map from a tablet's hash to the tablet. - nodeMap map[uint64]*discovery.TabletHealth - - // tablets is a "set" of all the tablets currently in the hash ring (by alias). - tablets map[string]struct{} -} - -// newHashRing returns a new hash ring with the default number of virtual nodes. -func newHashRing() *hashRing { - return &hashRing{ - numVirtualNodes: defaultVirtualNodes, - nodeMap: make(map[uint64]*discovery.TabletHealth), - tablets: make(map[string]struct{}), - } -} - -// add adds a tablet to the hash ring. -func (r *hashRing) add(tablet *discovery.TabletHealth) { - if r.contains(tablet) { - return - } - - // Build the tablet's hashes before locking - hashes := make([]uint64, 0, r.numVirtualNodes) - for i := range r.numVirtualNodes { - hash := buildHash(tablet, i) - hashes = append(hashes, hash) - } - - r.mu.Lock() - defer r.mu.Unlock() - - for _, hash := range hashes { - r.nodes = append(r.nodes, hash) - r.nodeMap[hash] = tablet - } - - slices.Sort(r.nodes) - r.tablets[tabletAlias(tablet)] = struct{}{} -} - -// remove removes a tablet from the hash ring. -func (r *hashRing) remove(tablet *discovery.TabletHealth) { - if !r.contains(tablet) { - return - } - - // Build the tablet's hashes before locking - hashes := make(map[uint64]struct{}, r.numVirtualNodes) - for i := range r.numVirtualNodes { - hash := buildHash(tablet, i) - hashes[hash] = struct{}{} - } - - r.mu.Lock() - defer r.mu.Unlock() - - for hash := range hashes { - delete(r.nodeMap, hash) - } - - r.nodes = removeNodes(r.nodes, hashes) - delete(r.tablets, tabletAlias(tablet)) -} - -// get returns the tablet for the given key, ignoring invalid tablets. -func (r *hashRing) get(key string, invalidTablets map[string]bool) *discovery.TabletHealth { - r.mu.RLock() - defer r.mu.RUnlock() - - if len(r.nodes) == 0 { - return nil - } - - hash := xxhash.Sum64String(key) - - // Find the first node greater than or equal to this hash, and isn't invalid - i := sort.Search(len(r.nodes), func(i int) bool { - node := r.nodes[i] - - return !r.invalidNode(node, invalidTablets) && r.nodes[i] >= hash - }) - - // Wrap around if needed - if i == len(r.nodes) { - i = 0 - - // If the first tablet is invalid, it means we couldn't find any valid tablets - node := r.nodes[i] - if r.invalidNode(node, invalidTablets) { - return nil - } - } - - // Return the associated tablet - node := r.nodes[i] - tablet := r.nodeMap[node] - - return tablet -} - -// contains checks if a tablet exists in the hash ring. -func (r *hashRing) contains(tablet *discovery.TabletHealth) bool { - r.mu.RLock() - defer r.mu.RUnlock() - - _, exists := r.tablets[tabletAlias(tablet)] - return exists -} - -// sort sorts the list of nodes. -func (r *hashRing) sort() { - r.mu.Lock() - defer r.mu.Unlock() - - slices.Sort(r.nodes) -} - -// buildHash builds a virtual node hash. -func buildHash(tablet *discovery.TabletHealth, node int) uint64 { - key := tabletAlias(tablet) + "#" + strconv.Itoa(node) - hash := xxhash.Sum64String(key) - - return hash -} - -// tabletAlias returns the tablet's alias as a string. -func tabletAlias(tablet *discovery.TabletHealth) string { - return topoproto.TabletAliasString(tablet.Tablet.Alias) -} - -// removeNodes removes the nodes in the set of hashes from the given list of nodes. -func removeNodes(nodes []uint64, hashes map[uint64]struct{}) []uint64 { - // Update the node list in-place - - writeIdx := 0 - for _, node := range nodes { - // Check if this node belongs to the tablet being removed - _, isTabletNode := hashes[node] - if isTabletNode { - continue - } - - nodes[writeIdx] = node - writeIdx++ - } - - return nodes[:writeIdx] -} - -// invalidNode returns whether the virtual node is associated with an invalid tablet. -func (r *hashRing) invalidNode(node uint64, invalidTablets map[string]bool) bool { - tablet := r.nodeMap[node] - - alias := topoproto.TabletAliasString(tablet.Tablet.Alias) - _, invalid := invalidTablets[alias] - - return invalid -} diff --git a/go/vt/vtgate/balancer/hashring_test.go b/go/vt/vtgate/balancer/hashring_test.go deleted file mode 100644 index 2a3e8c43592..00000000000 --- a/go/vt/vtgate/balancer/hashring_test.go +++ /dev/null @@ -1,330 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package balancer - -import ( - "fmt" - "slices" - "strconv" - "sync" - "testing" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/vt/discovery" - querypb "vitess.io/vitess/go/vt/proto/query" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "vitess.io/vitess/go/vt/topo" -) - -func createTestTabletForHashRing(cell string, uid uint32) *discovery.TabletHealth { - tablet := topo.NewTablet(uid, cell, strconv.FormatUint(uint64(uid), 10)) - tablet.PortMap["vt"] = 1 - tablet.PortMap["grpc"] = 2 - tablet.Keyspace = "test_keyspace" - tablet.Shard = "0" - - return &discovery.TabletHealth{ - Tablet: tablet, - Target: &querypb.Target{Keyspace: "test_keyspace", Shard: "0", TabletType: topodatapb.TabletType_REPLICA}, - Serving: true, - Stats: nil, - PrimaryTermStartTime: 0, - } -} - -func TestNewHashRing(t *testing.T) { - ring := newHashRing() - - require.NotNil(t, ring) - require.Equal(t, defaultVirtualNodes, ring.numVirtualNodes) - require.NotNil(t, ring.nodeMap) - require.Empty(t, ring.nodes) - require.Empty(t, ring.nodeMap) -} - -func TestHashRingAdd(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - ring.add(tablet1) - - require.Len(t, ring.nodes, defaultVirtualNodes) - require.Len(t, ring.nodeMap, defaultVirtualNodes) - - for _, hash := range ring.nodes { - require.Equal(t, tablet1, ring.nodeMap[hash]) - } - - ring.add(tablet2) - - require.Len(t, ring.nodes, 2*defaultVirtualNodes) - require.Len(t, ring.nodeMap, 2*defaultVirtualNodes) - - tablet1Count := 0 - tablet2Count := 0 - for _, tablet := range ring.nodeMap { - switch tablet { - case tablet1: - tablet1Count++ - case tablet2: - tablet2Count++ - } - } - require.Equal(t, defaultVirtualNodes, tablet1Count) - require.Equal(t, defaultVirtualNodes, tablet2Count) -} - -func TestHashRingAddDuplicate(t *testing.T) { - ring := newHashRing() - tablet := createTestTabletForHashRing("cell1", 100) - - ring.add(tablet) - originalLen := len(ring.nodes) - - ring.add(tablet) - - require.Len(t, ring.nodes, originalLen) - require.Len(t, ring.nodeMap, originalLen) -} - -func TestHashRingRemove(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - ring.add(tablet1) - ring.add(tablet2) - - require.Len(t, ring.nodes, 2*defaultVirtualNodes) - - ring.remove(tablet1) - - require.Len(t, ring.nodes, defaultVirtualNodes) - require.Len(t, ring.nodeMap, defaultVirtualNodes) - - for _, tablet := range ring.nodeMap { - require.Equal(t, tablet2, tablet) - } -} - -func TestHashRingRemoveNonExistent(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - ring.add(tablet1) - originalLen := len(ring.nodes) - - ring.remove(tablet2) - - require.Len(t, ring.nodes, originalLen) - require.Len(t, ring.nodeMap, originalLen) -} - -func TestHashRingGet(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - result := ring.get("test_key", nil) - require.Nil(t, result) - - ring.add(tablet1) - ring.sort() - - result = ring.get("test_key", nil) - require.NotNil(t, result) - require.Equal(t, tablet1, result) - - ring.add(tablet2) - ring.sort() - - result = ring.get("test_key", nil) - require.NotNil(t, result) - - // Empirically know that "test_key" hashes closest to tablet2 - require.Equal(t, tablet2, result) -} - -func TestHashRingSort(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - ring.add(tablet1) - ring.add(tablet2) - - originalNodes := make([]uint64, len(ring.nodes)) - copy(originalNodes, ring.nodes) - - ring.sort() - - require.True(t, slices.IsSorted(ring.nodes)) - require.Equal(t, len(originalNodes), len(ring.nodes)) - - for _, node := range originalNodes { - require.Contains(t, ring.nodes, node) - } -} - -func TestBuildHash(t *testing.T) { - tablet := createTestTabletForHashRing("cell1", 100) - - hash1 := buildHash(tablet, 0) - hash2 := buildHash(tablet, 1) - hash3 := buildHash(tablet, 0) - - require.NotEqual(t, hash1, hash2) - require.Equal(t, hash1, hash3) - - tablet2 := createTestTabletForHashRing("cell2", 200) - hash4 := buildHash(tablet2, 0) - - require.NotEqual(t, hash1, hash4) -} - -func TestHashRingAddRemoveSequence(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - tablet3 := createTestTabletForHashRing("cell3", 300) - - ring.add(tablet1) - ring.add(tablet2) - ring.add(tablet3) - ring.sort() - - key := "test_sequence" - - initialTablet := ring.get(key, nil) - require.NotNil(t, initialTablet) - - ring.remove(tablet2) - ring.sort() - - afterRemovalTablet := ring.get(key, nil) - require.NotNil(t, afterRemovalTablet) - - ring.add(tablet2) - ring.sort() - - afterReaddTablet := ring.get(key, nil) - require.NotNil(t, afterReaddTablet) - - require.Equal(t, initialTablet, afterReaddTablet) -} - -func TestHashRingWrapAround(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - - // Create a synthetic scenario where we know the hash will be larger - ring.nodes = []uint64{1000, 2000, 3000} // Small values - ring.nodeMap = make(map[uint64]*discovery.TabletHealth) - ring.nodeMap[1000] = tablet1 - ring.nodeMap[2000] = tablet2 - ring.nodeMap[3000] = tablet1 - - // Any large hash should wrap around to the first node - result := ring.get("this_should_wrap_around_with_large_hash", nil) - require.NotNil(t, result) - require.Contains(t, []*discovery.TabletHealth{tablet1, tablet2}, result) -} - -func TestHashRingRemoveAllTablets(t *testing.T) { - ring := newHashRing() - tablets := make([]*discovery.TabletHealth, 3) - - for i := range 3 { - tablets[i] = createTestTabletForHashRing(fmt.Sprintf("cell%d", i), uint32(100+i)) - ring.add(tablets[i]) - } - - ring.sort() - - for _, tablet := range tablets { - ring.remove(tablet) - } - - require.Empty(t, ring.nodes) - require.Empty(t, ring.nodeMap) - require.Nil(t, ring.get("any_key", nil)) -} - -func TestHashRingMultipleAddSameTablet(t *testing.T) { - ring := newHashRing() - tablet := createTestTabletForHashRing("cell1", 100) - - // Add the same tablet multiple times - for range 5 { - ring.add(tablet) - } - - // Should still only have defaultVirtualNodes entries - require.Len(t, ring.nodes, defaultVirtualNodes) - require.Len(t, ring.nodeMap, defaultVirtualNodes) -} - -func TestHashRingGetAfterRemove(t *testing.T) { - ring := newHashRing() - tablet1 := createTestTabletForHashRing("cell1", 100) - tablet2 := createTestTabletForHashRing("cell2", 200) - tablet3 := createTestTabletForHashRing("cell3", 300) - - ring.add(tablet1) - ring.add(tablet2) - ring.add(tablet3) - ring.sort() - - // Empirically know that this hashes closest to tablet3 - got := ring.get("key", nil) - require.Equal(t, tablet3, got) - - // Remove tablet3 - ring.remove(tablet3) - - got = ring.get("key", nil) - require.NotEqual(t, tablet3, got) -} - -func TestHashRingConcurrentGetOperations(t *testing.T) { - ring := newHashRing() - tablets := make([]*discovery.TabletHealth, 5) - - for i := range 5 { - tablets[i] = createTestTabletForHashRing(fmt.Sprintf("cell%d", i), uint32(100+i)) - ring.add(tablets[i]) - } - ring.sort() - - var wg sync.WaitGroup - numGoroutines := 1000 - wg.Add(numGoroutines) - for i := range numGoroutines { - go func(i int) { - defer wg.Done() - key := fmt.Sprintf("concurrent_key_%d", i) - tablet := ring.get(key, nil) - require.NotNil(t, tablet) - }(i) - } - - wg.Wait() -} diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 60ea15a9cc2..181c4f9a6b8 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -25,10 +25,13 @@ import ( "strings" "sync" + "github.com/cespare/xxhash/v2" + "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/srvtopo" + "vitess.io/vitess/go/vt/topo/topoproto" ) // tabletTypesToWatch are the tablet types that will be included in the hash rings. @@ -46,28 +49,29 @@ type SessionBalancer struct { mu sync.RWMutex - // localRings are the hash rings created for each target. It contains only tablets - // local to localCell. - localRings map[discovery.KeyspaceShardTabletType]*hashRing + // localTablets is a map of tablets in the local cell for each target. + localTablets map[discovery.KeyspaceShardTabletType]TabletSet - // externalRings are the hash rings created for each target. It contains only tablets - // external to localCell. - externalRings map[discovery.KeyspaceShardTabletType]*hashRing + // externalTablets is a map of tablets external to this cell for each target. + externalTablets map[discovery.KeyspaceShardTabletType]TabletSet - // tablets keeps track of the latest state of each tablet, keyed by tablet alias. This is - // used to check whether a tablet's target has changed, and needs to be removed from its old - // hash ring. - tablets map[string]*discovery.TabletHealth + // tablets keeps track of the latest state of each tablet, keyed by alias. This + // is used to remove tablets from old targets when their target changes (a + // PlannedReparentShard for example). + tablets TabletSet } +// TabletSet represents a set of tablets, keyed by alias. +type TabletSet map[string]*discovery.TabletHealth + // NewSessionBalancer creates a new session balancer. func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) (TabletBalancer, error) { b := &SessionBalancer{ - localCell: localCell, - hc: hc, - localRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), - externalRings: make(map[discovery.KeyspaceShardTabletType]*hashRing), - tablets: make(map[string]*discovery.TabletHealth), + localCell: localCell, + hc: hc, + localTablets: make(map[discovery.KeyspaceShardTabletType]TabletSet), + externalTablets: make(map[discovery.KeyspaceShardTabletType]TabletSet), + tablets: make(TabletSet), } // Set up health check subscription @@ -95,43 +99,7 @@ func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtop return b, nil } -// Pick is the main entry point to the balancer. -// -// For a given session, it will return the same tablet for its duration, with preference to tablets -// in the local cell. -func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { - if opts.SessionUUID == "" { - return nil - } - - b.mu.RLock() - defer b.mu.RUnlock() - - // Try to find a tablet in the local cell first - tablet := getFromRing(b.localRings, target, opts.InvalidTablets, opts.SessionUUID) - if tablet != nil { - return tablet - } - - // If we didn't find a tablet in the local cell, try external cells - tablet = getFromRing(b.externalRings, target, opts.InvalidTablets, opts.SessionUUID) - return tablet -} - -// DebugHandler provides a summary of the session balancer state. -func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(w, "Session balancer\n") - fmt.Fprintf(w, "================\n") - fmt.Fprintf(w, "Local cell: %s\n\n", b.localCell) - - b.mu.RLock() - defer b.mu.RUnlock() - - fmt.Fprint(w, b.print()) -} - -// watchHealthCheck watches the health check channel for tablet health changes, and updates hash rings accordingly. +// watchHealthCheck watches the health check channel for tablet health changes and updates the set of tablets accordingly. func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { for { select { @@ -143,6 +111,7 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis continue } + // Remove tablet from old target if it has changed b.removeOldTablet(tablet) // Ignore tablets we aren't supposed to watch @@ -155,114 +124,170 @@ func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *dis } } -// removeOldTablet removes the entry for a tablet in its old hash ring if its target has changed. For example, if a -// reparent happens and a replica is now a primary, we need to remove it from the replica hash ring. -func (b *SessionBalancer) removeOldTablet(tablet *discovery.TabletHealth) { - alias := tabletAlias(tablet) - prevTablet, ok := b.tablets[alias] - if !ok { - return - } - - prevTarget := prevTablet.Target +// removeOldTablet removes the entry for a tablet in its old target if its target has changed. For example, if a +// reparent happens and a replica is now a primary, we need to remove it from the list of tablets for the replica +// target. +func (b *SessionBalancer) removeOldTablet(newTablet *discovery.TabletHealth) { + b.mu.Lock() + defer b.mu.Unlock() - // If this tablet's target changed, remove it from its old hash ring. - targetChanged := prevTarget.TabletType != tablet.Target.TabletType || prevTarget.Keyspace != tablet.Target.Keyspace || prevTarget.Shard != tablet.Target.Shard - if !targetChanged { + alias := tabletAlias(newTablet) + prevTablet, exists := b.tablets[alias] + if !exists { return } - prevKey := discovery.KeyFromTarget(prevTablet.Target) + prevTargetKey := discovery.KeyFromTarget(prevTablet.Target) + currentTargetKey := discovery.KeyFromTarget(newTablet.Target) - var ring map[discovery.KeyspaceShardTabletType]*hashRing - if tablet.Tablet.Alias.Cell == b.localCell { - ring = b.localRings - } else { - ring = b.externalRings - } - - if ring == nil || ring[prevKey] == nil { + // If this tablet's target changed, remove it from its old target. + if prevTargetKey == currentTargetKey { return } - ring[prevKey].remove(prevTablet) + b.removeTablet(b.localTablets, prevTargetKey, prevTablet) + b.removeTablet(b.externalTablets, prevTargetKey, prevTablet) } // onTabletHealthChange is the handler for tablet health events. If a tablet goes into serving, // it is added to the appropriate (local or external) hash ring for its target. If it goes out // of serving, it is removed from the hash ring. -func (b *SessionBalancer) onTabletHealthChange(tablet *discovery.TabletHealth) { +func (b *SessionBalancer) onTabletHealthChange(newTablet *discovery.TabletHealth) { b.mu.Lock() defer b.mu.Unlock() - var ring *hashRing - if tablet.Tablet.Alias.Cell == b.localCell { - ring = getOrCreateRing(b.localRings, tablet) + var tablets map[discovery.KeyspaceShardTabletType]TabletSet + if newTablet.Tablet.Alias.Cell == b.localCell { + tablets = b.localTablets } else { - ring = getOrCreateRing(b.externalRings, tablet) + tablets = b.externalTablets } - if tablet.Serving { - ring.add(tablet) + alias := tabletAlias(newTablet) + targetKey := discovery.KeyFromTarget(newTablet.Target) - alias := tabletAlias(tablet) - b.tablets[alias] = tablet + if newTablet.Serving { + b.addTablet(tablets, targetKey, newTablet) + b.tablets[alias] = newTablet } else { - ring.remove(tablet) - delete(b.tablets, tabletAlias(tablet)) + b.removeTablet(tablets, targetKey, newTablet) + delete(b.tablets, alias) } } -// getOrCreateRing gets or creates a new ring for the given tablet. -func getOrCreateRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, tablet *discovery.TabletHealth) *hashRing { - key := discovery.KeyFromTarget(tablet.Target) +// addTablet adds a tablet to the target in the given list of tablets. +func (b *SessionBalancer) addTablet(tablets map[discovery.KeyspaceShardTabletType]TabletSet, targetKey discovery.KeyspaceShardTabletType, tablet *discovery.TabletHealth) { + target, ok := tablets[targetKey] + if !ok { + // Create the set if one has not been created for this target yet + tablets[targetKey] = make(TabletSet) + target = tablets[targetKey] + } - ring, exists := rings[key] - if !exists { - ring = newHashRing() - rings[key] = ring + alias := tabletAlias(tablet) + target[alias] = tablet +} + +// removeTablet removes the tablet from the target in the given list of tablets. +func (b *SessionBalancer) removeTablet(tablets map[discovery.KeyspaceShardTabletType]TabletSet, targetKey discovery.KeyspaceShardTabletType, tablet *discovery.TabletHealth) { + alias := tabletAlias(tablet) + delete(tablets[targetKey], alias) +} + +// Pick is the main entry point to the balancer. +// +// For a given session, it will return the same tablet for its duration, with preference to tablets +// in the local cell. +func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { + if opts.SessionUUID == "" { + return nil + } + + b.mu.RLock() + defer b.mu.RUnlock() + + targetKey := discovery.KeyFromTarget(target) + + // Try to find a tablet in the local cell first + tablet := pick(b.localTablets[targetKey], opts) + if tablet != nil { + return tablet + } + + // If we didn't find a tablet in the local cell, try external cells + tablet = pick(b.externalTablets[targetKey], opts) + return tablet +} + +// pick picks the highest weight valid tablet from the set of tablets. +func pick(tablets TabletSet, opts PickOpts) *discovery.TabletHealth { + var maxWeight uint64 + var maxTablet *discovery.TabletHealth + + for alias, tablet := range tablets { + invalid := opts.InvalidTablets[alias] + if invalid { + continue + } + + weight := weight(alias, opts.SessionUUID) + if tablet == nil || weight > maxWeight { + maxWeight = weight + maxTablet = tablet + } } - return ring + return maxTablet +} + +// weight computes the weight of a tablet by hashing its alias and the session UUID together. +func weight(alias string, sessionUUID string) uint64 { + return xxhash.Sum64String(alias + "#" + sessionUUID) +} + +// tabletAlias returns the tablet's alias as a string. +func tabletAlias(tablet *discovery.TabletHealth) string { + return topoproto.TabletAliasString(tablet.Tablet.Alias) +} + +// DebugHandler provides a summary of the session balancer state. +func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprint(w, b.print()) } // print returns a string representation of the session balancer state for debugging. func (b *SessionBalancer) print() string { + b.mu.RLock() + defer b.mu.RUnlock() + sb := strings.Builder{} - sb.WriteString("Local rings:\n") - if len(b.localRings) == 0 { - sb.WriteString("\tNo local rings\n") - } + sb.WriteString("Session balancer\n") + sb.WriteString("================\n") + sb.WriteString(fmt.Sprintf("Local cell: %s\n\n", b.localCell)) - for target, ring := range b.localRings { - sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) - sb.WriteString(fmt.Sprintf("\t\tNode count: %d\n", len(ring.nodes))) - sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(ring.tablets)))) - } + sb.WriteString("Local tablets:\n") - sb.WriteString("External rings:\n") - if len(b.externalRings) == 0 { - sb.WriteString("\tNo external rings\n") - } + for target, tablets := range b.localTablets { + if len(tablets) == 0 { + continue + } - for target, ring := range b.externalRings { sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) - sb.WriteString(fmt.Sprintf("\t\tNode count: %d\n", len(ring.nodes))) - sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(ring.tablets)))) + sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(tablets)))) } - return sb.String() -} + sb.WriteString("External tablets:\n") -// getFromRing gets a tablet from the respective ring for the given target and session hash. -func getFromRing(rings map[discovery.KeyspaceShardTabletType]*hashRing, target *querypb.Target, invalidTablets map[string]bool, sessionUUID string) *discovery.TabletHealth { - key := discovery.KeyFromTarget(target) + for target, tablets := range b.externalTablets { + if len(tablets) == 0 { + continue + } - ring, exists := rings[key] - if !exists { - return nil + sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) + sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(tablets)))) } - return ring.get(sessionUUID, invalidTablets) + return sb.String() } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 0413e6a2c8e..825cdc78dc5 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -18,8 +18,6 @@ package balancer import ( "fmt" - "net/http" - "net/http/httptest" "testing" "time" @@ -48,10 +46,9 @@ func TestNewSessionBalancer(t *testing.T) { require.Equal(t, "local", b.localCell) require.NotNil(t, b.hc) - require.NotNil(t, b.localRings) - require.Len(t, b.localRings, 0) - require.NotNil(t, b.externalRings) - require.Len(t, b.externalRings, 0) + require.NotNil(t, b.localTablets) + require.NotNil(t, b.externalTablets) + require.NotNil(t, b.tablets) } func TestPickNoTablets(t *testing.T) { @@ -121,7 +118,7 @@ func TestPickLocalOnly(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(10 * time.Millisecond) - // Pick for a specific session hash + // Pick for a specific session UUID opts := buildOpts("a") picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) @@ -131,10 +128,10 @@ func TestPickLocalOnly(t *testing.T) { require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 - opts = buildOpts("c") + opts = buildOpts("b") picked3 := b.Pick(target, nil, opts) require.NotNil(t, picked3) - require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked3))) + require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s for both", tabletAlias(picked3))) } func TestPickPreferLocal(t *testing.T) { @@ -372,7 +369,7 @@ func TestNewLocalTablet(t *testing.T) { time.Sleep(10 * time.Millisecond) - opts := buildOpts("a") + opts := buildOpts("b") picked1 := b.Pick(target, nil, opts) require.NotNil(t, picked1) @@ -465,16 +462,6 @@ func TestNewExternalTablet(t *testing.T) { require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) } -func TestDebugHandler(t *testing.T) { - b, _ := newSessionBalancer(t) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/debug", nil) - - b.DebugHandler(w, r) - require.Equal(t, http.StatusOK, w.Code) -} - func TestPickNoOpts(t *testing.T) { b, hcChan := newSessionBalancer(t) @@ -671,17 +658,17 @@ func TestTabletTypesToWatch(t *testing.T) { b.mu.RLock() defer b.mu.RUnlock() - require.Len(t, b.localRings, 2) - require.Len(t, b.externalRings, 0) + require.Len(t, b.localTablets, 2) + require.Len(t, b.externalTablets, 0) - for _, ring := range b.localRings { - for _, tablet := range ring.nodeMap { + for _, target := range b.localTablets { + for _, tablet := range target { require.Contains(t, tabletTypesToWatch, tablet.Target.TabletType) } } } -func TestTabletTargetChanges(t *testing.T) { +func TestLocalTabletTargetChanges(t *testing.T) { b, hcChan := newSessionBalancer(t) replica := &discovery.TabletHealth{ @@ -725,10 +712,10 @@ func TestTabletTargetChanges(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(100 * time.Millisecond) - require.Len(t, b.localRings, 1, b.print()) - require.Len(t, b.localRings[discovery.KeyFromTarget(replica.Target)].tablets, 1, b.print()) + require.Len(t, b.localTablets, 1, b.print()) + require.Len(t, b.localTablets[discovery.KeyFromTarget(replica.Target)], 1, b.print()) - require.Len(t, b.externalRings, 0, b.print()) + require.Len(t, b.externalTablets, 0, b.print()) // Reparent happens, tablet is now a primary hcChan <- primary @@ -736,10 +723,10 @@ func TestTabletTargetChanges(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(100 * time.Millisecond) - require.Len(t, b.localRings, 1, b.print()) - require.Len(t, b.localRings[discovery.KeyFromTarget(replica.Target)].tablets, 0, b.print()) + require.Len(t, b.localTablets, 1, b.print()) + require.Len(t, b.localTablets[discovery.KeyFromTarget(replica.Target)], 0, b.print()) - require.Len(t, b.externalRings, 0, b.print()) + require.Len(t, b.externalTablets, 0, b.print()) } func TestExternalTabletTargetChanges(t *testing.T) { @@ -786,10 +773,10 @@ func TestExternalTabletTargetChanges(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(100 * time.Millisecond) - require.Len(t, b.externalRings, 1, b.print()) - require.Len(t, b.externalRings[discovery.KeyFromTarget(replica.Target)].tablets, 1, b.print()) + require.Len(t, b.externalTablets, 1, b.print()) + require.Len(t, b.externalTablets[discovery.KeyFromTarget(replica.Target)], 1, b.print()) - require.Len(t, b.localRings, 0, b.print()) + require.Len(t, b.localTablets, 0, b.print()) // Reparent happens, tablet is now a primary hcChan <- primary @@ -797,10 +784,10 @@ func TestExternalTabletTargetChanges(t *testing.T) { // Give a moment for the worker to process the tablets time.Sleep(100 * time.Millisecond) - require.Len(t, b.externalRings, 1, b.print()) - require.Len(t, b.externalRings[discovery.KeyFromTarget(replica.Target)].tablets, 0, b.print()) + require.Len(t, b.externalTablets, 1, b.print()) + require.Len(t, b.externalTablets[discovery.KeyFromTarget(replica.Target)], 0, b.print()) - require.Len(t, b.localRings, 0, b.print()) + require.Len(t, b.localTablets, 0, b.print()) } func buildOpts(uuid string) PickOpts { From b4ecebcac3d38cd0f896f339c42467aa9cbb4468 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 7 Oct 2025 18:12:02 -0400 Subject: [PATCH 38/67] Remove health check Signed-off-by: Mohamed Hamza --- go/vt/vtexplain/vtexplain_vtgate.go | 2 +- go/vt/vtgate/balancer/session.go | 254 +------- go/vt/vtgate/balancer/session_test.go | 798 +++++------------------ go/vt/vtgate/legacy_scatter_conn_test.go | 2 +- go/vt/vtgate/tabletgateway.go | 24 +- go/vt/vtgate/tabletgateway_flaky_test.go | 6 +- go/vt/vtgate/tabletgateway_test.go | 10 +- go/vt/vtgate/vstream_manager_test.go | 4 +- go/vt/vtgate/vtgate.go | 6 +- 9 files changed, 227 insertions(+), 879 deletions(-) diff --git a/go/vt/vtexplain/vtexplain_vtgate.go b/go/vt/vtexplain/vtexplain_vtgate.go index 28bcaa6ea6d..7c838da4330 100644 --- a/go/vt/vtexplain/vtexplain_vtgate.go +++ b/go/vt/vtexplain/vtexplain_vtgate.go @@ -86,7 +86,7 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, ts *topo.Server, v } func (vte *VTExplain) newFakeResolver(ctx context.Context, opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver { - gw, _ := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell) + gw := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell) _ = gw.WaitForTablets(ctx, []topodatapb.TabletType{topodatapb.TabletType_REPLICA}) txMode := vtgatepb.TransactionMode_MULTI diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 181c4f9a6b8..1f5083c84b7 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -17,227 +17,72 @@ limitations under the License. package balancer import ( - "context" "fmt" - "maps" "net/http" - "slices" - "strings" - "sync" "github.com/cespare/xxhash/v2" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/proto/topodata" - "vitess.io/vitess/go/vt/srvtopo" "vitess.io/vitess/go/vt/topo/topoproto" ) -// tabletTypesToWatch are the tablet types that will be included in the hash rings. -var tabletTypesToWatch = map[topodata.TabletType]struct{}{topodata.TabletType_REPLICA: {}, topodata.TabletType_RDONLY: {}} - // SessionBalancer implements the TabletBalancer interface. For a given session, // it will return the same tablet for its duration, with preference to tablets in // the local cell. type SessionBalancer struct { // localCell is the cell the gateway is currently running in. localCell string - - // hc is the tablet health check. - hc discovery.HealthCheck - - mu sync.RWMutex - - // localTablets is a map of tablets in the local cell for each target. - localTablets map[discovery.KeyspaceShardTabletType]TabletSet - - // externalTablets is a map of tablets external to this cell for each target. - externalTablets map[discovery.KeyspaceShardTabletType]TabletSet - - // tablets keeps track of the latest state of each tablet, keyed by alias. This - // is used to remove tablets from old targets when their target changes (a - // PlannedReparentShard for example). - tablets TabletSet } -// TabletSet represents a set of tablets, keyed by alias. -type TabletSet map[string]*discovery.TabletHealth - // NewSessionBalancer creates a new session balancer. -func NewSessionBalancer(ctx context.Context, localCell string, topoServer srvtopo.Server, hc discovery.HealthCheck) (TabletBalancer, error) { - b := &SessionBalancer{ - localCell: localCell, - hc: hc, - localTablets: make(map[discovery.KeyspaceShardTabletType]TabletSet), - externalTablets: make(map[discovery.KeyspaceShardTabletType]TabletSet), - tablets: make(TabletSet), - } - - // Set up health check subscription - hcChan := b.hc.Subscribe("SessionBalancer") - - // Build initial hash rings - - // Find all the targets we're watching - targets, _, err := srvtopo.FindAllTargetsAndKeyspaces(ctx, topoServer, b.localCell, discovery.KeyspacesToWatch, slices.Collect(maps.Keys(tabletTypesToWatch))) - if err != nil { - return nil, fmt.Errorf("session balancer: failed to find all targets and keyspaces: %w", err) - } - - // Add each tablet to the hash ring - for _, target := range targets { - tablets := b.hc.GetHealthyTabletStats(target) - for _, tablet := range tablets { - b.onTabletHealthChange(tablet) - } - } - - // Start watcher to keep track of tablet health - go b.watchHealthCheck(ctx, hcChan) - - return b, nil -} - -// watchHealthCheck watches the health check channel for tablet health changes and updates the set of tablets accordingly. -func (b *SessionBalancer) watchHealthCheck(ctx context.Context, hcChan chan *discovery.TabletHealth) { - for { - select { - case <-ctx.Done(): - b.hc.Unsubscribe(hcChan) - return - case tablet := <-hcChan: - if tablet == nil { - continue - } - - // Remove tablet from old target if it has changed - b.removeOldTablet(tablet) - - // Ignore tablets we aren't supposed to watch - if _, ok := tabletTypesToWatch[tablet.Target.TabletType]; !ok { - continue - } - - b.onTabletHealthChange(tablet) - } - } -} - -// removeOldTablet removes the entry for a tablet in its old target if its target has changed. For example, if a -// reparent happens and a replica is now a primary, we need to remove it from the list of tablets for the replica -// target. -func (b *SessionBalancer) removeOldTablet(newTablet *discovery.TabletHealth) { - b.mu.Lock() - defer b.mu.Unlock() - - alias := tabletAlias(newTablet) - prevTablet, exists := b.tablets[alias] - if !exists { - return - } - - prevTargetKey := discovery.KeyFromTarget(prevTablet.Target) - currentTargetKey := discovery.KeyFromTarget(newTablet.Target) - - // If this tablet's target changed, remove it from its old target. - if prevTargetKey == currentTargetKey { - return - } - - b.removeTablet(b.localTablets, prevTargetKey, prevTablet) - b.removeTablet(b.externalTablets, prevTargetKey, prevTablet) -} - -// onTabletHealthChange is the handler for tablet health events. If a tablet goes into serving, -// it is added to the appropriate (local or external) hash ring for its target. If it goes out -// of serving, it is removed from the hash ring. -func (b *SessionBalancer) onTabletHealthChange(newTablet *discovery.TabletHealth) { - b.mu.Lock() - defer b.mu.Unlock() - - var tablets map[discovery.KeyspaceShardTabletType]TabletSet - if newTablet.Tablet.Alias.Cell == b.localCell { - tablets = b.localTablets - } else { - tablets = b.externalTablets - } - - alias := tabletAlias(newTablet) - targetKey := discovery.KeyFromTarget(newTablet.Target) - - if newTablet.Serving { - b.addTablet(tablets, targetKey, newTablet) - b.tablets[alias] = newTablet - } else { - b.removeTablet(tablets, targetKey, newTablet) - delete(b.tablets, alias) - } -} - -// addTablet adds a tablet to the target in the given list of tablets. -func (b *SessionBalancer) addTablet(tablets map[discovery.KeyspaceShardTabletType]TabletSet, targetKey discovery.KeyspaceShardTabletType, tablet *discovery.TabletHealth) { - target, ok := tablets[targetKey] - if !ok { - // Create the set if one has not been created for this target yet - tablets[targetKey] = make(TabletSet) - target = tablets[targetKey] - } - - alias := tabletAlias(tablet) - target[alias] = tablet -} - -// removeTablet removes the tablet from the target in the given list of tablets. -func (b *SessionBalancer) removeTablet(tablets map[discovery.KeyspaceShardTabletType]TabletSet, targetKey discovery.KeyspaceShardTabletType, tablet *discovery.TabletHealth) { - alias := tabletAlias(tablet) - delete(tablets[targetKey], alias) +func NewSessionBalancer(localCell string) TabletBalancer { + return &SessionBalancer{localCell: localCell} } // Pick is the main entry point to the balancer. // // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. -func (b *SessionBalancer) Pick(target *querypb.Target, _ []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { +func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { if opts.SessionUUID == "" { return nil } - b.mu.RLock() - defer b.mu.RUnlock() + // Find the highest weight local and external tablets + var maxLocalWeight, maxExternalWeight uint64 + var maxLocalTablet, maxExternalTablet *discovery.TabletHealth - targetKey := discovery.KeyFromTarget(target) - - // Try to find a tablet in the local cell first - tablet := pick(b.localTablets[targetKey], opts) - if tablet != nil { - return tablet - } + for _, tablet := range tablets { + alias := tabletAlias(tablet) - // If we didn't find a tablet in the local cell, try external cells - tablet = pick(b.externalTablets[targetKey], opts) - return tablet -} - -// pick picks the highest weight valid tablet from the set of tablets. -func pick(tablets TabletSet, opts PickOpts) *discovery.TabletHealth { - var maxWeight uint64 - var maxTablet *discovery.TabletHealth - - for alias, tablet := range tablets { - invalid := opts.InvalidTablets[alias] - if invalid { + // Ignore invalid tablets + if _, invalid := opts.InvalidTablets[alias]; invalid { continue } weight := weight(alias, opts.SessionUUID) - if tablet == nil || weight > maxWeight { - maxWeight = weight - maxTablet = tablet + + if b.isLocal(tablet) && ((maxLocalTablet == nil) || (weight > maxLocalWeight)) { + maxLocalWeight = weight + maxLocalTablet = tablet } + + // We can consider all tablets here since we'd only use this if there were no + // valid local tablets (meaning we'd have only considered external tablets anyway). + if (maxExternalTablet == nil) || (weight > maxExternalWeight) { + maxExternalWeight = weight + maxExternalTablet = tablet + } + } + + // If we found a valid local tablet, use that + if maxLocalTablet != nil { + return maxLocalTablet } - return maxTablet + // Otherwise, use the max external tablet (if it exists) + return maxExternalTablet } // weight computes the weight of a tablet by hashing its alias and the session UUID together. @@ -250,44 +95,13 @@ func tabletAlias(tablet *discovery.TabletHealth) string { return topoproto.TabletAliasString(tablet.Tablet.Alias) } +// isLocal returns true if the tablet is in the local cell. +func (b *SessionBalancer) isLocal(tablet *discovery.TabletHealth) bool { + return tablet.Tablet.Alias.Cell == b.localCell +} + // DebugHandler provides a summary of the session balancer state. func (b *SessionBalancer) DebugHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - fmt.Fprint(w, b.print()) -} - -// print returns a string representation of the session balancer state for debugging. -func (b *SessionBalancer) print() string { - b.mu.RLock() - defer b.mu.RUnlock() - - sb := strings.Builder{} - - sb.WriteString("Session balancer\n") - sb.WriteString("================\n") - sb.WriteString(fmt.Sprintf("Local cell: %s\n\n", b.localCell)) - - sb.WriteString("Local tablets:\n") - - for target, tablets := range b.localTablets { - if len(tablets) == 0 { - continue - } - - sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) - sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(tablets)))) - } - - sb.WriteString("External tablets:\n") - - for target, tablets := range b.externalTablets { - if len(tablets) == 0 { - continue - } - - sb.WriteString(fmt.Sprintf("\t - Target: %s\n", target)) - sb.WriteString(fmt.Sprintf("\t\tTablets: %+v\n", slices.Collect(maps.Keys(tablets)))) - } - - return sb.String() + fmt.Fprintf(w, "Local cell: %s", b.localCell) } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 825cdc78dc5..0152b65759a 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -19,40 +19,26 @@ package balancer import ( "fmt" "testing" - "time" "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "vitess.io/vitess/go/vt/srvtopo/fakesrvtopo" "vitess.io/vitess/go/vt/topo/topoproto" ) -func newSessionBalancer(t *testing.T) (*SessionBalancer, chan *discovery.TabletHealth) { - ctx := t.Context() +func newSessionBalancer(t *testing.T) *SessionBalancer { + t.Helper() - ch := make(chan *discovery.TabletHealth, 10) - hc := discovery.NewFakeHealthCheck(ch) - b, _ := NewSessionBalancer(ctx, "local", &fakesrvtopo.FakeSrvTopo{}, hc) + b := NewSessionBalancer("local") sb := b.(*SessionBalancer) - return sb, ch -} - -func TestNewSessionBalancer(t *testing.T) { - b, _ := newSessionBalancer(t) - - require.Equal(t, "local", b.localCell) - require.NotNil(t, b.hc) - require.NotNil(t, b.localTablets) - require.NotNil(t, b.externalTablets) - require.NotNil(t, b.tablets) + return sb } func TestPickNoTablets(t *testing.T) { - b, _ := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -67,7 +53,7 @@ func TestPickNoTablets(t *testing.T) { } func TestPickLocalOnly(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -76,66 +62,62 @@ func TestPickLocalOnly(t *testing.T) { Cell: "local", } - localTablet1 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, - } - localTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, } - hcChan <- localTablet1 - hcChan <- localTablet2 - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - // Pick for a specific session UUID opts := buildOpts("a") - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, tablets, opts) require.NotNil(t, picked1) // Pick again with same session hash, should return same tablet - picked2 := b.Pick(target, nil, opts) + picked2 := b.Pick(target, tablets, opts) require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 opts = buildOpts("b") - picked3 := b.Pick(target, nil, opts) + picked3 := b.Pick(target, tablets, opts) require.NotNil(t, picked3) require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s for both", tabletAlias(picked3))) } func TestPickPreferLocal(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -144,76 +126,71 @@ func TestPickPreferLocal(t *testing.T) { Cell: "local", } - localTablet1 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, - } - localTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, - } - externalTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 200, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", + }, + Serving: true, }, - Serving: true, } - hcChan <- localTablet1 - hcChan <- localTablet2 - hcChan <- externalTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - // Pick should prefer local cell opts := buildOpts("a") - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, tablets, opts) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) } func TestPickNoLocal(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -222,248 +199,52 @@ func TestPickNoLocal(t *testing.T) { Cell: "local", } - externalTablet1 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 200, + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", - }, - Serving: true, - } - - externalTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 201, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", - }, - Serving: true, - } - - hcChan <- externalTablet1 - hcChan <- externalTablet2 - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - - // Pick should return external cell since there are no local cells - opts := buildOpts("a") - picked1 := b.Pick(target, nil, opts) - require.NotNil(t, picked1) - require.Equal(t, "external", picked1.Target.Cell) -} - -func TestTabletNotServing(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - target := &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - } - - localTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, + Serving: true, + }, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 201, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - externalTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 200, + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "external", }, - Keyspace: "keyspace", - Shard: "0", + Serving: true, }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", - }, - Serving: true, - } - - hcChan <- localTablet - hcChan <- externalTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - - opts := buildOpts("a") - picked1 := b.Pick(target, nil, opts) - require.NotNil(t, picked1) - - // Local tablet goes out of serving - localTablet.Serving = false - hcChan <- localTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - - // Should not pick the local tablet again - picked2 := b.Pick(target, nil, opts) - require.NotEqual(t, picked1, picked2) -} - -func TestNewLocalTablet(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - target := &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - } - - localTablet1 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - hcChan <- localTablet1 - - time.Sleep(10 * time.Millisecond) - - opts := buildOpts("b") - picked1 := b.Pick(target, nil, opts) - require.NotNil(t, picked1) - - localTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, } - hcChan <- localTablet2 - - time.Sleep(10 * time.Millisecond) - - picked2 := b.Pick(target, nil, opts) - require.NotNil(t, picked2) - require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) -} - -func TestNewExternalTablet(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - target := &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - } - - externalTablet1 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - hcChan <- externalTablet1 - - time.Sleep(10 * time.Millisecond) - + // Pick should return external cell since there are no local cells opts := buildOpts("a") - picked1 := b.Pick(target, nil, opts) + picked1 := b.Pick(target, tablets, opts) require.NotNil(t, picked1) - - externalTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - hcChan <- externalTablet2 - - time.Sleep(10 * time.Millisecond) - - picked2 := b.Pick(target, nil, opts) - require.NotNil(t, picked2) - require.NotEqual(t, picked1, picked2, fmt.Sprintf("expected different tablets, got %s", tabletAlias(picked2))) + require.Equal(t, "external", picked1.Target.Cell) } func TestPickNoOpts(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -472,36 +253,33 @@ func TestPickNoOpts(t *testing.T) { Cell: "local", } - localTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, } - hcChan <- localTablet - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - // Test with empty opts - result := b.Pick(target, nil, PickOpts{}) + result := b.Pick(target, tablets, PickOpts{}) require.Nil(t, result) } func TestPickInvalidTablets(t *testing.T) { - b, hcChan := newSessionBalancer(t) + b := newSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -510,286 +288,60 @@ func TestPickInvalidTablets(t *testing.T) { Cell: "local", } - localTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, - } - localTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", + Target: &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + }, + Serving: true, }, - Serving: true, } - hcChan <- localTablet - hcChan <- localTablet2 - - // Give a moment for the worker to process the tablets - time.Sleep(10 * time.Millisecond) - // Get a tablet regularly opts := buildOpts("a") - tablet := b.Pick(target, nil, opts) + tablet := b.Pick(target, tablets, opts) require.NotNil(t, tablet) // Mark returned tablet as invalid, should return other tablet opts.InvalidTablets = map[string]bool{topoproto.TabletAliasString(tablet.Tablet.Alias): true} - tablet2 := b.Pick(target, nil, opts) + tablet2 := b.Pick(target, tablets, opts) require.NotEqual(t, tablet, tablet2) // Mark both as invalid, should return nil opts.InvalidTablets[topoproto.TabletAliasString(tablet2.Tablet.Alias)] = true - tablet3 := b.Pick(target, nil, opts) + tablet3 := b.Pick(target, tablets, opts) require.Nil(t, tablet3) } -func TestTabletTypesToWatch(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - // Valid tablet type - localTablet := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - // Valid tablet type - localTablet2 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_RDONLY, - Cell: "local", - }, - Serving: true, - } - - // Invalid tablet type - localTablet3 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, - Cell: "local", - }, - Serving: true, - } - - // Invalid tablet type - localTablet4 := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_BACKUP, - Cell: "local", - }, - Serving: true, - } - - hcChan <- localTablet - hcChan <- localTablet2 - hcChan <- localTablet3 - hcChan <- localTablet4 - - // Give a moment for the worker to process the tablets - time.Sleep(100 * time.Millisecond) - - b.mu.RLock() - defer b.mu.RUnlock() - - require.Len(t, b.localTablets, 2) - require.Len(t, b.externalTablets, 0) - - for _, target := range b.localTablets { - for _, tablet := range target { - require.Contains(t, tabletTypesToWatch, tablet.Target.TabletType) - } - } -} - -func TestLocalTabletTargetChanges(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - replica := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - }, - Serving: true, - } - - primary := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, - Cell: "local", - }, - Serving: true, - } - - hcChan <- replica - - // Give a moment for the worker to process the tablets - time.Sleep(100 * time.Millisecond) - - require.Len(t, b.localTablets, 1, b.print()) - require.Len(t, b.localTablets[discovery.KeyFromTarget(replica.Target)], 1, b.print()) - - require.Len(t, b.externalTablets, 0, b.print()) - - // Reparent happens, tablet is now a primary - hcChan <- primary - - // Give a moment for the worker to process the tablets - time.Sleep(100 * time.Millisecond) - - require.Len(t, b.localTablets, 1, b.print()) - require.Len(t, b.localTablets[discovery.KeyFromTarget(replica.Target)], 0, b.print()) - - require.Len(t, b.externalTablets, 0, b.print()) -} - -func TestExternalTabletTargetChanges(t *testing.T) { - b, hcChan := newSessionBalancer(t) - - replica := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "external", - }, - Serving: true, - } - - primary := &discovery.TabletHealth{ - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - Target: &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_PRIMARY, - Cell: "external", - }, - Serving: true, - } - - hcChan <- replica - - // Give a moment for the worker to process the tablets - time.Sleep(100 * time.Millisecond) - - require.Len(t, b.externalTablets, 1, b.print()) - require.Len(t, b.externalTablets[discovery.KeyFromTarget(replica.Target)], 1, b.print()) - - require.Len(t, b.localTablets, 0, b.print()) - - // Reparent happens, tablet is now a primary - hcChan <- primary - - // Give a moment for the worker to process the tablets - time.Sleep(100 * time.Millisecond) - - require.Len(t, b.externalTablets, 1, b.print()) - require.Len(t, b.externalTablets[discovery.KeyFromTarget(replica.Target)], 0, b.print()) - - require.Len(t, b.localTablets, 0, b.print()) -} - func buildOpts(uuid string) PickOpts { return PickOpts{SessionUUID: uuid} } diff --git a/go/vt/vtgate/legacy_scatter_conn_test.go b/go/vt/vtgate/legacy_scatter_conn_test.go index 9ff2ddcf1d3..ec345e4308e 100644 --- a/go/vt/vtgate/legacy_scatter_conn_test.go +++ b/go/vt/vtgate/legacy_scatter_conn_test.go @@ -621,7 +621,7 @@ func newTestScatterConn(ctx context.Context, hc discovery.HealthCheck, serv srvt // The topo.Server is used to start watching the cells described // in '-cells_to_watch' command line parameter, which is // empty by default. So it's unused in this test, set to nil. - gw, _ := NewTabletGateway(ctx, hc, serv, cell) + gw := NewTabletGateway(ctx, hc, serv, cell) tc := NewTxConn(gw, &StaticConfig{ TxMode: vtgatepb.TransactionMode_MULTI, }) diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 2dd77e34a11..37ea620552a 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -111,7 +111,7 @@ func createHealthCheck(ctx context.Context, retryDelay, timeout time.Duration, t } // NewTabletGateway creates and returns a new TabletGateway -func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, localCell string) (*TabletGateway, error) { +func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, localCell string) *TabletGateway { // hack to accommodate various users of gateway + tests if hc == nil { var topoServer *topo.Server @@ -132,16 +132,11 @@ func NewTabletGateway(ctx context.Context, hc discovery.HealthCheck, serv srvtop statusAggregators: make(map[string]*TabletStatusAggregator), } gw.setupBuffering(ctx) - if balancerEnabled { - err := gw.setupBalancer(ctx) - if err != nil { - return nil, fmt.Errorf("tablet gateway: failed to set up balancer: %w", err) - } + gw.setupBalancer() } - gw.QueryService = queryservice.Wrap(nil, gw.withRetry) - return gw, nil + return gw } func (gw *TabletGateway) setupBuffering(ctx context.Context) { @@ -173,29 +168,22 @@ func (gw *TabletGateway) setupBuffering(ctx context.Context) { }(bufferCtx, ksChan, gw.buffer) } -func (gw *TabletGateway) setupBalancer(ctx context.Context) error { +func (gw *TabletGateway) setupBalancer() { if len(balancerVtgateCells) == 0 { log.Exitf("balancer-vtgate-cells is required for balanced mode") } switch balancerType { case "session": - balancer, err := balancer.NewSessionBalancer(ctx, gw.localCell, gw.srvTopoServer, gw.hc) - if err != nil { - return fmt.Errorf("failed to create session balancer: %w", err) - } - - gw.balancer = balancer + gw.balancer = balancer.NewSessionBalancer(gw.localCell) default: if balancerType != "balanced" { log.Warningf("Unrecognized balancer type %q, using default \"balanced\"", balancerType) + balancerType = "balanced" } - balancerType = "balanced" gw.balancer = balancer.NewTabletBalancer(gw.localCell, balancerVtgateCells) } - - return nil } // QueryServiceByAlias satisfies the Gateway interface diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index f8ab3f45934..124997bea9e 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -59,7 +59,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // add a primary tablet which is serving @@ -162,7 +162,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // add a primary tablet which is serving @@ -292,7 +292,7 @@ func TestInconsistentStateDetectedBuffering(t *testing.T) { // create a new fake health check. We want to check the buffering code which uses Subscribe, so we must also pass a channel hc := discovery.NewFakeHealthCheck(make(chan *discovery.TabletHealth)) // create a new tablet gateway - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) tg.retryCount = 0 diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index 0ecca0a8348..9812de52e9e 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -112,7 +112,7 @@ func TestTabletGatewayShuffleTablets(t *testing.T) { hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg, _ := NewTabletGateway(ctx, hc, ts, "local") + tg := NewTabletGateway(ctx, hc, ts, "local") defer tg.Close(ctx) ts1 := &discovery.TabletHealth{ @@ -186,7 +186,7 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) { } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) _ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) @@ -221,7 +221,7 @@ func testTabletGatewayGenericHelper(t *testing.T, ctx context.Context, f func(ct } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // no tablet want := []string{"target: ks.0.replica", `no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA`} @@ -309,7 +309,7 @@ func testTabletGatewayTransact(t *testing.T, ctx context.Context, f func(ctx con } hc := discovery.NewFakeHealthCheck(nil) ts := &econtext.FakeTopoServer{} - tg, _ := NewTabletGateway(ctx, hc, ts, "cell") + tg := NewTabletGateway(ctx, hc, ts, "cell") defer tg.Close(ctx) // retry error - no retry @@ -350,7 +350,7 @@ func verifyShardErrors(t *testing.T, err error, wantErrors []string, wantCode vt // TestWithRetry tests the functionality of withRetry function in different circumstances. func TestWithRetry(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - tg, _ := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &econtext.FakeTopoServer{}, "cell") + tg := NewTabletGateway(ctx, discovery.NewFakeHealthCheck(nil), &econtext.FakeTopoServer{}, "cell") tg.kev = discovery.NewKeyspaceEventWatcher(ctx, tg.srvTopoServer, tg.hc, tg.localCell) defer func() { cancel() diff --git a/go/vt/vtgate/vstream_manager_test.go b/go/vt/vtgate/vstream_manager_test.go index 2527c63832e..e1cf2fabf4e 100644 --- a/go/vt/vtgate/vstream_manager_test.go +++ b/go/vt/vtgate/vstream_manager_test.go @@ -813,7 +813,6 @@ func TestVStreamRetriableErrors(t *testing.T) { } }) } - } func TestVStreamShouldNotSendSourceHeartbeats(t *testing.T) { @@ -1520,7 +1519,6 @@ func TestResolveVStreamParams(t *testing.T) { require.Equal(t, minimizeSkew, flags2.MinimizeSkew) }) } - } func TestVStreamIdleHeartbeat(t *testing.T) { @@ -1984,7 +1982,7 @@ func TestVStreamManagerHealthCheckResponseHandling(t *testing.T) { } func newTestVStreamManager(ctx context.Context, hc discovery.HealthCheck, serv srvtopo.Server, cell string) *vstreamManager { - gw, _ := NewTabletGateway(ctx, hc, serv, cell) + gw := NewTabletGateway(ctx, hc, serv, cell) srvResolver := srvtopo.NewResolver(serv, gw, cell) return newVStreamManager(srvResolver, serv, cell) } diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 7ab4b16fc96..02b1c0a1925 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -320,11 +320,7 @@ func Init( // Start with the gateway. If we can't reach the topology service, // we can't go on much further, so we log.Fatal out. // TabletGateway can create it's own healthcheck - gw, err := NewTabletGateway(ctx, hc, serv, cell) - if err != nil { - log.Fatalf("vtgate: failed to initalize tablet gateway: %w", err) - } - + gw := NewTabletGateway(ctx, hc, serv, cell) gw.RegisterStats() if err := gw.WaitForTablets(ctx, tabletTypesToWait); err != nil { log.Fatalf("tabletGateway.WaitForTablets failed: %v", err) From 01b5bb5f3e415f3a77a2564360c0e50aa184649e Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 9 Dec 2025 15:10:15 -0500 Subject: [PATCH 39/67] Update balancer.go Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 0d38f1ea892..209abd6a559 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -151,6 +151,7 @@ type PickOpts struct { // - "prefer-cell": Flow-based balancer that maintains cell affinity while balancing load // - See the RFC here: https://github.com/vitessio/vitess/issues/12241 // - "random": Random balancer that uniformly distributes load without cell affinity +// - "session": Session balancer that pins a session to the same tablet for the duration of the session. // // Note: "cell" mode is handled by the gateway and does not create a balancer instance. // operates as a round robin inside of the vtgate's cell From b7583267b39fd37c869f72e6cb33d77419afb585 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 9 Dec 2025 15:27:08 -0500 Subject: [PATCH 40/67] more conflict changes Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 15 ++++++++------- go/vt/vtgate/balancer/random_balancer.go | 2 +- go/vt/vtgate/balancer/random_balancer_test.go | 8 ++++---- go/vt/vtgate/balancer/session.go | 4 ++-- go/vt/vtgate/balancer/session_test.go | 16 ++++++++-------- 5 files changed, 23 insertions(+), 22 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 209abd6a559..24b9ecc9b3f 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -96,6 +96,7 @@ const ( ModeCell ModePreferCell ModeRandom + ModeSession ) func ParseMode(ms string) Mode { @@ -106,6 +107,8 @@ func ParseMode(ms string) Mode { return ModePreferCell case "random": return ModeRandom + case "session": + return ModeSession default: return ModeInvalid } @@ -119,13 +122,15 @@ func (m Mode) String() string { return "prefer-cell" case ModeRandom: return "random" + case ModeSession: + return "session" default: return "invalid" } } func GetAvailableModeNames() []string { - return []string{ModeCell.String(), ModePreferCell.String(), ModeRandom.String()} + return []string{ModeCell.String(), ModePreferCell.String(), ModeRandom.String(), ModeSession.String()} } type TabletBalancer interface { @@ -162,6 +167,8 @@ func NewTabletBalancer(mode Mode, localCell string, vtGateCells []string) (Table return newFlowBalancer(localCell, vtGateCells), nil case ModeRandom: return newRandomBalancer(localCell, vtGateCells), nil + case ModeSession: + return newSessionBalancer(localCell), nil case ModeCell: return nil, errors.New("cell mode should be handled by the gateway, not the balancer factory") default: @@ -352,9 +359,6 @@ func (b *flowBalancer) allocateFlows(allTablets []*discovery.TabletHealth) *targ } } - // fmt.Printf("outflows %v over %v under %v\n", a.Outflows, overAllocated, underAllocated) - - // // For each overallocated cell, proportionally shift flow from targets that are overallocated // to targets that are underallocated. // @@ -376,9 +380,6 @@ func (b *flowBalancer) allocateFlows(allTablets []*discovery.TabletHealth) *targ // to avoid truncating the integer values. shiftFlow := overAllocatedFlow * currentFlow * underAllocatedFlow / a.Inflows[overAllocatedCell] / unbalancedFlow - //fmt.Printf("shift %d %s %s -> %s (over %d current %d in %d under %d unbalanced %d) \n", shiftFlow, vtgateCell, overAllocatedCell, underAllocatedCell, - // overAllocatedFlow, currentFlow, a.Inflows[overAllocatedCell], underAllocatedFlow, unbalancedFlow) - a.Outflows[vtgateCell][overAllocatedCell] -= shiftFlow a.Inflows[overAllocatedCell] -= shiftFlow diff --git a/go/vt/vtgate/balancer/random_balancer.go b/go/vt/vtgate/balancer/random_balancer.go index f8ad723c910..825dfe4dc84 100644 --- a/go/vt/vtgate/balancer/random_balancer.go +++ b/go/vt/vtgate/balancer/random_balancer.go @@ -73,7 +73,7 @@ type randomBalancer struct { // Pick returns a random tablet from the list with uniform probability (1/N). // If vtGateCells is configured, only tablets in those cells are considered. -func (b *randomBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth) *discovery.TabletHealth { +func (b *randomBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ PickOpts) *discovery.TabletHealth { if len(tablets) == 0 { return nil } diff --git a/go/vt/vtgate/balancer/random_balancer_test.go b/go/vt/vtgate/balancer/random_balancer_test.go index fce3306507f..5bd74330405 100644 --- a/go/vt/vtgate/balancer/random_balancer_test.go +++ b/go/vt/vtgate/balancer/random_balancer_test.go @@ -47,7 +47,7 @@ func TestRandomBalancerUniformDistribution(t *testing.T) { pickCounts := make(map[uint32]int) for i := 0; i < numPicks; i++ { - th := b.Pick(target, tablets) + th := b.Pick(target, tablets, PickOpts{}) require.NotNil(t, th, "Pick should not return nil") pickCounts[th.Tablet.Alias.Uid]++ } @@ -66,7 +66,7 @@ func TestRandomBalancerPickEmpty(t *testing.T) { target := &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA} b := newRandomBalancer("cell1", []string{}) - th := b.Pick(target, []*discovery.TabletHealth{}) + th := b.Pick(target, []*discovery.TabletHealth{}, PickOpts{}) assert.Nil(t, th, "Pick should return nil for empty tablet list") } @@ -80,7 +80,7 @@ func TestRandomBalancerPickSingle(t *testing.T) { // Pick multiple times, should always return the same tablet for i := 0; i < 100; i++ { - th := b.Pick(target, tablets) + th := b.Pick(target, tablets, PickOpts{}) require.NotNil(t, th, "Pick should not return nil") assert.Equal(t, tablets[0].Tablet.Alias.Uid, th.Tablet.Alias.Uid, "Pick should return the only available tablet") @@ -133,7 +133,7 @@ func TestRandomBalancerCellFiltering(t *testing.T) { pickCounts := make(map[string]int) for i := 0; i < numPicks; i++ { - th := b.Pick(target, tablets) + th := b.Pick(target, tablets, PickOpts{}) require.NotNil(t, th) pickCounts[th.Tablet.Alias.Cell]++ } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 1f5083c84b7..939824ee89b 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -35,8 +35,8 @@ type SessionBalancer struct { localCell string } -// NewSessionBalancer creates a new session balancer. -func NewSessionBalancer(localCell string) TabletBalancer { +// newSessionBalancer creates a new session balancer. +func newSessionBalancer(localCell string) TabletBalancer { return &SessionBalancer{localCell: localCell} } diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 0152b65759a..e989cae3220 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -28,17 +28,17 @@ import ( "vitess.io/vitess/go/vt/topo/topoproto" ) -func newSessionBalancer(t *testing.T) *SessionBalancer { +func createSessionBalancer(t *testing.T) *SessionBalancer { t.Helper() - b := NewSessionBalancer("local") + b := newSessionBalancer("local") sb := b.(*SessionBalancer) return sb } func TestPickNoTablets(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -53,7 +53,7 @@ func TestPickNoTablets(t *testing.T) { } func TestPickLocalOnly(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -117,7 +117,7 @@ func TestPickLocalOnly(t *testing.T) { } func TestPickPreferLocal(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -190,7 +190,7 @@ func TestPickPreferLocal(t *testing.T) { } func TestPickNoLocal(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -244,7 +244,7 @@ func TestPickNoLocal(t *testing.T) { } func TestPickNoOpts(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", @@ -279,7 +279,7 @@ func TestPickNoOpts(t *testing.T) { } func TestPickInvalidTablets(t *testing.T) { - b := newSessionBalancer(t) + b := createSessionBalancer(t) target := &querypb.Target{ Keyspace: "keyspace", From b30017cd8532e6e33e59da8049c0afd562798308 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 9 Dec 2025 15:32:26 -0500 Subject: [PATCH 41/67] remove some allocations Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 939824ee89b..173252e84f7 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -61,7 +61,7 @@ func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.Tabl continue } - weight := weight(alias, opts.SessionUUID) + weight := tabletWeight(alias, opts.SessionUUID) if b.isLocal(tablet) && ((maxLocalTablet == nil) || (weight > maxLocalWeight)) { maxLocalWeight = weight @@ -85,9 +85,13 @@ func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.Tabl return maxExternalTablet } -// weight computes the weight of a tablet by hashing its alias and the session UUID together. -func weight(alias string, sessionUUID string) uint64 { - return xxhash.Sum64String(alias + "#" + sessionUUID) +// tabletWeight computes the weight of a tablet by hashing its alias and the session UUID together. +func tabletWeight(alias string, sessionUUID string) uint64 { + h := xxhash.New() + _, _ = h.WriteString(alias) + _, _ = h.WriteString("#") + _, _ = h.WriteString(sessionUUID) + return h.Sum64() } // tabletAlias returns the tablet's alias as a string. From 89455ec8a5fccd95d9af318bd9f664d907b17c1f Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 11:48:24 -0500 Subject: [PATCH 42/67] adjust tests to match previous tests Signed-off-by: Mohamed Hamza --- .../vtgate/sessionbalancer/session_test.go | 243 ------------------ .../vtgate/sessionbalancer/uschema.sql | 4 - .../vtgate/tabletbalancer/session_test.go | 223 ++++++++++++++++ 3 files changed, 223 insertions(+), 247 deletions(-) delete mode 100644 go/test/endtoend/vtgate/sessionbalancer/session_test.go delete mode 100644 go/test/endtoend/vtgate/sessionbalancer/uschema.sql create mode 100644 go/test/endtoend/vtgate/tabletbalancer/session_test.go diff --git a/go/test/endtoend/vtgate/sessionbalancer/session_test.go b/go/test/endtoend/vtgate/sessionbalancer/session_test.go deleted file mode 100644 index 281dfa963af..00000000000 --- a/go/test/endtoend/vtgate/sessionbalancer/session_test.go +++ /dev/null @@ -1,243 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sessionbalancer - -import ( - "context" - _ "embed" - "fmt" - "slices" - "testing" - "time" - - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/mysql" - "vitess.io/vitess/go/test/endtoend/cluster" -) - -const ( - cell = "test_misc" - keyspace = "uks" -) - -//go:embed uschema.sql -var uschemaSQL string - -func createCluster(t *testing.T, replicaCount int) (*cluster.LocalProcessCluster, *mysql.ConnParams) { - t.Helper() - - // Create a new clusterInstance - clusterInstance := cluster.NewCluster(cell, "localhost") - - // Start topo server - err := clusterInstance.StartTopo() - require.NoError(t, err, "Failed to start topo server") - - // Enable session balancer in vtgate - clusterInstance.VtGateExtraArgs = append(clusterInstance.VtGateExtraArgs, - "--enable-balancer", - "--balancer-vtgate-cells", clusterInstance.Cell, - "--balancer-type", "session") - - ks := cluster.Keyspace{ - Name: keyspace, - SchemaSQL: uschemaSQL, - } - err = clusterInstance.StartUnshardedKeyspace(ks, replicaCount, false) - require.NoError(t, err, "Failed to start keyspace") - - err = clusterInstance.StartVtgate() - require.NoError(t, err, "Failed to start vtgate") - - vtParams := clusterInstance.GetVTParams(keyspace) - - return clusterInstance, &vtParams -} - -// TestSessionBalancer validates that the session balancer consistently routes -// queries for the same session to the same tablet. -func TestSessionBalancer(t *testing.T) { - cluster, vtParams := createCluster(t, 2) - defer cluster.Teardown() - - // Get connections that route to different tablets - conns, ids := connections(t, vtParams, "replica", 2) - - for _, conn := range conns { - defer conn.Close() - } - - // Validate that each connection consistently returns the same server ID - for range 20 { - for i, conn := range conns { - id := serverID(t, conn) - require.Equal(t, ids[i], id) - } - } -} - -// TestSessionBalancerRemoveTablet validates that when a tablet is killed, -// connections that were using that tablet get rerouted to remaining tablets. -func TestSessionBalancerRemoveTablet(t *testing.T) { - cluster, vtParams := createCluster(t, 2) - defer cluster.Teardown() - - // Get connections that route to different tablets - conns, _ := connections(t, vtParams, "replica", 2) - conn1, conn2 := conns[0], conns[1] - - defer conn1.Close() - defer conn2.Close() - - tablets := tablets(t, cluster, "replica") - require.NotNil(t, tablets) - require.Len(t, tablets, 2) - - err := tablets[0].VttabletProcess.TearDown() - require.NoError(t, err) - - time.Sleep(10 * time.Millisecond) - - for range 20 { - newID1 := serverID(t, conn1) - newID2 := serverID(t, conn2) - - require.Equal(t, newID1, newID2) - } -} - -// TestSessionBalancerAddTablet validates that when a new tablet is added, -// new connections get routed to the new tablet. -func TestSessionBalancerAddTablet(t *testing.T) { - cluster, vtParams := createCluster(t, 3) - defer cluster.Teardown() - - // Get 3 connections that route to different tablets - conns, ids := connections(t, vtParams, "replica", 3) - for _, conn := range conns { - defer conn.Close() - } - - tablets := tablets(t, cluster, "replica") - require.NotNil(t, tablets) - require.Len(t, tablets, 3) - - // Start with only 2 tablets serving - tablet := tablets[2] - err := tablet.VttabletProcess.TearDown() - require.NoError(t, err) - - time.Sleep(10 * time.Millisecond) - - // Find the connection that moved - var conn *mysql.Conn - for i, c := range conns { - newID := serverID(t, c) - if newID != ids[i] { - conn = c - break - } - } - - require.NotNil(t, conn, "One connection should've moved tablets") - - // Start up the tablet again - err = tablet.RestartOnlyTablet() - require.NoError(t, err) - - time.Sleep(10 * time.Millisecond) - - // All connections should route to the same IDs again - for range 20 { - for i, conn := range conns { - id := ids[i] - newID := serverID(t, conn) - require.Equal(t, id, newID) - } - } -} - -// connections returns the specified number of connections that should route to different tablets. -func connections(t *testing.T, vtParams *mysql.ConnParams, tabletType string, numConnections int) ([]*mysql.Conn, []string) { - t.Helper() - - vtParams.DbName = fmt.Sprintf("%s@%s", keyspace, tabletType) - - conns := make([]*mysql.Conn, 0, numConnections) - ids := make([]string, 0, numConnections) - - // Keep creating connections until we have the required number with different server IDs - for range 20 { - conn, err := mysql.Connect(context.Background(), vtParams) - require.NoError(t, err) - - id := serverID(t, conn) - newID := !slices.Contains(ids, id) - - // If we found a new tablet, add it to the list of connections - if newID { - conns = append(conns, conn) - ids = append(ids, id) - - if len(conns) == numConnections { - return conns, ids - } - - continue - } - - conn.Close() - } - - t.Fatalf("could not create %d connections with different tablet connections", numConnections) - return nil, nil -} - -// serverID runs a `SELECT @@server_id` on the given connection and returns the server's ID. -func serverID(t *testing.T, conn *mysql.Conn) string { - t.Helper() - - result1, err := conn.ExecuteFetch("SELECT @@server_id", 1, false) - require.NoError(t, err) - - tablet1Bytes, err := result1.Rows[0][0].ToBytes() - require.NoError(t, err) - - return string(tablet1Bytes) -} - -func tablets(t *testing.T, clusterInstance *cluster.LocalProcessCluster, tabletType string) []*cluster.Vttablet { - t.Helper() - - if len(clusterInstance.Keyspaces) == 0 { - return nil - } - - if len(clusterInstance.Keyspaces[0].Shards) == 0 { - return nil - } - - tablets := make([]*cluster.Vttablet, 0, 2) - for _, tablet := range clusterInstance.Keyspaces[0].Shards[0].Vttablets { - if tablet.Type == tabletType { - tablets = append(tablets, tablet) - } - } - - return tablets -} diff --git a/go/test/endtoend/vtgate/sessionbalancer/uschema.sql b/go/test/endtoend/vtgate/sessionbalancer/uschema.sql deleted file mode 100644 index 61b680a618f..00000000000 --- a/go/test/endtoend/vtgate/sessionbalancer/uschema.sql +++ /dev/null @@ -1,4 +0,0 @@ -CREATE TABLE t1 ( - id int PRIMARY KEY, - name varchar(255) -); diff --git a/go/test/endtoend/vtgate/tabletbalancer/session_test.go b/go/test/endtoend/vtgate/tabletbalancer/session_test.go new file mode 100644 index 00000000000..d4ed13c0e53 --- /dev/null +++ b/go/test/endtoend/vtgate/tabletbalancer/session_test.go @@ -0,0 +1,223 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tabletbalancer + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" +) + +// TestSessionModeBalancer tests the "session" mode routes each session consistently to the same tablet. +func TestSessionModeBalancer(t *testing.T) { + vtgateProcess, vtParams, _, _ := setupCluster(t) + defer vtgateProcess.TearDown() + + // Create 2 session connections that route to different tablets + conns := createSessionConnections(t, &vtParams, 2) + for conn := range conns { + defer conn.Close() + } + + verifyStickiness(t, conns, 20) +} + +// TestSessionModeRemoveTablet tests that when a tablet is killed, connections switch to remaining tablets +func TestSessionModeRemoveTablet(t *testing.T) { + vtgateProcess, vtParams, replicaTablets, aliases := setupCluster(t) + defer vtgateProcess.TearDown() + + // Create 2 connections to different tablets + conns := createSessionConnections(t, &vtParams, 2) + for conn := range conns { + defer conn.Close() + } + + // Find the first replica tablet that one of our connections is using + var tabletToKill *cluster.Vttablet + var affectedConn *mysql.Conn + var killedServerID int64 + + for _, tablet := range replicaTablets { + tabletServerID := aliases[tablet.Alias] + + // Check if any connection is using this tablet + for conn, connServerID := range conns { + if connServerID != tabletServerID { + continue + } + + // We found a connection that's using this tablet, let's kill this tablet + tabletToKill = tablet + affectedConn = conn + killedServerID = tabletServerID + break + } + + // We found a tablet, no need to check other tablets + if tabletToKill != nil { + break + } + } + + require.NotNil(t, tabletToKill, "Should find a tablet to kill") + + // Kill the tablet immediately + err := tabletToKill.VttabletProcess.Kill() + require.Error(t, err) + + // Wait for the connection to switch to a new tablet and update the map + require.Eventually(t, func() bool { + newServerID := getServerID(t, affectedConn) + if newServerID != killedServerID { + conns[affectedConn] = newServerID + return true + } + + return false + }, 10*time.Millisecond, 1*time.Millisecond, "Connection should switch to a different tablet") + + verifyStickiness(t, conns, 20) +} + +// setupCluster sets up a cluster with a vtgate using the session balancer. +func setupCluster(t *testing.T) (*cluster.VtgateProcess, mysql.ConnParams, []*cluster.Vttablet, map[string]int64) { + t.Helper() + + // Start vtgate in cell1 with session mode + vtgateProcess := cluster.VtgateProcessInstance( + clusterInstance.GetAndReservePort(), + clusterInstance.GetAndReservePort(), + clusterInstance.GetAndReservePort(), + cell1, + fmt.Sprintf("%s,%s", cell1, cell2), + clusterInstance.Hostname, + replicaStr, + clusterInstance.TopoProcess.Port, + clusterInstance.TmpDirectory, + []string{ + "--vtgate-balancer-mode", "session", + }, + plancontext.PlannerVersion(0), + ) + require.NoError(t, vtgateProcess.Setup()) + require.True(t, vtgateProcess.WaitForStatus()) + + vtParams := mysql.ConnParams{ + Host: clusterInstance.Hostname, + Port: vtgateProcess.MySQLServerPort, + } + + allTablets := clusterInstance.Keyspaces[0].Shards[0].Vttablets + shardName := clusterInstance.Keyspaces[0].Shards[0].Name + replicaTablets := replicaTablets(allTablets) + + conn, err := mysql.Connect(context.Background(), &vtParams) + require.NoError(t, err) + defer conn.Close() + + // Wait for tablets to be discovered + err = vtgateProcess.WaitForStatusOfTabletInShard(fmt.Sprintf("%s.%s.primary", keyspaceName, shardName), 1, 30*time.Second) + require.NoError(t, err) + + err = vtgateProcess.WaitForStatusOfTabletInShard(fmt.Sprintf("%s.%s.replica", keyspaceName, shardName), len(replicaTablets), 30*time.Second) + require.NoError(t, err) + + aliases := mapTabletAliasToMySQLServerID(t, allTablets) + + // Insert test data + testValue := fmt.Sprintf("session_test_%d", time.Now().UnixNano()) + _, err = conn.ExecuteFetch(fmt.Sprintf("INSERT INTO balancer_test (value) VALUES ('%s')", testValue), 1, false) + require.NoError(t, err) + waitForReplication(t, replicaTablets, testValue) + + return vtgateProcess, vtParams, replicaTablets, aliases +} + +// getServerID returns the server ID that the connection is currently routing to. +func getServerID(t *testing.T, conn *mysql.Conn) int64 { + t.Helper() + + res, err := conn.ExecuteFetch("SELECT @@server_id", 1, false) + require.NoError(t, err) + require.Equal(t, 1, len(res.Rows), "expected one row from server_id query") + + serverID, err := res.Rows[0][0].ToInt64() + require.NoError(t, err) + + return serverID +} + +// createSessionConnections creates `n` connections that route to different tablets. +// Returns a map of mysql.Conn -> serverID. +func createSessionConnections(t *testing.T, vtParams *mysql.ConnParams, numConnections int) map[*mysql.Conn]int64 { + t.Helper() + + conns := make(map[*mysql.Conn]int64) + seenServerIDs := make(map[int64]bool) + + // Try up to 50 times to get numConnections with different server IDs + for range 50 { + conn, err := mysql.Connect(context.Background(), vtParams) + require.NoError(t, err) + + _, err = conn.ExecuteFetch("USE @replica", 1, false) + require.NoError(t, err) + + // Get the server ID this connection routes to + serverID := getServerID(t, conn) + + // If this is a new tablet, keep the connection + if !seenServerIDs[serverID] { + seenServerIDs[serverID] = true + conns[conn] = serverID + + // If we have enough connections, return + if len(conns) == numConnections { + return conns + } + + continue + } + + // Already seen this tablet, close and try again + conn.Close() + } + + t.Fatalf("could not create %d connections with different tablets after 50 attempts, only got %d", numConnections, len(conns)) + return nil +} + +// verifyStickiness validates whether the given connections remain connected to the same +// server `n` times in a row. +func verifyStickiness(t *testing.T, conns map[*mysql.Conn]int64, n uint) { + t.Helper() + + for conn, expectedServerID := range conns { + for range n { + currentServerID := getServerID(t, conn) + require.Equal(t, expectedServerID, currentServerID, "Connection should stick to tablet %d, got %d", expectedServerID, currentServerID) + } + } +} From 27a48d81802fdc467a0501a662f692f8e6e4277f Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 12:27:21 -0500 Subject: [PATCH 43/67] add benchmarks comparing all balancers Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/benchmark_test.go | 91 +++++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 go/vt/vtgate/balancer/benchmark_test.go diff --git a/go/vt/vtgate/balancer/benchmark_test.go b/go/vt/vtgate/balancer/benchmark_test.go new file mode 100644 index 00000000000..5cb49adbfc1 --- /dev/null +++ b/go/vt/vtgate/balancer/benchmark_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +import ( + "testing" + + "vitess.io/vitess/go/vt/discovery" + querypb "vitess.io/vitess/go/vt/proto/query" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" +) + +func BenchmarkBalancers(b *testing.B) { + target := &querypb.Target{ + Keyspace: "keyspace", + Shard: "0", + TabletType: topodatapb.TabletType_REPLICA, + Cell: "local", + } + + tablets := []*discovery.TabletHealth{ + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 100, + }, + Keyspace: "keyspace", + Shard: "0", + }, + }, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "local", + Uid: 101, + }, + Keyspace: "keyspace", + Shard: "0", + }, + }, + { + Tablet: &topodatapb.Tablet{ + Alias: &topodatapb.TabletAlias{ + Cell: "external", + Uid: 200, + }, + Keyspace: "keyspace", + Shard: "0", + }, + }, + } + + opts := PickOpts{SessionUUID: "session-uuid-123"} + + benchmarks := []struct { + name string + mode Mode + }{ + {"Session", ModeSession}, + {"Random", ModeRandom}, + {"PreferCell", ModePreferCell}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + balancer, err := NewTabletBalancer(bm.mode, "local", []string{"local", "external"}) + if err != nil { + b.Fatalf("failed to create balancer: %v", err) + } + + for b.Loop() { + balancer.Pick(target, tablets, opts) + } + }) + } +} From 8236141a7e499cb20143cd2ff28664a3587696d6 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 12:43:04 -0500 Subject: [PATCH 44/67] add session mode to help text Signed-off-by: Mohamed Hamza --- go/flags/endtoend/vtgate.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index b1f4ef4b0b0..e6ae18ee32a 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -246,7 +246,7 @@ Flags: -v, --version print binary version --vmodule vModuleFlag comma-separated list of pattern=N settings for file-filtered logging --vschema-ddl-authorized-users string List of users authorized to execute vschema ddl operations, or '%' to allow all users. - --vtgate-balancer-mode string Tablet balancer mode (options: cell, prefer-cell, random). Defaults to 'cell' which shuffles tablets in the local cell. + --vtgate-balancer-mode string Tablet balancer mode (options: cell, prefer-cell, random, session). Defaults to 'cell' which shuffles tablets in the local cell. --vtgate-config-terse-errors prevent bind vars from escaping in returned errors --warming-reads-concurrency int Number of concurrent warming reads allowed (default 500) --warming-reads-percent int Percentage of reads on the primary to forward to replicas. Useful for keeping buffer pools warm From f45a4004ac1af9098365126e59479e4cb34362c3 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 15:48:02 -0500 Subject: [PATCH 45/67] remove benchmark Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/benchmark_test.go | 91 ------------------------- 1 file changed, 91 deletions(-) delete mode 100644 go/vt/vtgate/balancer/benchmark_test.go diff --git a/go/vt/vtgate/balancer/benchmark_test.go b/go/vt/vtgate/balancer/benchmark_test.go deleted file mode 100644 index 5cb49adbfc1..00000000000 --- a/go/vt/vtgate/balancer/benchmark_test.go +++ /dev/null @@ -1,91 +0,0 @@ -/* -Copyright 2025 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package balancer - -import ( - "testing" - - "vitess.io/vitess/go/vt/discovery" - querypb "vitess.io/vitess/go/vt/proto/query" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" -) - -func BenchmarkBalancers(b *testing.B) { - target := &querypb.Target{ - Keyspace: "keyspace", - Shard: "0", - TabletType: topodatapb.TabletType_REPLICA, - Cell: "local", - } - - tablets := []*discovery.TabletHealth{ - { - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 100, - }, - Keyspace: "keyspace", - Shard: "0", - }, - }, - { - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "local", - Uid: 101, - }, - Keyspace: "keyspace", - Shard: "0", - }, - }, - { - Tablet: &topodatapb.Tablet{ - Alias: &topodatapb.TabletAlias{ - Cell: "external", - Uid: 200, - }, - Keyspace: "keyspace", - Shard: "0", - }, - }, - } - - opts := PickOpts{SessionUUID: "session-uuid-123"} - - benchmarks := []struct { - name string - mode Mode - }{ - {"Session", ModeSession}, - {"Random", ModeRandom}, - {"PreferCell", ModePreferCell}, - } - - for _, bm := range benchmarks { - b.Run(bm.name, func(b *testing.B) { - balancer, err := NewTabletBalancer(bm.mode, "local", []string{"local", "external"}) - if err != nil { - b.Fatalf("failed to create balancer: %v", err) - } - - for b.Loop() { - balancer.Pick(target, tablets, opts) - } - }) - } -} From 3f4c17e8c892e40bb6bf448200054fb83c621072 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 16:19:21 -0500 Subject: [PATCH 46/67] remove PickOpts.InvalidTablets Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 3 -- go/vt/vtgate/balancer/session.go | 6 --- go/vt/vtgate/balancer/session_test.go | 15 +++++--- go/vt/vtgate/tabletgateway.go | 54 +++++++++++++-------------- 4 files changed, 37 insertions(+), 41 deletions(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 24b9ecc9b3f..ecdc06256da 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -144,9 +144,6 @@ type TabletBalancer interface { // PickOpts are balancer options that are passed into Pick. type PickOpts struct { - // InvalidTablets is a set of tablets that should not be picked. - InvalidTablets map[string]bool - // SessionUUID is the the current session UUID. SessionUUID string } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index 173252e84f7..ab4c0f9312c 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -55,12 +55,6 @@ func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.Tabl for _, tablet := range tablets { alias := tabletAlias(tablet) - - // Ignore invalid tablets - if _, invalid := opts.InvalidTablets[alias]; invalid { - continue - } - weight := tabletWeight(alias, opts.SessionUUID) if b.isLocal(tablet) && ((maxLocalTablet == nil) || (weight > maxLocalWeight)) { diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index e989cae3220..4d78cb64154 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -18,6 +18,7 @@ package balancer import ( "fmt" + "slices" "testing" "github.com/stretchr/testify/require" @@ -331,14 +332,18 @@ func TestPickInvalidTablets(t *testing.T) { tablet := b.Pick(target, tablets, opts) require.NotNil(t, tablet) - // Mark returned tablet as invalid, should return other tablet - opts.InvalidTablets = map[string]bool{topoproto.TabletAliasString(tablet.Tablet.Alias): true} + // Filter out the returned tablet as invalid + tablets = slices.DeleteFunc(tablets, func(t *discovery.TabletHealth) bool { + return topoproto.TabletAliasString(t.Tablet.Alias) == topoproto.TabletAliasString(tablet.Tablet.Alias) + }) + + // Pick should now return a different tablet tablet2 := b.Pick(target, tablets, opts) + require.NotNil(t, tablet2) require.NotEqual(t, tablet, tablet2) - // Mark both as invalid, should return nil - opts.InvalidTablets[topoproto.TabletAliasString(tablet2.Tablet.Alias)] = true - tablet3 := b.Pick(target, tablets, opts) + // Filter out the last tablet, Pick should return nothing + tablet3 := b.Pick(target, []*discovery.TabletHealth{}, opts) require.Nil(t, tablet3) } diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 4178022c6ca..077fd28c17e 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -322,7 +322,8 @@ func (gw *TabletGateway) DebugBalancerHandler(w http.ResponseWriter, r *http.Req // withRetry also adds shard information to errors returned from the inner QueryService, so // withShardError should not be combined with withRetry. func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, _ queryservice.QueryService, - _ string, opts queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error)) error { + _ string, opts queryservice.WrapOpts, inner func(ctx context.Context, target *querypb.Target, conn queryservice.QueryService) (bool, error), +) error { // for transactions, we connect to a specific tablet instead of letting gateway choose one if opts.InTransaction && target.TabletType != topodatapb.TabletType_PRIMARY { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "tabletGateway's query service can only be used for non-transactional queries on replicas") @@ -436,7 +437,23 @@ func (gw *TabletGateway) withRetry(ctx context.Context, target *querypb.Target, // getBalancerTablet selects a tablet for the given query target, using the configured balancer if enabled. Otherwise, it will // select a random tablet, with preference to the local cell. func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*discovery.TabletHealth, invalidTablets map[string]bool, opts queryservice.WrapOpts) *discovery.TabletHealth { - var tablet *discovery.TabletHealth + // Return early if no tablets are available + if len(tablets) == 0 { + return nil + } + + // Filter out the tablets that we've tried before (if any) + if len(invalidTablets) > 0 { + tablets = slices.DeleteFunc(tablets, func(t *discovery.TabletHealth) bool { + _, isInvalid := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)] + return isInvalid + }) + + // If all tablets are invalid, let's return early + if len(tablets) == 0 { + return nil + } + } // Determine if we should use the balancer for this target useBalancer := gw.balancer != nil @@ -446,39 +463,22 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // Get the tablet from the balancer if enabled if useBalancer { - // filter out the tablets that we've tried before (if any), then pick the best one - if len(invalidTablets) > 0 { - tablets = slices.DeleteFunc(tablets, func(t *discovery.TabletHealth) bool { - _, isInvalid := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)] - return isInvalid - }) - } - var sessionUUID string if opts.Options != nil { sessionUUID = opts.Options.SessionUUID } - pickOpts := balancer.PickOpts{SessionUUID: sessionUUID, InvalidTablets: invalidTablets} - tablet = gw.balancer.Pick(target, tablets, pickOpts) - } - - if tablet != nil { - return tablet - } - - // If the balancer isn't enabled, or it didn't return a tablet, randomly select a - // tablet, with preference to the local cell - gw.shuffleTablets(gw.localCell, tablets) - - // skip tablets we tried before - for _, t := range tablets { - if _, ok := invalidTablets[topoproto.TabletAliasString(t.Tablet.Alias)]; !ok { - return t + tablet := gw.balancer.Pick(target, tablets, balancer.PickOpts{SessionUUID: sessionUUID}) + if tablet != nil { + return tablet } } - return nil + // If the balancer isn't enabled, or it didn't return a tablet, shuffle the tablets + // and return the first one. (This will always contain at least one tablet due to the + // check above). + gw.shuffleTablets(gw.localCell, tablets) + return tablets[0] } // withShardError adds shard information to errors returned from the inner QueryService. From 6abc3f780d8d5a3cf98575e431623936a5a899df Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 10 Dec 2025 19:49:03 -0500 Subject: [PATCH 47/67] add promptless changelog suggestion Signed-off-by: Mohamed Hamza --- changelog/24.0/24.0.0/changelog.md | 58 ++++++++++++++++++++++++++++++ changelog/24.0/24.0.0/summary.md | 11 ++++++ 2 files changed, 69 insertions(+) create mode 100644 changelog/24.0/24.0.0/changelog.md diff --git a/changelog/24.0/24.0.0/changelog.md b/changelog/24.0/24.0.0/changelog.md new file mode 100644 index 00000000000..258e41e6366 --- /dev/null +++ b/changelog/24.0/24.0.0/changelog.md @@ -0,0 +1,58 @@ +# Release of Vitess v24.0.0 +## Summary + +### Table of Contents + +- **[Major Changes](#major-changes)** + - **[New Support](#new-support)** + - [Window function pushdown for sharded keyspaces](#window-function-pushdown) +- **[Minor Changes](#minor-changes)** + - **[VTGate](#minor-changes-vtgate)** + - [New default for `--legacy-replication-lag-algorithm` flag](#vtgate-new-default-legacy-replication-lag-algorithm) + - [New "session" mode for `--vtgate-balancer-mode` flag](#vtgate-session-balancer-mode) + - **[VTTablet](#minor-changes-vttablet)** + - [New Experimental flag `--init-tablet-type-lookup`](#vttablet-init-tablet-type-lookup) + +## Major Changes + +### New Support + +#### Window function pushdown for sharded keyspaces + +This release introduces an optimization that allows window functions to be pushed down to individual shards when they are partitioned by a column that matches a unique vindex. + +Previously, all window function queries required single-shard routing, which limited their applicability on sharded tables. With this change, queries where the `PARTITION BY` clause aligns with a unique vindex can now be pushed down and executed on each shard. + +For examples and more details, see the [documentation](https://vitess.io/docs/24.0/reference/compatibility/mysql-compatibility/#window-functions). + +## Minor Changes + +### VTGate + +#### New default for `--legacy-replication-lag-algorithm` flag + +The VTGate flag `--legacy-replication-lag-algorithm` now defaults to `false`, disabling the legacy approach to handling replication lag by default. + +Instead, a simpler algorithm purely based on low lag, high lag and minimum number of tablets is used, which has proven to be more stable in many production environments. A detailed explanation of the two approaches [is explained in this code comment](https://github.com/vitessio/vitess/blob/main/go/vt/discovery/replicationlag.go#L125-L149). + +In v25 this flag will become deprecated and in the following release it will be removed. In the meantime, the legacy behaviour can be used by setting `--legacy-replication-lag-algorithm=true`. This deprecation is tracked in https://github.com/vitessio/vitess/issues/18914. + +#### New "session" mode for `--vtgate-balancer-mode` flag + +The VTGate flag `--vtgate-balancer-mode` now supports a new "session" mode in addition to the existing "cell", "prefer-cell", and "random" modes. Session mode routes each session consistently to the same tablet for the session's duration. + +To enable session mode, set the flag when starting VTGate: + +``` +--vtgate-balancer-mode=session +``` + +### VTTablet + +#### New Experimental flag `--init-tablet-type-lookup` + +The new experimental flag `--init-tablet-type-lookup` for VTTablet allows tablets to automatically restore their previous tablet type on restart by looking up the existing topology record, rather than always using the static `--init-tablet-type` value. + +When enabled, the tablet uses its alias to look up the tablet type from the existing topology record on restart. This allows tablets to maintain their changed roles (e.g., RDONLY/DRAINED) across restarts without manual reconfiguration. If disabled or if no topology record exists, the standard `--init-tablet-type` value will be used instead. + +**Note**: Vitess Operator–managed deployments generally do not keep tablet records in the topo between restarts, so this feature will not take effect in those environments. \ No newline at end of file diff --git a/changelog/24.0/24.0.0/summary.md b/changelog/24.0/24.0.0/summary.md index 696da1ef99c..169cb82f4da 100644 --- a/changelog/24.0/24.0.0/summary.md +++ b/changelog/24.0/24.0.0/summary.md @@ -9,6 +9,7 @@ - **[Minor Changes](#minor-changes)** - **[VTGate](#minor-changes-vtgate)** - [New default for `--legacy-replication-lag-algorithm` flag](#vtgate-new-default-legacy-replication-lag-algorithm) + - [New "session" mode for `--vtgate-balancer-mode` flag](#vtgate-session-balancer-mode) - **[VTTablet](#minor-changes-vttablet)** - [New Experimental flag `--init-tablet-type-lookup`](#vttablet-init-tablet-type-lookup) @@ -36,6 +37,16 @@ Instead, a simpler algorithm purely based on low lag, high lag and minimum numbe In v25 this flag will become deprecated and in the following release it will be removed. In the meantime, the legacy behaviour can be used by setting `--legacy-replication-lag-algorithm=true`. This deprecation is tracked in https://github.com/vitessio/vitess/issues/18914. +#### New "session" mode for `--vtgate-balancer-mode` flag + +The VTGate flag `--vtgate-balancer-mode` now supports a new "session" mode in addition to the existing "cell", "prefer-cell", and "random" modes. Session mode routes each session consistently to the same tablet for the session's duration. + +To enable session mode, set the flag when starting VTGate: + +``` +--vtgate-balancer-mode=session +``` + ### VTTablet #### New Experimental flag `--init-tablet-type-lookup` From 1b9b7d9f893b9c3415df32a9efc3b2985f491c83 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 11 Dec 2025 11:14:08 -0500 Subject: [PATCH 48/67] temp: add session to queryservice Signed-off-by: Mohamed Hamza --- go/vt/vtgate/scatter_conn.go | 6 ++++-- go/vt/vttablet/queryservice/queryservice.go | 12 +++++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 96fc5f066e4..01f0922bbde 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -221,9 +221,11 @@ func (stc *ScatterConn) ExecuteMultiShard( } } + qsSession := queryservice.Session{SessionUUID: session.GetSessionUUID()} + switch info.actionNeeded { case nothing: - innerqr, err = qs.Execute(ctx, rs.Target, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts) + innerqr, err = qs.Execute(ctx, rs.Target, qsSession, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts) if err != nil { retryRequest(func() { // we seem to have lost our connection. it was a reserved connection, let's try to recreate it @@ -877,7 +879,7 @@ func actionInfo(ctx context.Context, target *querypb.Target, session *econtext.S shouldReserve := session.InReservedConn() && (shardSession == nil || shardSession.ReservedId == 0) shouldBegin := session.InTransaction() && (shardSession == nil || shardSession.TransactionId == 0) && !autocommit - var act = nothing + act := nothing switch { case shouldBegin && shouldReserve: act = reserveBegin diff --git a/go/vt/vttablet/queryservice/queryservice.go b/go/vt/vttablet/queryservice/queryservice.go index d6972bfb6a3..ff9d93d6640 100644 --- a/go/vt/vttablet/queryservice/queryservice.go +++ b/go/vt/vttablet/queryservice/queryservice.go @@ -19,10 +19,10 @@ limitations under the License. package queryservice import ( - topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "context" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/sqltypes" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" @@ -80,7 +80,7 @@ type QueryService interface { UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) ([]*querypb.TransactionMetadata, error) // Execute for query execution - Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) + Execute(ctx context.Context, target *querypb.Target, session Session, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) // StreamExecute for query execution with streaming StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error @@ -148,3 +148,9 @@ type ReservedTransactionState struct { TabletAlias *topodatapb.TabletAlias SessionStateChanges string } + +// Session represents the current VTGate session. +type Session struct { + // SessionUUID is the UUID of the current session. + SessionUUID string +} From 5f497bf1efbf82a480839483f9b6d72d1b47053f Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 11 Dec 2025 11:26:22 -0500 Subject: [PATCH 49/67] Revert "temp: add session to queryservice" This reverts commit 7fc0d991b06273ef4d5118ea9aaa566f79794eeb. Signed-off-by: Mohamed Hamza --- go/vt/vtgate/scatter_conn.go | 6 ++---- go/vt/vttablet/queryservice/queryservice.go | 12 +++--------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 01f0922bbde..96fc5f066e4 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -221,11 +221,9 @@ func (stc *ScatterConn) ExecuteMultiShard( } } - qsSession := queryservice.Session{SessionUUID: session.GetSessionUUID()} - switch info.actionNeeded { case nothing: - innerqr, err = qs.Execute(ctx, rs.Target, qsSession, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts) + innerqr, err = qs.Execute(ctx, rs.Target, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts) if err != nil { retryRequest(func() { // we seem to have lost our connection. it was a reserved connection, let's try to recreate it @@ -879,7 +877,7 @@ func actionInfo(ctx context.Context, target *querypb.Target, session *econtext.S shouldReserve := session.InReservedConn() && (shardSession == nil || shardSession.ReservedId == 0) shouldBegin := session.InTransaction() && (shardSession == nil || shardSession.TransactionId == 0) && !autocommit - act := nothing + var act = nothing switch { case shouldBegin && shouldReserve: act = reserveBegin diff --git a/go/vt/vttablet/queryservice/queryservice.go b/go/vt/vttablet/queryservice/queryservice.go index ff9d93d6640..d6972bfb6a3 100644 --- a/go/vt/vttablet/queryservice/queryservice.go +++ b/go/vt/vttablet/queryservice/queryservice.go @@ -19,10 +19,10 @@ limitations under the License. package queryservice import ( - "context" - topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "context" + "vitess.io/vitess/go/sqltypes" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" @@ -80,7 +80,7 @@ type QueryService interface { UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) ([]*querypb.TransactionMetadata, error) // Execute for query execution - Execute(ctx context.Context, target *querypb.Target, session Session, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) + Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) // StreamExecute for query execution with streaming StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error @@ -148,9 +148,3 @@ type ReservedTransactionState struct { TabletAlias *topodatapb.TabletAlias SessionStateChanges string } - -// Session represents the current VTGate session. -type Session struct { - // SessionUUID is the UUID of the current session. - SessionUUID string -} From dc61e95c1bf99579bf23b63833c2187b2841724d Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Thu, 11 Dec 2025 12:03:09 -0500 Subject: [PATCH 50/67] remove random changelog file Signed-off-by: Mohamed Hamza --- changelog/24.0/24.0.0/changelog.md | 58 ------------------------------ 1 file changed, 58 deletions(-) delete mode 100644 changelog/24.0/24.0.0/changelog.md diff --git a/changelog/24.0/24.0.0/changelog.md b/changelog/24.0/24.0.0/changelog.md deleted file mode 100644 index 258e41e6366..00000000000 --- a/changelog/24.0/24.0.0/changelog.md +++ /dev/null @@ -1,58 +0,0 @@ -# Release of Vitess v24.0.0 -## Summary - -### Table of Contents - -- **[Major Changes](#major-changes)** - - **[New Support](#new-support)** - - [Window function pushdown for sharded keyspaces](#window-function-pushdown) -- **[Minor Changes](#minor-changes)** - - **[VTGate](#minor-changes-vtgate)** - - [New default for `--legacy-replication-lag-algorithm` flag](#vtgate-new-default-legacy-replication-lag-algorithm) - - [New "session" mode for `--vtgate-balancer-mode` flag](#vtgate-session-balancer-mode) - - **[VTTablet](#minor-changes-vttablet)** - - [New Experimental flag `--init-tablet-type-lookup`](#vttablet-init-tablet-type-lookup) - -## Major Changes - -### New Support - -#### Window function pushdown for sharded keyspaces - -This release introduces an optimization that allows window functions to be pushed down to individual shards when they are partitioned by a column that matches a unique vindex. - -Previously, all window function queries required single-shard routing, which limited their applicability on sharded tables. With this change, queries where the `PARTITION BY` clause aligns with a unique vindex can now be pushed down and executed on each shard. - -For examples and more details, see the [documentation](https://vitess.io/docs/24.0/reference/compatibility/mysql-compatibility/#window-functions). - -## Minor Changes - -### VTGate - -#### New default for `--legacy-replication-lag-algorithm` flag - -The VTGate flag `--legacy-replication-lag-algorithm` now defaults to `false`, disabling the legacy approach to handling replication lag by default. - -Instead, a simpler algorithm purely based on low lag, high lag and minimum number of tablets is used, which has proven to be more stable in many production environments. A detailed explanation of the two approaches [is explained in this code comment](https://github.com/vitessio/vitess/blob/main/go/vt/discovery/replicationlag.go#L125-L149). - -In v25 this flag will become deprecated and in the following release it will be removed. In the meantime, the legacy behaviour can be used by setting `--legacy-replication-lag-algorithm=true`. This deprecation is tracked in https://github.com/vitessio/vitess/issues/18914. - -#### New "session" mode for `--vtgate-balancer-mode` flag - -The VTGate flag `--vtgate-balancer-mode` now supports a new "session" mode in addition to the existing "cell", "prefer-cell", and "random" modes. Session mode routes each session consistently to the same tablet for the session's duration. - -To enable session mode, set the flag when starting VTGate: - -``` ---vtgate-balancer-mode=session -``` - -### VTTablet - -#### New Experimental flag `--init-tablet-type-lookup` - -The new experimental flag `--init-tablet-type-lookup` for VTTablet allows tablets to automatically restore their previous tablet type on restart by looking up the existing topology record, rather than always using the static `--init-tablet-type` value. - -When enabled, the tablet uses its alias to look up the tablet type from the existing topology record on restart. This allows tablets to maintain their changed roles (e.g., RDONLY/DRAINED) across restarts without manual reconfiguration. If disabled or if no topology record exists, the standard `--init-tablet-type` value will be used instead. - -**Note**: Vitess Operator–managed deployments generally do not keep tablet records in the topo between restarts, so this feature will not take effect in those environments. \ No newline at end of file From ee8f2b690ecc29e8d6f958c7e293040488ea548a Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 12:48:54 -0500 Subject: [PATCH 51/67] Update `QueryService` to include session Updates the `QueryService` interface to include a session, and removes the `options` parameter as it can be accessed through the session. Signed-off-by: Mohamed Hamza --- go/test/endtoend/cluster/cluster_process.go | 15 +- go/vt/proto/query/query.pb.go | 18 +- go/vt/proto/query/query_vtproto.pb.go | 46 -- go/vt/vtcombo/tablet_map.go | 37 +- go/vt/vtexplain/vtexplain_vttablet.go | 12 +- go/vt/vtgate/balancer/balancer.go | 1 + go/vt/vtgate/executor.go | 20 +- go/vt/vtgate/plugin_mysql_server.go | 1 - go/vt/vtgate/scatter_conn.go | 52 +- go/vt/vtgate/tabletgateway.go | 4 +- go/vt/vtgate/tabletgateway_flaky_test.go | 11 +- go/vt/vtgate/tabletgateway_test.go | 13 +- go/vt/vttablet/grpcqueryservice/server.go | 47 +- go/vt/vttablet/grpctabletconn/conn.go | 36 +- go/vt/vttablet/grpctabletconn/conn_test.go | 9 +- .../fakes/stream_health_query_service.go | 4 +- go/vt/vttablet/queryservice/queryservice.go | 31 +- go/vt/vttablet/queryservice/wrapped.go | 56 +-- go/vt/vttablet/sandboxconn/sandboxconn.go | 56 +-- .../tabletconntest/fakequeryservice.go | 53 +- .../vttablet/tabletconntest/tabletconntest.go | 45 +- .../vttablet/tabletmanager/framework_test.go | 18 +- go/vt/vttablet/tabletmanager/rpc_query.go | 3 +- go/vt/vttablet/tabletserver/bench_test.go | 11 +- .../vttablet/tabletserver/dt_executor_test.go | 3 +- .../tabletserver/query_executor_test.go | 458 +++++++++--------- go/vt/vttablet/tabletserver/tabletserver.go | 53 +- .../tabletserver/tabletserver_test.go | 147 +++--- go/vtbench/client.go | 3 +- proto/query.proto | 3 - web/vtadmin/src/proto/vtadmin.d.ts | 6 - web/vtadmin/src/proto/vtadmin.js | 23 - 32 files changed, 634 insertions(+), 661 deletions(-) diff --git a/go/test/endtoend/cluster/cluster_process.go b/go/test/endtoend/cluster/cluster_process.go index 3d95b39915e..5f21c749ec0 100644 --- a/go/test/endtoend/cluster/cluster_process.go +++ b/go/test/endtoend/cluster/cluster_process.go @@ -45,12 +45,14 @@ import ( "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/vtgateconn" "vitess.io/vitess/go/vt/vttablet/tabletconn" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" // Ensure dialers are registered (needed by ExecOnTablet and ExecOnVTGate). @@ -807,7 +809,7 @@ func NewBareCluster(cell string, hostname string) *LocalProcessCluster { // path/to/whatever exists cluster.ReusingVTDATAROOT = true } else { - err = createDirectory(cluster.CurrentVTDATAROOT, 0700) + err = createDirectory(cluster.CurrentVTDATAROOT, 0o700) if err != nil { log.Fatal(err) } @@ -948,11 +950,12 @@ func (cluster *LocalProcessCluster) ExecOnTablet(ctx context.Context, vttablet * txID, reservedID := 0, 0 - return conn.Execute(ctx, &querypb.Target{ + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: opts}) + return conn.Execute(ctx, session, &querypb.Target{ Keyspace: tablet.Keyspace, Shard: tablet.Shard, TabletType: tablet.Type, - }, sql, bindvars, int64(txID), int64(reservedID), opts) + }, sql, bindvars, int64(txID), int64(reservedID)) } // ExecOnVTGate executes a query on a local cluster VTGate with the provided @@ -1165,7 +1168,8 @@ func (cluster *LocalProcessCluster) waitForMySQLProcessToExit(mysqlctlProcessLis // StartVtbackup starts a vtbackup func (cluster *LocalProcessCluster) StartVtbackup(newInitDBFile string, initialBackup bool, - keyspace string, shard string, cell string, extraArgs ...string) error { + keyspace string, shard string, cell string, extraArgs ...string, +) error { log.Info("Starting vtbackup") cluster.VtbackupProcess = *VtbackupProcessInstance( cluster.GetAndReserveTabletUID(), @@ -1195,7 +1199,6 @@ func (cluster *LocalProcessCluster) GetAndReservePort() int { cluster.nextPortForProcess = cluster.nextPortForProcess + 1 log.Infof("Attempting to reserve port: %v", cluster.nextPortForProcess) ln, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(cluster.nextPortForProcess))) - if err != nil { log.Errorf("Can't listen on port %v: %s, trying next port", cluster.nextPortForProcess, err) continue @@ -1218,7 +1221,7 @@ const portFileTimeout = 1 * time.Hour // If yes, then return that port, and save port + 200 in the same file // here, assumptions is 200 ports might be consumed for all tests in a package func getPort() int { - portFile, err := os.OpenFile(path.Join(os.TempDir(), "endtoend.port"), os.O_CREATE|os.O_RDWR, 0644) + portFile, err := os.OpenFile(path.Join(os.TempDir(), "endtoend.port"), os.O_CREATE|os.O_RDWR, 0o644) if err != nil { panic(err) } diff --git a/go/vt/proto/query/query.pb.go b/go/vt/proto/query/query.pb.go index c82b5486922..accc9b305dd 100644 --- a/go/vt/proto/query/query.pb.go +++ b/go/vt/proto/query/query.pb.go @@ -1404,10 +1404,8 @@ type ExecuteOptions struct { InDmlExecution bool `protobuf:"varint,19,opt,name=in_dml_execution,json=inDmlExecution,proto3" json:"in_dml_execution,omitempty"` // transaction_timeout specifies the transaction timeout in milliseconds. If not set, the default timeout is used. TransactionTimeout *int64 `protobuf:"varint,20,opt,name=transaction_timeout,json=transactionTimeout,proto3,oneof" json:"transaction_timeout,omitempty"` - // SessionUUID is the UUID of the current session. - SessionUUID string `protobuf:"bytes,21,opt,name=SessionUUID,proto3" json:"SessionUUID,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExecuteOptions) Reset() { @@ -1561,13 +1559,6 @@ func (x *ExecuteOptions) GetTransactionTimeout() int64 { return 0 } -func (x *ExecuteOptions) GetSessionUUID() string { - if x != nil { - return x.SessionUUID - } - return "" -} - type isExecuteOptions_Timeout interface { isExecuteOptions_Timeout() } @@ -5835,7 +5826,7 @@ const file_query_proto_rawDesc = "" + "\x0ebind_variables\x18\x02 \x03(\v2$.query.BoundQuery.BindVariablesEntryR\rbindVariables\x1aU\n" + "\x12BindVariablesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12)\n" + - "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\xa5\r\n" + + "\x05value\x18\x02 \x01(\v2\x13.query.BindVariableR\x05value:\x028\x01\"\x83\r\n" + "\x0eExecuteOptions\x12M\n" + "\x0fincluded_fields\x18\x04 \x01(\x0e2$.query.ExecuteOptions.IncludedFieldsR\x0eincludedFields\x12*\n" + "\x11client_found_rows\x18\x05 \x01(\bR\x0fclientFoundRows\x12:\n" + @@ -5853,8 +5844,7 @@ const file_query_proto_rawDesc = "" + "\x15authoritative_timeout\x18\x11 \x01(\x03H\x00R\x14authoritativeTimeout\x12/\n" + "\x14fetch_last_insert_id\x18\x12 \x01(\bR\x11fetchLastInsertId\x12(\n" + "\x10in_dml_execution\x18\x13 \x01(\bR\x0einDmlExecution\x124\n" + - "\x13transaction_timeout\x18\x14 \x01(\x03H\x01R\x12transactionTimeout\x88\x01\x01\x12 \n" + - "\vSessionUUID\x18\x15 \x01(\tR\vSessionUUID\";\n" + + "\x13transaction_timeout\x18\x14 \x01(\x03H\x01R\x12transactionTimeout\x88\x01\x01\";\n" + "\x0eIncludedFields\x12\x11\n" + "\rTYPE_AND_NAME\x10\x00\x12\r\n" + "\tTYPE_ONLY\x10\x01\x12\a\n" + diff --git a/go/vt/proto/query/query_vtproto.pb.go b/go/vt/proto/query/query_vtproto.pb.go index 4e5589ee61a..3b32b52bf49 100644 --- a/go/vt/proto/query/query_vtproto.pb.go +++ b/go/vt/proto/query/query_vtproto.pb.go @@ -178,7 +178,6 @@ func (m *ExecuteOptions) CloneVT() *ExecuteOptions { r.Priority = m.Priority r.FetchLastInsertId = m.FetchLastInsertId r.InDmlExecution = m.InDmlExecution - r.SessionUUID = m.SessionUUID if rhs := m.TransactionAccessMode; rhs != nil { tmpContainer := make([]ExecuteOptions_TransactionAccessMode, len(rhs)) copy(tmpContainer, rhs) @@ -1901,15 +1900,6 @@ func (m *ExecuteOptions) MarshalToSizedBufferVT(dAtA []byte) (int, error) { } i -= size } - if len(m.SessionUUID) > 0 { - i -= len(m.SessionUUID) - copy(dAtA[i:], m.SessionUUID) - i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.SessionUUID))) - i-- - dAtA[i] = 0x1 - i-- - dAtA[i] = 0xaa - } if m.TransactionTimeout != nil { i = protohelpers.EncodeVarint(dAtA, i, uint64(*m.TransactionTimeout)) i-- @@ -6236,10 +6226,6 @@ func (m *ExecuteOptions) SizeVT() (n int) { if m.TransactionTimeout != nil { n += 2 + protohelpers.SizeOfVarint(uint64(*m.TransactionTimeout)) } - l = len(m.SessionUUID) - if l > 0 { - n += 2 + l + protohelpers.SizeOfVarint(uint64(l)) - } n += len(m.unknownFields) return n } @@ -9048,38 +9034,6 @@ func (m *ExecuteOptions) UnmarshalVT(dAtA []byte) error { } } m.TransactionTimeout = &v - case 21: - if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field SessionUUID", wireType) - } - var stringLen uint64 - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return protohelpers.ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - stringLen |= uint64(b&0x7F) << shift - if b < 0x80 { - break - } - } - intStringLen := int(stringLen) - if intStringLen < 0 { - return protohelpers.ErrInvalidLength - } - postIndex := iNdEx + intStringLen - if postIndex < 0 { - return protohelpers.ErrInvalidLength - } - if postIndex > l { - return io.ErrUnexpectedEOF - } - m.SessionUUID = string(dAtA[iNdEx:postIndex]) - iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) diff --git a/go/vt/vtcombo/tablet_map.go b/go/vt/vtcombo/tablet_map.go index 79b9cfb4e7f..45166d04665 100644 --- a/go/vt/vtcombo/tablet_map.go +++ b/go/vt/vtcombo/tablet_map.go @@ -460,14 +460,14 @@ var _ queryservice.QueryService = (*internalTabletConn)(nil) // We need to copy the bind variables as tablet server will change them. func (itc *internalTabletConn) Execute( ctx context.Context, + session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, - options *querypb.ExecuteOptions, ) (*sqltypes.Result, error) { bindVars = sqltypes.CopyBindVariables(bindVars) - reply, err := itc.tablet.qsc.QueryService().Execute(ctx, target, query, bindVars, transactionID, reservedID, options) + reply, err := itc.tablet.qsc.QueryService().Execute(ctx, session, target, query, bindVars, transactionID, reservedID) if err != nil { return nil, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } @@ -478,26 +478,26 @@ func (itc *internalTabletConn) Execute( // We need to copy the bind variables as tablet server will change them. func (itc *internalTabletConn) StreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) error { bindVars = sqltypes.CopyBindVariables(bindVars) - err := itc.tablet.qsc.QueryService().StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, callback) + err := itc.tablet.qsc.QueryService().StreamExecute(ctx, session, target, query, bindVars, transactionID, reservedID, callback) return tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } // Begin is part of queryservice.QueryService func (itc *internalTabletConn) Begin( ctx context.Context, + session queryservice.Session, target *querypb.Target, - options *querypb.ExecuteOptions, ) (queryservice.TransactionState, error) { - state, err := itc.tablet.qsc.QueryService().Begin(ctx, target, options) + state, err := itc.tablet.qsc.QueryService().Begin(ctx, session, target) return state, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } @@ -575,31 +575,31 @@ func (itc *internalTabletConn) UnresolvedTransactions(ctx context.Context, targe // BeginExecute is part of queryservice.QueryService func (itc *internalTabletConn) BeginExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reserveID int64, - options *querypb.ExecuteOptions, ) (queryservice.TransactionState, *sqltypes.Result, error) { bindVars = sqltypes.CopyBindVariables(bindVars) - state, result, err := itc.tablet.qsc.QueryService().BeginExecute(ctx, target, preQueries, query, bindVars, reserveID, options) + state, result, err := itc.tablet.qsc.QueryService().BeginExecute(ctx, session, target, preQueries, query, bindVars, reserveID) return state, result, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } // BeginStreamExecute is part of queryservice.QueryService func (itc *internalTabletConn) BeginStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (queryservice.TransactionState, error) { bindVars = sqltypes.CopyBindVariables(bindVars) - state, err := itc.tablet.qsc.QueryService().BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) + state, err := itc.tablet.qsc.QueryService().BeginStreamExecute(ctx, session, target, preQueries, query, bindVars, reservedID, callback) return state, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } @@ -622,62 +622,62 @@ func (itc *internalTabletConn) HandlePanic(err *error) { // ReserveBeginExecute is part of the QueryService interface. func (itc *internalTabletConn) ReserveBeginExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, - options *querypb.ExecuteOptions, ) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { bindVariables = sqltypes.CopyBindVariables(bindVariables) - state, result, err := itc.tablet.qsc.QueryService().ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) + state, result, err := itc.tablet.qsc.QueryService().ReserveBeginExecute(ctx, session, target, preQueries, postBeginQueries, sql, bindVariables) return state, result, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } // ReserveBeginStreamExecute is part of the QueryService interface. func (itc *internalTabletConn) ReserveBeginStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (queryservice.ReservedTransactionState, error) { bindVariables = sqltypes.CopyBindVariables(bindVariables) - state, err := itc.tablet.qsc.QueryService().ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) + state, err := itc.tablet.qsc.QueryService().ReserveBeginStreamExecute(ctx, session, target, preQueries, postBeginQueries, sql, bindVariables, callback) return state, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } // ReserveExecute is part of the QueryService interface. func (itc *internalTabletConn) ReserveExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, - options *querypb.ExecuteOptions, ) (queryservice.ReservedState, *sqltypes.Result, error) { bindVariables = sqltypes.CopyBindVariables(bindVariables) - state, result, err := itc.tablet.qsc.QueryService().ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) + state, result, err := itc.tablet.qsc.QueryService().ReserveExecute(ctx, session, target, preQueries, sql, bindVariables, transactionID) return state, result, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } // ReserveStreamExecute is part of the QueryService interface. func (itc *internalTabletConn) ReserveStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (queryservice.ReservedState, error) { bindVariables = sqltypes.CopyBindVariables(bindVariables) - state, err := itc.tablet.qsc.QueryService().ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) + state, err := itc.tablet.qsc.QueryService().ReserveStreamExecute(ctx, session, target, preQueries, sql, bindVariables, transactionID, callback) return state, tabletconn.ErrorFromGRPC(vterrors.ToGRPC(err)) } @@ -1103,6 +1103,7 @@ func (itmc *internalTabletManagerClient) ResetReplicationParameters(context.Cont func (itmc *internalTabletManagerClient) ReplicaWasRestarted(context.Context, *topodatapb.Tablet, *topodatapb.TabletAlias) error { return errors.New("not implemented in vtcombo") } + func (itmc *internalTabletManagerClient) ResetSequences(ctx context.Context, tablet *topodatapb.Tablet, tables []string) error { return errors.New("not implemented in vtcombo") } diff --git a/go/vt/vtexplain/vtexplain_vttablet.go b/go/vt/vtexplain/vtexplain_vttablet.go index 19bcd2d7a0b..6dd2fc250a9 100644 --- a/go/vt/vtexplain/vtexplain_vttablet.go +++ b/go/vt/vtexplain/vtexplain_vttablet.go @@ -155,7 +155,7 @@ func (vte *VTExplain) newTablet(ctx context.Context, env *vtenv.Environment, opt var _ queryservice.QueryService = (*explainTablet)(nil) // compile-time interface check // Begin is part of the QueryService interface. -func (t *explainTablet) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (t *explainTablet) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (queryservice.TransactionState, error) { t.mu.Lock() t.currentTime = t.vte.batchTime.Wait() t.tabletQueries = append(t.tabletQueries, &TabletQuery{ @@ -165,7 +165,7 @@ func (t *explainTablet) Begin(ctx context.Context, target *querypb.Target, optio t.mu.Unlock() - return t.tsv.Begin(ctx, target, options) + return t.tsv.Begin(ctx, session, target) } // Commit is part of the QueryService interface. @@ -190,7 +190,7 @@ func (t *explainTablet) Rollback(ctx context.Context, target *querypb.Target, tr } // Execute is part of the QueryService interface. -func (t *explainTablet) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (t *explainTablet) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { t.mu.Lock() t.currentTime = t.vte.batchTime.Wait() @@ -204,7 +204,7 @@ func (t *explainTablet) Execute(ctx context.Context, target *querypb.Target, sql }) t.mu.Unlock() - return t.tsv.Execute(ctx, target, sql, bindVariables, transactionID, reservedID, options) + return t.tsv.Execute(ctx, session, target, sql, bindVariables, transactionID, reservedID) } // Prepare is part of the QueryService interface. @@ -264,7 +264,7 @@ func (t *explainTablet) ReadTransaction(ctx context.Context, target *querypb.Tar } // BeginExecute is part of the QueryService interface. -func (t *explainTablet) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { +func (t *explainTablet) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { t.mu.Lock() t.currentTime = t.vte.batchTime.Wait() bindVariables = sqltypes.CopyBindVariables(bindVariables) @@ -275,7 +275,7 @@ func (t *explainTablet) BeginExecute(ctx context.Context, target *querypb.Target }) t.mu.Unlock() - return t.tsv.BeginExecute(ctx, target, preQueries, sql, bindVariables, reservedID, options) + return t.tsv.BeginExecute(ctx, session, target, preQueries, sql, bindVariables, reservedID) } // Close is part of the QueryService interface. diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index ecdc06256da..a785c6f0382 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ +// Package balancer contains a number of different tablet balancing algorithms and their implementations. package balancer import ( diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 4c1fbedb447..e3ed264f44b 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -155,12 +155,16 @@ type ( var executorOnce sync.Once -const pathQueryPlans = "/debug/query_plans" -const pathScatterStats = "/debug/scatter_stats" -const pathVSchema = "/debug/vschema" +const ( + pathQueryPlans = "/debug/query_plans" + pathScatterStats = "/debug/scatter_stats" + pathVSchema = "/debug/vschema" +) -type PlanCacheKey = theine.HashKey256 -type PlanCache = theine.Store[PlanCacheKey, *engine.Plan] +type ( + PlanCacheKey = theine.HashKey256 + PlanCache = theine.Store[PlanCacheKey, *engine.Plan] +) func DefaultPlanCache() *PlanCache { // when being endtoend tested, disable the doorkeeper to ensure reproducible results @@ -365,7 +369,6 @@ func (e *Executor) StreamExecute( err := vc.StreamExecutePrimitive(ctx, plan.Instructions, bindVars, true, func(qr *sqltypes.Result) error { return srr.storeResultStats(plan.QueryType, qr) }) - // Check if there was partial DML execution. If so, rollback the effect of the partially executed query. if err != nil { if safeSession.InTransaction() && e.rollbackOnFatalTxError(ctx, safeSession, err) { @@ -1009,7 +1012,7 @@ func (e *Executor) ShowVitessReplicationStatus(ctx context.Context, filter *sqlp replicaSQLRunningField = "Slave_SQL_Running" secondsBehindSourceField = "Seconds_Behind_Master" } - results, err := e.txConn.tabletGateway.Execute(ctx, ts.Target, sql, nil, 0, 0, nil) + results, err := e.txConn.tabletGateway.Execute(ctx, econtext.NewSafeSession(nil), ts.Target, sql, nil, 0, 0) if err != nil || results == nil { log.Warningf("Could not get replication status from %s: %v", tabletHostPort, err) } else if row := results.Named().Row(); row != nil { @@ -1119,7 +1122,8 @@ func (e *Executor) fetchOrCreatePlan( logStats *logstats.LogStats, isExecutePath bool, // this means we are trying to execute the query - this is not a PREPARE call ) ( - plan *engine.Plan, vcursor *econtext.VCursorImpl, stmt sqlparser.Statement, err error) { + plan *engine.Plan, vcursor *econtext.VCursorImpl, stmt sqlparser.Statement, err error, +) { if e.VSchema() == nil { return nil, nil, nil, vterrors.VT13001("vschema not initialized") } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index a5f9c07edf8..cfeb014d588 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -522,7 +522,6 @@ func (vh *vtgateHandler) session(c *mysql.Conn) *vtgatepb.Session { Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, Workload: querypb.ExecuteOptions_Workload(mysqlDefaultWorkload), - SessionUUID: u.String(), // The collation field of ExecuteOption is set right before an execution. }, diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 96fc5f066e4..7fde259f1fd 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -180,19 +180,15 @@ func (stc *ScatterConn) ExecuteMultiShard( var ( innerqr *sqltypes.Result err error - opts *querypb.ExecuteOptions alias *topodatapb.TabletAlias qs queryservice.QueryService ) transactionID := info.transactionID reservedID := info.reservedID - if session != nil && session.Session != nil { - opts = session.Session.Options - } - - if opts == nil && fetchLastInsertID { - opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} + if session.GetOptions() == nil && fetchLastInsertID { + session = econtext.NewSafeSession(session.Session) + session.SetOptions(&querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID}) } if autocommit { @@ -223,21 +219,21 @@ func (stc *ScatterConn) ExecuteMultiShard( switch info.actionNeeded { case nothing: - innerqr, err = qs.Execute(ctx, rs.Target, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts) + innerqr, err = qs.Execute(ctx, session, rs.Target, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID) if err != nil { retryRequest(func() { // we seem to have lost our connection. it was a reserved connection, let's try to recreate it info.actionNeeded = reserve info.ignoreOldSession = true var state queryservice.ReservedState - state, innerqr, err = qs.ReserveExecute(ctx, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, 0 /*transactionId*/, opts) + state, innerqr, err = qs.ReserveExecute(ctx, session, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, 0 /*transactionId*/) reservedID = state.ReservedID alias = state.TabletAlias }) } case begin: var state queryservice.TransactionState - state, innerqr, err = qs.BeginExecute(ctx, rs.Target, session.SavePoints(), queries[i].Sql, queries[i].BindVariables, reservedID, opts) + state, innerqr, err = qs.BeginExecute(ctx, session, rs.Target, session.SavePoints(), queries[i].Sql, queries[i].BindVariables, reservedID) transactionID = state.TransactionID alias = state.TabletAlias if err != nil { @@ -246,7 +242,7 @@ func (stc *ScatterConn) ExecuteMultiShard( info.actionNeeded = reserveBegin info.ignoreOldSession = true var state queryservice.ReservedTransactionState - state, innerqr, err = qs.ReserveBeginExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), queries[i].Sql, queries[i].BindVariables, opts) + state, innerqr, err = qs.ReserveBeginExecute(ctx, session, rs.Target, session.SetPreQueries(), session.SavePoints(), queries[i].Sql, queries[i].BindVariables) transactionID = state.TransactionID reservedID = state.ReservedID alias = state.TabletAlias @@ -254,12 +250,12 @@ func (stc *ScatterConn) ExecuteMultiShard( } case reserve: var state queryservice.ReservedState - state, innerqr, err = qs.ReserveExecute(ctx, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, transactionID, opts) + state, innerqr, err = qs.ReserveExecute(ctx, session, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, transactionID) reservedID = state.ReservedID alias = state.TabletAlias case reserveBegin: var state queryservice.ReservedTransactionState - state, innerqr, err = qs.ReserveBeginExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), queries[i].Sql, queries[i].BindVariables, opts) + state, innerqr, err = qs.ReserveBeginExecute(ctx, session, rs.Target, session.SetPreQueries(), session.SavePoints(), queries[i].Sql, queries[i].BindVariables) transactionID = state.TransactionID reservedID = state.ReservedID alias = state.TabletAlias @@ -411,19 +407,15 @@ func (stc *ScatterConn) StreamExecuteMulti( func(rs *srvtopo.ResolvedShard, i int, info *shardActionInfo) (*shardActionInfo, error) { var ( err error - opts *querypb.ExecuteOptions alias *topodatapb.TabletAlias qs queryservice.QueryService ) transactionID := info.transactionID reservedID := info.reservedID - if session != nil && session.Session != nil { - opts = session.Session.Options - } - - if opts == nil && fetchLastInsertID { - opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} + if session.GetOptions() == nil && fetchLastInsertID { + session = econtext.NewSafeSession(session.Session) + session.SetOptions(&querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID}) } if autocommit { @@ -454,20 +446,20 @@ func (stc *ScatterConn) StreamExecuteMulti( switch info.actionNeeded { case nothing: - err = qs.StreamExecute(ctx, rs.Target, query, bindVars[i], transactionID, reservedID, opts, observedCallback) + err = qs.StreamExecute(ctx, session, rs.Target, query, bindVars[i], transactionID, reservedID, observedCallback) if err != nil { retryRequest(func() { // we seem to have lost our connection. it was a reserved connection, let's try to recreate it info.actionNeeded = reserve var state queryservice.ReservedState - state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, opts, observedCallback) + state, err = qs.ReserveStreamExecute(ctx, session, rs.Target, session.SetPreQueries(), query, bindVars[i], 0 /*transactionId*/, observedCallback) reservedID = state.ReservedID alias = state.TabletAlias }) } case begin: var state queryservice.TransactionState - state, err = qs.BeginStreamExecute(ctx, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, opts, observedCallback) + state, err = qs.BeginStreamExecute(ctx, session, rs.Target, session.SavePoints(), query, bindVars[i], reservedID, observedCallback) transactionID = state.TransactionID alias = state.TabletAlias if err != nil { @@ -475,7 +467,7 @@ func (stc *ScatterConn) StreamExecuteMulti( // we seem to have lost our connection. it was a reserved connection, let's try to recreate it info.actionNeeded = reserveBegin var state queryservice.ReservedTransactionState - state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback) + state, err = qs.ReserveBeginStreamExecute(ctx, session, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], observedCallback) transactionID = state.TransactionID reservedID = state.ReservedID alias = state.TabletAlias @@ -483,12 +475,12 @@ func (stc *ScatterConn) StreamExecuteMulti( } case reserve: var state queryservice.ReservedState - state, err = qs.ReserveStreamExecute(ctx, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, opts, observedCallback) + state, err = qs.ReserveStreamExecute(ctx, session, rs.Target, session.SetPreQueries(), query, bindVars[i], transactionID, observedCallback) reservedID = state.ReservedID alias = state.TabletAlias case reserveBegin: var state queryservice.ReservedTransactionState - state, err = qs.ReserveBeginStreamExecute(ctx, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], opts, observedCallback) + state, err = qs.ReserveBeginStreamExecute(ctx, session, rs.Target, session.SetPreQueries(), session.SavePoints(), query, bindVars[i], observedCallback) transactionID = state.TransactionID reservedID = state.ReservedID alias = state.TabletAlias @@ -759,7 +751,6 @@ func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedSha var ( qr *sqltypes.Result err error - opts *querypb.ExecuteOptions alias *topodatapb.TabletAlias ) allErrors := new(concurrency.AllErrorRecorder) @@ -770,7 +761,6 @@ func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedSha return nil, vterrors.VT13001("session cannot be nil") } - opts = session.Session.Options info, err := lockInfo(rs.Target, session, lockFuncType) // Lock session is created on alphabetic sorted keyspace. // This error will occur if the existing session target does not match the current target. @@ -788,7 +778,7 @@ func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedSha switch info.actionNeeded { case nothing: - qr, err = qs.Execute(ctx, rs.Target, query.Sql, query.BindVariables, 0 /* transactionID */, reservedID, opts) + qr, err = qs.Execute(ctx, session, rs.Target, query.Sql, query.BindVariables, 0 /* transactionID */, reservedID) if err != nil && wasConnectionClosed(err) { // TODO: try to acquire lock again. session.ResetLock() @@ -799,7 +789,7 @@ func (stc *ScatterConn) ExecuteLock(ctx context.Context, rs *srvtopo.ResolvedSha } case reserve: var state queryservice.ReservedState - state, qr, err = qs.ReserveExecute(ctx, rs.Target, session.SetPreQueries(), query.Sql, query.BindVariables, 0 /* transactionID */, opts) + state, qr, err = qs.ReserveExecute(ctx, session, rs.Target, session.SetPreQueries(), query.Sql, query.BindVariables, 0 /* transactionID */) reservedID = state.ReservedID alias = state.TabletAlias if err != nil && reservedID != 0 { @@ -877,7 +867,7 @@ func actionInfo(ctx context.Context, target *querypb.Target, session *econtext.S shouldReserve := session.InReservedConn() && (shardSession == nil || shardSession.ReservedId == 0) shouldBegin := session.InTransaction() && (shardSession == nil || shardSession.TransactionId == 0) && !autocommit - var act = nothing + act := nothing switch { case shouldBegin && shouldReserve: act = reserveBegin diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index 077fd28c17e..c684b2d0d41 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -464,8 +464,8 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // Get the tablet from the balancer if enabled if useBalancer { var sessionUUID string - if opts.Options != nil { - sessionUUID = opts.Options.SessionUUID + if opts.Session != nil { + sessionUUID = opts.Session.GetSessionUUID() } tablet := gw.balancer.Pick(target, tablets, balancer.PickOpts{SessionUUID: sessionUUID}) diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index 124997bea9e..da0a7c7cd9f 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -32,6 +32,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) // TestGatewayBufferingWhenPrimarySwitchesServingState is used to test that the buffering mechanism buffers the queries when a primary goes to a non serving state and @@ -94,7 +95,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { sbc.SetResults([]*sqltypes.Result{sqlResult1}) // run a query that we indeed get the result added to the sandbox connection back - res, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil) + res, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) require.NoError(t, err) require.Equal(t, res, sqlResult1) @@ -114,7 +115,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { // execute the query in a go routine since it should be buffered, and check that it eventually succeed queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, target, "query", nil, 0, 0, nil) + res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) queryChan <- struct{}{} }() @@ -186,7 +187,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // run a query that we indeed get the result added to the sandbox connection back // this also checks that the query reaches the primary tablet and not the replica - res, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil) + res, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) require.NoError(t, err) require.Equal(t, res, sqlResult1) @@ -224,7 +225,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // execute the query in a go routine since it should be buffered, and check that it eventually succeed queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, target, "query", nil, 0, 0, nil) + res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) queryChan <- struct{}{} }() @@ -332,7 +333,7 @@ func TestInconsistentStateDetectedBuffering(t *testing.T) { var err error queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, target, "query", nil, 0, 0, nil) + res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) queryChan <- struct{}{} }() diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index ace5e4a6081..6b92febf83a 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -34,6 +34,7 @@ import ( "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vterrors" @@ -43,14 +44,14 @@ import ( func TestTabletGatewayExecute(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil) + _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { assert.Equal(t, want, sc.ExecCount.Load()) }) testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil) + _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 1, 0) return err }) } @@ -58,7 +59,7 @@ func TestTabletGatewayExecute(t *testing.T) { func TestTabletGatewayExecuteStream(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - err := tg.StreamExecute(ctx, target, "query", nil, 0, 0, nil, func(qr *sqltypes.Result) error { + err := tg.StreamExecute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0, func(qr *sqltypes.Result) error { return nil }) return err @@ -71,7 +72,7 @@ func TestTabletGatewayExecuteStream(t *testing.T) { func TestTabletGatewayBegin(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Begin(ctx, target, nil) + _, err := tg.Begin(ctx, &vtgatepb.Session{}, target) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { @@ -98,7 +99,7 @@ func TestTabletGatewayRollback(t *testing.T) { func TestTabletGatewayBeginExecute(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, _, err := tg.BeginExecute(ctx, target, nil, "query", nil, 0, nil) + _, _, err := tg.BeginExecute(ctx, &vtgatepb.Session{}, target, nil, "query", nil, 0) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { @@ -190,7 +191,7 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) { defer tg.Close(ctx) _ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) - _, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil) + _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 1, 0) verifyContainsError(t, err, "query service can only be used for non-transactional queries on replicas", vtrpcpb.Code_INTERNAL) } diff --git a/go/vt/vttablet/grpcqueryservice/server.go b/go/vt/vttablet/grpcqueryservice/server.go index e3c179ce856..ced6dcea0bd 100644 --- a/go/vt/vttablet/grpcqueryservice/server.go +++ b/go/vt/vttablet/grpcqueryservice/server.go @@ -25,11 +25,13 @@ import ( "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/callinfo" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/queryservice" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) // query is the gRPC query service implementation. @@ -48,7 +50,10 @@ func (q *query) Execute(ctx context.Context, request *querypb.ExecuteRequest) (r request.EffectiveCallerId, request.ImmediateCallerId, ) - result, err := q.server.Execute(ctx, request.Target, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.ReservedId, request.Options) + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + result, err := q.server.Execute(ctx, session, request.Target, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.ReservedId) if err != nil { return nil, vterrors.ToGRPC(err) } @@ -64,7 +69,10 @@ func (q *query) StreamExecute(request *querypb.StreamExecuteRequest, stream quer request.EffectiveCallerId, request.ImmediateCallerId, ) - err = q.server.StreamExecute(ctx, request.Target, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.ReservedId, request.Options, func(reply *sqltypes.Result) error { + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + err = q.server.StreamExecute(ctx, session, request.Target, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.ReservedId, func(reply *sqltypes.Result) error { return stream.Send(&querypb.StreamExecuteResponse{ Result: sqltypes.ResultToProto3(reply), }) @@ -79,7 +87,10 @@ func (q *query) Begin(ctx context.Context, request *querypb.BeginRequest) (respo request.EffectiveCallerId, request.ImmediateCallerId, ) - state, err := q.server.Begin(ctx, request.Target, request.Options) + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, err := q.server.Begin(ctx, session, request.Target) if err != nil { return nil, vterrors.ToGRPC(err) } @@ -252,7 +263,10 @@ func (q *query) BeginExecute(ctx context.Context, request *querypb.BeginExecuteR request.EffectiveCallerId, request.ImmediateCallerId, ) - state, result, err := q.server.BeginExecute(ctx, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.ReservedId, request.Options) + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, result, err := q.server.BeginExecute(ctx, session, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.ReservedId) if err != nil { // if we have a valid transactionID, return the error in-band if state.TransactionID != 0 { @@ -279,7 +293,10 @@ func (q *query) BeginStreamExecute(request *querypb.BeginStreamExecuteRequest, s request.EffectiveCallerId, request.ImmediateCallerId, ) - state, err := q.server.BeginStreamExecute(ctx, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.ReservedId, request.Options, func(reply *sqltypes.Result) error { + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, err := q.server.BeginStreamExecute(ctx, session, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.ReservedId, func(reply *sqltypes.Result) error { return stream.Send(&querypb.BeginStreamExecuteResponse{ Result: sqltypes.ResultToProto3(reply), }) @@ -392,7 +409,10 @@ func (q *query) ReserveExecute(ctx context.Context, request *querypb.ReserveExec request.EffectiveCallerId, request.ImmediateCallerId, ) - state, result, err := q.server.ReserveExecute(ctx, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.Options) + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, result, err := q.server.ReserveExecute(ctx, session, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.TransactionId) if err != nil { // if we have a valid reservedID, return the error in-band if state.ReservedID != 0 { @@ -418,7 +438,10 @@ func (q *query) ReserveStreamExecute(request *querypb.ReserveStreamExecuteReques request.EffectiveCallerId, request.ImmediateCallerId, ) - state, err := q.server.ReserveStreamExecute(ctx, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.TransactionId, request.Options, func(reply *sqltypes.Result) error { + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, err := q.server.ReserveStreamExecute(ctx, session, request.Target, request.PreQueries, request.Query.Sql, request.Query.BindVariables, request.TransactionId, func(reply *sqltypes.Result) error { return stream.Send(&querypb.ReserveStreamExecuteResponse{ Result: sqltypes.ResultToProto3(reply), }) @@ -442,7 +465,10 @@ func (q *query) ReserveBeginExecute(ctx context.Context, request *querypb.Reserv request.EffectiveCallerId, request.ImmediateCallerId, ) - state, result, err := q.server.ReserveBeginExecute(ctx, request.Target, request.PreQueries, request.PostBeginQueries, request.Query.Sql, request.Query.BindVariables, request.Options) + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, result, err := q.server.ReserveBeginExecute(ctx, session, request.Target, request.PreQueries, request.PostBeginQueries, request.Query.Sql, request.Query.BindVariables) if err != nil { // if we have a valid reservedID or transactionID, return the error in-band if state.TransactionID != 0 || state.ReservedID != 0 { @@ -472,7 +498,10 @@ func (q *query) ReserveBeginStreamExecute(request *querypb.ReserveBeginStreamExe request.EffectiveCallerId, request.ImmediateCallerId, ) - state, err := q.server.ReserveBeginStreamExecute(ctx, request.Target, request.PreQueries, request.PostBeginQueries, request.Query.Sql, request.Query.BindVariables, request.Options, func(reply *sqltypes.Result) error { + + session := executorcontext.NewSafeSession(&vtgatepb.Session{Options: request.Options}) + + state, err := q.server.ReserveBeginStreamExecute(ctx, session, request.Target, request.PreQueries, request.PostBeginQueries, request.Query.Sql, request.Query.BindVariables, func(reply *sqltypes.Result) error { return stream.Send(&querypb.ReserveBeginStreamExecuteResponse{ Result: sqltypes.ResultToProto3(reply), }) diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index 0ef94031cf3..0bf73bff7a0 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -112,7 +112,7 @@ func DialTablet(ctx context.Context, tablet *topodatapb.Tablet, failFast grpccli } // Execute sends the query to VTTablet. -func (conn *gRPCQueryClient) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (conn *gRPCQueryClient) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -128,7 +128,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, target *querypb.Target BindVariables: bindVars, }, TransactionId: transactionID, - Options: options, + Options: session.GetOptions(), ReservedId: reservedID, } er, err := conn.c.Execute(ctx, req) @@ -139,7 +139,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, target *querypb.Target } // StreamExecute executes the query and streams results back through callback. -func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { +func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error { // All streaming clients should follow the code pattern below. // The first part of the function starts the stream while holding // a lock on conn.mu. The second part receives the data and calls @@ -166,7 +166,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, target *querypb. Sql: query, BindVariables: bindVars, }, - Options: options, + Options: session.GetOptions(), TransactionId: transactionID, ReservedId: reservedID, } @@ -198,7 +198,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, target *querypb. } // Begin starts a transaction. -func (conn *gRPCQueryClient) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state queryservice.TransactionState, err error) { +func (conn *gRPCQueryClient) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (state queryservice.TransactionState, err error) { conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -209,7 +209,7 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context, target *querypb.Target, Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: options, + Options: session.GetOptions(), } br, err := conn.c.Begin(ctx, req) if err != nil { @@ -463,7 +463,7 @@ func (conn *gRPCQueryClient) UnresolvedTransactions(ctx context.Context, target } // BeginExecute starts a transaction and runs an Execute. -func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state queryservice.TransactionState, result *sqltypes.Result, err error) { +func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64) (state queryservice.TransactionState, result *sqltypes.Result, err error) { conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -480,7 +480,7 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T BindVariables: bindVars, }, ReservedId: reservedID, - Options: options, + Options: session.GetOptions(), } reply, err := conn.c.BeginExecute(ctx, req) if err != nil { @@ -496,7 +496,7 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, target *querypb.T } // BeginStreamExecute starts a transaction and runs an Execute. -func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) { +func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (state queryservice.TransactionState, err error) { // Please see comments in StreamExecute to see how this works. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -524,7 +524,7 @@ func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, target *que BindVariables: bindVars, }, ReservedId: reservedID, - Options: options, + Options: session.GetOptions(), } stream, err := conn.c.BeginStreamExecute(ctx, req) if err != nil { @@ -861,7 +861,7 @@ func (conn *gRPCQueryClient) HandlePanic(err *error) { } // ReserveBeginExecute implements the queryservice interface -func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { +func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -872,7 +872,7 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: options, + Options: session.GetOptions(), PreQueries: preQueries, PostBeginQueries: postBeginQueries, Query: &querypb.BoundQuery{ @@ -896,7 +896,7 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, target *qu } // ReserveBeginStreamExecute implements the queryservice interface -func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) { +func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (state queryservice.ReservedTransactionState, err error) { // Please see comments in StreamExecute to see how this works. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -917,7 +917,7 @@ func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, targ Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: options, + Options: session.GetOptions(), PreQueries: preQueries, PostBeginQueries: postBeginQueries, Query: &querypb.BoundQuery{ @@ -977,7 +977,7 @@ func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, targ } // ReserveExecute implements the queryservice interface -func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state queryservice.ReservedState, result *sqltypes.Result, err error) { +func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (state queryservice.ReservedState, result *sqltypes.Result, err error) { conn.mu.RLock() defer conn.mu.RUnlock() if conn.cc == nil { @@ -993,7 +993,7 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb BindVariables: bindVariables, }, TransactionId: transactionID, - Options: options, + Options: session.GetOptions(), PreQueries: preQueries, } reply, err := conn.c.ReserveExecute(ctx, req) @@ -1010,7 +1010,7 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, target *querypb } // ReserveStreamExecute implements the queryservice interface -func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) { +func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (state queryservice.ReservedState, err error) { // Please see comments in StreamExecute to see how this works. ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -1031,7 +1031,7 @@ func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, target *q Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: options, + Options: session.GetOptions(), PreQueries: preQueries, Query: &querypb.BoundQuery{ Sql: sql, diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index e20cf36a797..c7e104660d1 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -32,6 +32,7 @@ import ( binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" "vitess.io/vitess/go/vt/vttablet/tabletconntest" @@ -191,22 +192,22 @@ func TestGoRoutineLeakPrevention(t *testing.T) { cc: &grpc.ClientConn{}, c: mqc, } - _ = qc.StreamExecute(context.Background(), nil, "", nil, 0, 0, nil, func(result *sqltypes.Result) error { + _ = qc.StreamExecute(context.Background(), &vtgatepb.Session{}, nil, "", nil, 0, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.BeginStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + _, _ = qc.BeginStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, "", nil, nil, func(result *sqltypes.Result) error { + _, _ = qc.ReserveBeginStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, nil, "", nil, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, "", nil, 0, nil, func(result *sqltypes.Result) error { + _, _ = qc.ReserveStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) diff --git a/go/vt/vttablet/queryservice/fakes/stream_health_query_service.go b/go/vt/vttablet/queryservice/fakes/stream_health_query_service.go index e992c12baca..7d259503eee 100644 --- a/go/vt/vttablet/queryservice/fakes/stream_health_query_service.go +++ b/go/vt/vttablet/queryservice/fakes/stream_health_query_service.go @@ -56,12 +56,12 @@ func NewStreamHealthQueryService(target *querypb.Target) *StreamHealthQueryServi } // Begin implemented as a no op -func (q *StreamHealthQueryService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (q *StreamHealthQueryService) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (queryservice.TransactionState, error) { return queryservice.TransactionState{}, nil } // Execute implemented as a no op -func (q *StreamHealthQueryService) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (q *StreamHealthQueryService) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { return &sqltypes.Result{}, nil } diff --git a/go/vt/vttablet/queryservice/queryservice.go b/go/vt/vttablet/queryservice/queryservice.go index d6972bfb6a3..75d09b1c05f 100644 --- a/go/vt/vttablet/queryservice/queryservice.go +++ b/go/vt/vttablet/queryservice/queryservice.go @@ -19,16 +19,25 @@ limitations under the License. package queryservice import ( - topodatapb "vitess.io/vitess/go/vt/proto/topodata" - "context" + topodatapb "vitess.io/vitess/go/vt/proto/topodata" + "vitess.io/vitess/go/sqltypes" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" ) +// Session represents the current session. +type Session interface { + // GetSessionUUID returns the session's UUID. + GetSessionUUID() string + + // GetOptions returns the current execute options. + GetOptions() *querypb.ExecuteOptions +} + // QueryService is the interface implemented by the tablet's query service. // All streaming methods accept a callback function that will be called for // each response. If the callback returns an error, that error is returned @@ -41,7 +50,7 @@ type QueryService interface { // Transaction management // Begin returns the transaction id to use for further operations - Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (TransactionState, error) + Begin(ctx context.Context, session Session, target *querypb.Target) (TransactionState, error) // Commit commits the current transaction Commit(ctx context.Context, target *querypb.Target, transactionID int64) (int64, error) @@ -80,16 +89,16 @@ type QueryService interface { UnresolvedTransactions(ctx context.Context, target *querypb.Target, abandonAgeSeconds int64) ([]*querypb.TransactionMetadata, error) // Execute for query execution - Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) + Execute(ctx context.Context, session Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) // StreamExecute for query execution with streaming - StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error + StreamExecute(ctx context.Context, session Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error // Combo methods, they also return the transactionID from the // Begin part. If err != nil, the transactionID may still be // non-zero, and needs to be propagated back (like for a DB // Integrity Error) - BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (TransactionState, *sqltypes.Result, error) - BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (TransactionState, error) + BeginExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (TransactionState, *sqltypes.Result, error) + BeginStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (TransactionState, error) // Messaging methods. MessageStream(ctx context.Context, target *querypb.Target, name string, callback func(*sqltypes.Result) error) error @@ -114,13 +123,13 @@ type QueryService interface { // HandlePanic will be called if any of the functions panic. HandlePanic(err *error) - ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (ReservedTransactionState, *sqltypes.Result, error) + ReserveBeginExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (ReservedTransactionState, *sqltypes.Result, error) - ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (ReservedTransactionState, error) + ReserveBeginStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (ReservedTransactionState, error) - ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (ReservedState, *sqltypes.Result, error) + ReserveExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (ReservedState, *sqltypes.Result, error) - ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (ReservedState, error) + ReserveStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (ReservedState, error) Release(ctx context.Context, target *querypb.Target, transactionID, reservedID int64) error diff --git a/go/vt/vttablet/queryservice/wrapped.go b/go/vt/vttablet/queryservice/wrapped.go index 1837fef1af8..4534170d399 100644 --- a/go/vt/vttablet/queryservice/wrapped.go +++ b/go/vt/vttablet/queryservice/wrapped.go @@ -40,7 +40,7 @@ type WrapperFunc func(ctx context.Context, target *querypb.Target, conn QuerySer type WrapOpts struct { InTransaction bool - Options *querypb.ExecuteOptions + Session Session } // Wrap returns a wrapped version of the original QueryService implementation. @@ -116,11 +116,11 @@ type wrappedService struct { wrapper WrapperFunc } -func (ws *wrappedService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state TransactionState, err error) { - opts := WrapOpts{InTransaction: false, Options: options} +func (ws *wrappedService) Begin(ctx context.Context, session Session, target *querypb.Target) (state TransactionState, err error) { + opts := WrapOpts{InTransaction: false, Session: session} err = ws.wrapper(ctx, target, ws.impl, "Begin", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - state, innerErr = conn.Begin(ctx, target, options) + state, innerErr = conn.Begin(ctx, session, target) return canRetry(ctx, innerErr), innerErr }) return state, wrapFatalTxErrorInVTError(err, true, vterrors.VT15001) @@ -238,12 +238,12 @@ func (ws *wrappedService) UnresolvedTransactions(ctx context.Context, target *qu return transactions, err } -func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (qr *sqltypes.Result, err error) { +func (ws *wrappedService) Execute(ctx context.Context, session Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64) (qr *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err = ws.wrapper(ctx, target, ws.impl, "Execute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - qr, innerErr = conn.Execute(ctx, target, query, bindVars, transactionID, reservedID, options) + qr, innerErr = conn.Execute(ctx, session, target, query, bindVars, transactionID, reservedID) // You cannot retry if you're in a transaction. retryable := canRetry(ctx, innerErr) && (!inDedicatedConn) return retryable, innerErr @@ -252,12 +252,12 @@ func (ws *wrappedService) Execute(ctx context.Context, target *querypb.Target, q } // StreamExecute implements the QueryService interface -func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { +func (ws *wrappedService) StreamExecute(ctx context.Context, session Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error { inDedicatedConn := transactionID != 0 || reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err := ws.wrapper(ctx, target, ws.impl, "StreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { streamingStarted := false - innerErr := conn.StreamExecute(ctx, target, query, bindVars, transactionID, reservedID, options, func(qr *sqltypes.Result) error { + innerErr := conn.StreamExecute(ctx, session, target, query, bindVars, transactionID, reservedID, func(qr *sqltypes.Result) error { streamingStarted = true return callback(qr) }) @@ -268,24 +268,24 @@ func (ws *wrappedService) StreamExecute(ctx context.Context, target *querypb.Tar return wrapFatalTxErrorInVTError(err, transactionID != 0, vterrors.VT15001) } -func (ws *wrappedService) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (state TransactionState, qr *sqltypes.Result, err error) { +func (ws *wrappedService) BeginExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64) (state TransactionState, qr *sqltypes.Result, err error) { inDedicatedConn := reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err = ws.wrapper(ctx, target, ws.impl, "BeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - state, qr, innerErr = conn.BeginExecute(ctx, target, preQueries, query, bindVars, reservedID, options) + state, qr, innerErr = conn.BeginExecute(ctx, session, target, preQueries, query, bindVars, reservedID) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr }) return state, qr, wrapFatalTxErrorInVTError(err, true, vterrors.VT15001) } // BeginStreamExecute implements the QueryService interface -func (ws *wrappedService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state TransactionState, err error) { +func (ws *wrappedService) BeginStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (state TransactionState, err error) { inDedicatedConn := reservedID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err = ws.wrapper(ctx, target, ws.impl, "BeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - state, innerErr = conn.BeginStreamExecute(ctx, target, preQueries, query, bindVars, reservedID, options, callback) + state, innerErr = conn.BeginStreamExecute(ctx, session, target, preQueries, query, bindVars, reservedID, callback) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr }) return state, wrapFatalTxErrorInVTError(err, true, vterrors.VT15001) @@ -354,11 +354,11 @@ func (ws *wrappedService) HandlePanic(err *error) { } // ReserveBeginExecute implements the QueryService interface -func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state ReservedTransactionState, res *sqltypes.Result, err error) { - opts := WrapOpts{InTransaction: false, Options: options} +func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (state ReservedTransactionState, res *sqltypes.Result, err error) { + opts := WrapOpts{InTransaction: false, Session: session} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error - state, res, err = conn.ReserveBeginExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options) + state, res, err = conn.ReserveBeginExecute(ctx, session, target, preQueries, postBeginQueries, sql, bindVariables) return canRetry(ctx, err), err }) @@ -366,23 +366,23 @@ func (ws *wrappedService) ReserveBeginExecute(ctx context.Context, target *query } // ReserveBeginStreamExecute implements the QueryService interface -func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { - opts := WrapOpts{InTransaction: false, Options: options} +func (ws *wrappedService) ReserveBeginStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (state ReservedTransactionState, err error) { + opts := WrapOpts{InTransaction: false, Session: session} err = ws.wrapper(ctx, target, ws.impl, "ReserveBeginStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - state, innerErr = conn.ReserveBeginStreamExecute(ctx, target, preQueries, postBeginQueries, sql, bindVariables, options, callback) + state, innerErr = conn.ReserveBeginStreamExecute(ctx, session, target, preQueries, postBeginQueries, sql, bindVariables, callback) return canRetry(ctx, innerErr), innerErr }) return state, err } // ReserveExecute implements the QueryService interface -func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state ReservedState, res *sqltypes.Result, err error) { +func (ws *wrappedService) ReserveExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (state ReservedState, res *sqltypes.Result, err error) { inDedicatedConn := transactionID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err = ws.wrapper(ctx, target, ws.impl, "ReserveExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var err error - state, res, err = conn.ReserveExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options) + state, res, err = conn.ReserveExecute(ctx, session, target, preQueries, sql, bindVariables, transactionID) return canRetry(ctx, err) && !inDedicatedConn, err }) @@ -390,12 +390,12 @@ func (ws *wrappedService) ReserveExecute(ctx context.Context, target *querypb.Ta } // ReserveStreamExecute implements the QueryService interface -func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (state ReservedState, err error) { +func (ws *wrappedService) ReserveStreamExecute(ctx context.Context, session Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (state ReservedState, err error) { inDedicatedConn := transactionID != 0 - opts := WrapOpts{InTransaction: inDedicatedConn, Options: options} + opts := WrapOpts{InTransaction: inDedicatedConn, Session: session} err = ws.wrapper(ctx, target, ws.impl, "ReserveStreamExecute", opts, func(ctx context.Context, target *querypb.Target, conn QueryService) (bool, error) { var innerErr error - state, innerErr = conn.ReserveStreamExecute(ctx, target, preQueries, sql, bindVariables, transactionID, options, callback) + state, innerErr = conn.ReserveStreamExecute(ctx, session, target, preQueries, sql, bindVariables, transactionID, callback) return canRetry(ctx, innerErr) && !inDedicatedConn, innerErr }) return state, err diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index e6a13199204..4b466e6d10d 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -264,7 +264,7 @@ func (sbc *SandboxConn) SetSchemaResult(r []SchemaResult) { } // Execute is part of the QueryService interface. -func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (sbc *SandboxConn) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { sbc.panicIfNeeded() sbc.execMu.Lock() defer sbc.execMu.Unlock() @@ -283,7 +283,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que Sql: query, BindVariables: bv, }) - sbc.Options = append(sbc.Options, options) + sbc.Options = append(sbc.Options, session.GetOptions()) if err := sbc.getError(); err != nil { return nil, err } @@ -297,7 +297,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, target *querypb.Target, que } // StreamExecute is part of the QueryService interface. -func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { +func (sbc *SandboxConn) StreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, query string, bindVars map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error { sbc.panicIfNeeded() sbc.sExecMu.Lock() sbc.ExecCount.Add(1) @@ -309,7 +309,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe Sql: query, BindVariables: bv, }) - sbc.Options = append(sbc.Options, options) + sbc.Options = append(sbc.Options, session.GetOptions()) err := sbc.getError() if err != nil { sbc.sExecMu.Unlock() @@ -338,12 +338,12 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe } // Begin is part of the QueryService interface. -func (sbc *SandboxConn) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (sbc *SandboxConn) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (queryservice.TransactionState, error) { sbc.panicIfNeeded() - return sbc.begin(ctx, target, nil, 0, options) + return sbc.begin(ctx, session, target, nil, 0) } -func (sbc *SandboxConn) begin(ctx context.Context, target *querypb.Target, preQueries []string, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (sbc *SandboxConn) begin(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, reservedID int64) (queryservice.TransactionState, error) { sbc.BeginCount.Add(1) err := sbc.getError() if err != nil { @@ -355,7 +355,7 @@ func (sbc *SandboxConn) begin(ctx context.Context, target *querypb.Target, preQu transactionID = sbc.TransactionID.Add(1) } for _, preQuery := range preQueries { - _, err := sbc.Execute(ctx, target, preQuery, nil, transactionID, reservedID, options) + _, err := sbc.Execute(ctx, session, target, preQuery, nil, transactionID, reservedID) if err != nil { return queryservice.TransactionState{}, err } @@ -494,30 +494,30 @@ func (sbc *SandboxConn) UnresolvedTransactions(context.Context, *querypb.Target, } // BeginExecute is part of the QueryService interface. -func (sbc *SandboxConn) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { +func (sbc *SandboxConn) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, query string, bindVars map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { sbc.panicIfNeeded() - state, err := sbc.begin(ctx, target, preQueries, reservedID, options) + state, err := sbc.begin(ctx, session, target, preQueries, reservedID) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) } if err != nil { return queryservice.TransactionState{}, nil, err } - result, err := sbc.Execute(ctx, target, query, bindVars, state.TransactionID, reservedID, options) + result, err := sbc.Execute(ctx, session, target, query, bindVars, state.TransactionID, reservedID) return state, result, err } // BeginStreamExecute is part of the QueryService interface. -func (sbc *SandboxConn) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { +func (sbc *SandboxConn) BeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { sbc.panicIfNeeded() - state, err := sbc.begin(ctx, target, preQueries, reservedID, options) + state, err := sbc.begin(ctx, session, target, preQueries, reservedID) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) } if err != nil { return queryservice.TransactionState{}, err } - err = sbc.StreamExecute(ctx, target, sql, bindVariables, state.TransactionID, reservedID, options, callback) + err = sbc.StreamExecute(ctx, session, target, sql, bindVariables, state.TransactionID, reservedID, callback) return state, err } @@ -669,10 +669,10 @@ func (sbc *SandboxConn) HandlePanic(err *error) { } // ReserveBeginExecute implements the QueryService interface -func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { +func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { sbc.panicIfNeeded() - reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options) - state, result, err := sbc.BeginExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options) + reservedID := sbc.reserve(ctx, session, target, preQueries, bindVariables, 0) + state, result, err := sbc.BeginExecute(ctx, session, target, postBeginQueries, sql, bindVariables, reservedID) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) } @@ -684,10 +684,10 @@ func (sbc *SandboxConn) ReserveBeginExecute(ctx context.Context, target *querypb } // ReserveBeginStreamExecute is part of the QueryService interface. -func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { +func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { sbc.panicIfNeeded() - reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, 0, options) - state, err := sbc.BeginStreamExecute(ctx, target, postBeginQueries, sql, bindVariables, reservedID, options, callback) + reservedID := sbc.reserve(ctx, session, target, preQueries, bindVariables, 0) + state, err := sbc.BeginStreamExecute(ctx, session, target, postBeginQueries, sql, bindVariables, reservedID, callback) if state.TransactionID != 0 { sbc.setTxReservedID(state.TransactionID, reservedID) } @@ -699,10 +699,10 @@ func (sbc *SandboxConn) ReserveBeginStreamExecute(ctx context.Context, target *q } // ReserveExecute implements the QueryService interface -func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (queryservice.ReservedState, *sqltypes.Result, error) { +func (sbc *SandboxConn) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (queryservice.ReservedState, *sqltypes.Result, error) { sbc.panicIfNeeded() - reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options) - result, err := sbc.Execute(ctx, target, sql, bindVariables, transactionID, reservedID, options) + reservedID := sbc.reserve(ctx, session, target, preQueries, bindVariables, transactionID) + result, err := sbc.Execute(ctx, session, target, sql, bindVariables, transactionID, reservedID) if transactionID != 0 { sbc.setTxReservedID(transactionID, reservedID) } @@ -713,10 +713,10 @@ func (sbc *SandboxConn) ReserveExecute(ctx context.Context, target *querypb.Targ } // ReserveStreamExecute is part of the QueryService interface. -func (sbc *SandboxConn) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { +func (sbc *SandboxConn) ReserveStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { sbc.panicIfNeeded() - reservedID := sbc.reserve(ctx, target, preQueries, bindVariables, transactionID, options) - err := sbc.StreamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, options, callback) + reservedID := sbc.reserve(ctx, session, target, preQueries, bindVariables, transactionID) + err := sbc.StreamExecute(ctx, session, target, sql, bindVariables, transactionID, reservedID, callback) if transactionID != 0 { sbc.setTxReservedID(transactionID, reservedID) } @@ -726,10 +726,10 @@ func (sbc *SandboxConn) ReserveStreamExecute(ctx context.Context, target *queryp }, err } -func (sbc *SandboxConn) reserve(ctx context.Context, target *querypb.Target, preQueries []string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) int64 { +func (sbc *SandboxConn) reserve(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, bindVariables map[string]*querypb.BindVariable, transactionID int64) int64 { sbc.ReserveCount.Add(1) for _, query := range preQueries { - sbc.Execute(ctx, target, query, bindVariables, transactionID, 0, options) + sbc.Execute(ctx, session, target, query, bindVariables, transactionID, 0) } if transactionID != 0 { return transactionID diff --git a/go/vt/vttablet/tabletconntest/fakequeryservice.go b/go/vt/vttablet/tabletconntest/fakequeryservice.go index 13ec5838fe3..19a93592432 100644 --- a/go/vt/vttablet/tabletconntest/fakequeryservice.go +++ b/go/vt/vttablet/tabletconntest/fakequeryservice.go @@ -33,6 +33,7 @@ import ( binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -104,10 +105,12 @@ var TestVTGateCallerID = &querypb.VTGateCallerID{ Username: "test_username", } -// TestExecuteOptions is a test execute options. -var TestExecuteOptions = &querypb.ExecuteOptions{ - IncludedFields: querypb.ExecuteOptions_TYPE_ONLY, - ClientFoundRows: true, +// TestSession is a test execute options. +var TestSession = &vtgatepb.Session{ + Options: &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_TYPE_ONLY, + ClientFoundRows: true, + }, } // TestAsTransaction is a test 'asTransaction' flag. @@ -141,7 +144,7 @@ func (f *FakeQueryService) checkTargetCallerID(ctx context.Context, name string, const beginTransactionID int64 = 9990 // Begin is part of the queryservice.QueryService interface -func (f *FakeQueryService) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (f *FakeQueryService) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (queryservice.TransactionState, error) { if f.HasBeginError { return queryservice.TransactionState{}, f.TabletError } @@ -149,8 +152,8 @@ func (f *FakeQueryService) Begin(ctx context.Context, target *querypb.Target, op panic(errors.New("test-triggered panic")) } f.checkTargetCallerID(ctx, "Begin", target) - if !proto.Equal(options, TestExecuteOptions) { - f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", options, TestExecuteOptions) + if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { + f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) } return queryservice.TransactionState{TransactionID: beginTransactionID, TabletAlias: TestAlias}, nil } @@ -412,7 +415,7 @@ var ExecuteQueryResult = sqltypes.Result{ } // Execute is part of the queryservice.QueryService interface -func (f *FakeQueryService) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (f *FakeQueryService) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { if f.HasError { return nil, f.TabletError } @@ -425,8 +428,8 @@ func (f *FakeQueryService) Execute(ctx context.Context, target *querypb.Target, if !sqltypes.BindVariablesEqual(bindVariables, ExecuteBindVars) { f.t.Errorf("invalid Execute.BindVariables: got %v expected %v", bindVariables, ExecuteBindVars) } - if !proto.Equal(options, TestExecuteOptions) { - f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", options, TestExecuteOptions) + if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { + f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) } f.checkTargetCallerID(ctx, "Execute", target) if transactionID != f.ExpectedTransactionID { @@ -472,7 +475,7 @@ var StreamExecuteQueryResult2 = sqltypes.Result{ } // StreamExecute is part of the queryservice.QueryService interface -func (f *FakeQueryService) StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { +func (f *FakeQueryService) StreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error { if f.Panics && f.StreamExecutePanicsEarly { panic(errors.New("test-triggered panic early")) } @@ -482,8 +485,8 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, target *querypb.Ta if !sqltypes.BindVariablesEqual(bindVariables, StreamExecuteBindVars) { f.t.Errorf("invalid StreamExecute.BindVariables: got %v expected %v", bindVariables, StreamExecuteBindVars) } - if !proto.Equal(options, TestExecuteOptions) { - f.t.Errorf("invalid StreamExecute.ExecuteOptions: got %v expected %v", options, TestExecuteOptions) + if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { + f.t.Errorf("invalid StreamExecute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) } f.checkTargetCallerID(ctx, "StreamExecute", target) if err := callback(&StreamExecuteQueryResult1); err != nil { @@ -567,32 +570,32 @@ var ExecuteBatchQueryResultList = []sqltypes.Result{ } // BeginExecute combines Begin and Execute. -func (f *FakeQueryService) BeginExecute(ctx context.Context, target *querypb.Target, _ []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { - state, err := f.Begin(ctx, target, options) +func (f *FakeQueryService) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, _ []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { + state, err := f.Begin(ctx, session, target) if err != nil { return state, nil, err } // TODO(deepthi): what alias should we actually return here? - result, err := f.Execute(ctx, target, sql, bindVariables, state.TransactionID, reservedID, options) + result, err := f.Execute(ctx, session, target, sql, bindVariables, state.TransactionID, reservedID) return state, result, err } // BeginStreamExecute combines Begin and StreamExecute. -func (f *FakeQueryService) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { - state, err := f.Begin(ctx, target, options) +func (f *FakeQueryService) BeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { + state, err := f.Begin(ctx, session, target) if err != nil { return state, err } for _, preQuery := range preQueries { - _, err := f.Execute(ctx, target, preQuery, nil, state.TransactionID, reservedID, options) + _, err := f.Execute(ctx, session, target, preQuery, nil, state.TransactionID, reservedID) if err != nil { return state, err } } - err = f.StreamExecute(ctx, target, sql, bindVariables, state.TransactionID, reservedID, options, callback) + err = f.StreamExecute(ctx, session, target, sql, bindVariables, state.TransactionID, reservedID, callback) return state, err } @@ -732,27 +735,27 @@ func (f *FakeQueryService) GetServingKeyspaces() []string { } // ReserveBeginExecute satisfies the Gateway interface -func (f *FakeQueryService) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { +func (f *FakeQueryService) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { panic("implement me") } // ReserveBeginStreamExecute satisfies the Gateway interface -func (f *FakeQueryService) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { +func (f *FakeQueryService) ReserveBeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { panic("implement me") } // ReserveExecute implements the QueryService interface -func (f *FakeQueryService) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (queryservice.ReservedState, *sqltypes.Result, error) { +func (f *FakeQueryService) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (queryservice.ReservedState, *sqltypes.Result, error) { panic("implement me") } // ReserveStreamExecute satisfies the Gateway interface -func (f *FakeQueryService) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { +func (f *FakeQueryService) ReserveStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { state, err := f.reserve(transactionID) if err != nil { return state, err } - err = f.StreamExecute(ctx, target, sql, bindVariables, transactionID, state.ReservedID, options, callback) + err = f.StreamExecute(ctx, session, target, sql, bindVariables, transactionID, state.ReservedID, callback) return state, err } diff --git a/go/vt/vttablet/tabletconntest/tabletconntest.go b/go/vt/vttablet/tabletconntest/tabletconntest.go index 91aeceb0dcd..bba5cab9fcc 100644 --- a/go/vt/vttablet/tabletconntest/tabletconntest.go +++ b/go/vt/vttablet/tabletconntest/tabletconntest.go @@ -36,6 +36,7 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/vt/vttablet/tabletconn" @@ -102,7 +103,7 @@ func testBegin(t *testing.T, conn queryservice.QueryService, f *FakeQueryService t.Log("testBegin") ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - state, err := conn.Begin(ctx, TestTarget, TestExecuteOptions) + state, err := conn.Begin(ctx, TestSession, TestTarget) if err != nil { t.Fatalf("Begin failed: %v", err) } @@ -116,7 +117,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe t.Log("testBeginError") f.HasBeginError = true testErrorHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, TestTarget, nil) + _, err := conn.Begin(ctx, executorcontext.NewSafeSession(nil), TestTarget) return err }) f.HasBeginError = false @@ -125,7 +126,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe func testBeginPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { t.Log("testBeginPanics") testPanicHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, TestTarget, nil) + _, err := conn.Begin(ctx, executorcontext.NewSafeSession(nil), TestTarget) return err }) } @@ -436,7 +437,7 @@ func testExecute(t *testing.T, conn queryservice.QueryService, f *FakeQueryServi f.ExpectedTransactionID = ExecuteTransactionID ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - qr, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) + qr, err := conn.Execute(ctx, TestSession, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID) if err != nil { t.Fatalf("Execute failed: %v", err) } @@ -449,7 +450,7 @@ func testExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQuery t.Log("testExecuteError") f.HasError = true testErrorHelper(t, f, "Execute", func(ctx context.Context) error { - _, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) + _, err := conn.Execute(ctx, TestSession, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID) return err }) f.HasError = false @@ -458,7 +459,7 @@ func testExecuteError(t *testing.T, conn queryservice.QueryService, f *FakeQuery func testExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { t.Log("testExecutePanics") testPanicHelper(t, f, "Execute", func(ctx context.Context) error { - _, err := conn.Execute(ctx, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID, TestExecuteOptions) + _, err := conn.Execute(ctx, TestSession, TestTarget, ExecuteQuery, ExecuteBindVars, ExecuteTransactionID, ReserveConnectionID) return err }) } @@ -468,7 +469,7 @@ func testBeginExecute(t *testing.T, conn queryservice.QueryService, f *FakeQuery f.ExpectedTransactionID = beginTransactionID ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - state, qr, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) + state, qr, err := conn.BeginExecute(ctx, TestSession, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID) if err != nil { t.Fatalf("BeginExecute failed: %v", err) } @@ -485,7 +486,7 @@ func testBeginExecuteErrorInBegin(t *testing.T, conn queryservice.QueryService, t.Log("testBeginExecuteErrorInBegin") f.HasBeginError = true testErrorHelper(t, f, "BeginExecute.Begin", func(ctx context.Context) error { - state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) + state, _, err := conn.BeginExecute(ctx, TestSession, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID) if state.TransactionID != 0 { t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted 0", state.TransactionID) } @@ -499,7 +500,7 @@ func testBeginExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService f.HasError = true testErrorHelper(t, f, "BeginExecute.Execute", func(ctx context.Context) error { ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - state, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) + state, _, err := conn.BeginExecute(ctx, TestSession, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID) if state.TransactionID != beginTransactionID { t.Errorf("Unexpected transactionID from BeginExecute: got %v wanted %v", state.TransactionID, beginTransactionID) } @@ -511,7 +512,7 @@ func testBeginExecuteErrorInExecute(t *testing.T, conn queryservice.QueryService func testBeginExecutePanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { t.Log("testBeginExecutePanics") testPanicHelper(t, f, "BeginExecute", func(ctx context.Context) error { - _, _, err := conn.BeginExecute(ctx, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID, TestExecuteOptions) + _, _, err := conn.BeginExecute(ctx, TestSession, TestTarget, nil, ExecuteQuery, ExecuteBindVars, ReserveConnectionID) return err }) } @@ -521,7 +522,7 @@ func testStreamExecute(t *testing.T, conn queryservice.QueryService, f *FakeQuer ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) i := 0 - err := conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + err := conn.StreamExecute(ctx, TestSession, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, func(qr *sqltypes.Result) error { switch i { case 0: if len(qr.Rows) == 0 { @@ -557,7 +558,7 @@ func testStreamExecuteError(t *testing.T, conn queryservice.QueryService, f *Fak testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { f.ErrorWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + return conn.StreamExecute(ctx, TestSession, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.ErrorWait: @@ -586,7 +587,7 @@ func testStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *Fa f.StreamExecutePanicsEarly = true testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error { ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + return conn.StreamExecute(ctx, TestSession, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, func(qr *sqltypes.Result) error { return nil }) }) @@ -596,7 +597,7 @@ func testStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f *Fa testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error { f.PanicWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + return conn.StreamExecute(ctx, TestSession, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.PanicWait: @@ -621,7 +622,7 @@ func testBeginStreamExecute(t *testing.T, conn queryservice.QueryService, f *Fak ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) i := 0 - _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + _, err := conn.BeginStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { switch i { case 0: if len(qr.Rows) == 0 { @@ -656,7 +657,7 @@ func testReserveStreamExecute(t *testing.T, conn queryservice.QueryService, f *F ctx := context.Background() ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) i := 0 - _, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + _, err := conn.ReserveStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { switch i { case 0: if len(qr.Rows) == 0 { @@ -692,7 +693,7 @@ func testBeginStreamExecuteErrorInBegin(t *testing.T, conn queryservice.QuerySer testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { f.ErrorWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + _, err := conn.BeginStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.ErrorWait: @@ -720,7 +721,7 @@ func testBeginStreamExecuteErrorInExecute(t *testing.T, conn queryservice.QueryS testErrorHelper(t, f, "StreamExecute", func(ctx context.Context) error { f.ErrorWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - state, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + state, err := conn.BeginStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.ErrorWait: @@ -749,7 +750,7 @@ func testReserveStreamExecuteErrorInReserve(t *testing.T, conn queryservice.Quer testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error { f.ErrorWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - _, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + _, err := conn.ReserveStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.ErrorWait: @@ -777,7 +778,7 @@ func testReserveStreamExecuteErrorInExecute(t *testing.T, conn queryservice.Quer testErrorHelper(t, f, "ReserveStreamExecute", func(ctx context.Context) error { f.ErrorWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - state, err := conn.ReserveStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + state, err := conn.ReserveStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.ErrorWait: @@ -808,7 +809,7 @@ func testBeginStreamExecutePanics(t *testing.T, conn queryservice.QueryService, f.StreamExecutePanicsEarly = true testPanicHelper(t, f, "StreamExecute.Early", func(ctx context.Context) error { ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - return conn.StreamExecute(ctx, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + return conn.StreamExecute(ctx, TestSession, TestTarget, StreamExecuteQuery, StreamExecuteBindVars, 0, 0, func(qr *sqltypes.Result) error { return nil }) }) @@ -818,7 +819,7 @@ func testBeginStreamExecutePanics(t *testing.T, conn queryservice.QueryService, testPanicHelper(t, f, "StreamExecute.Late", func(ctx context.Context) error { f.PanicWait = make(chan struct{}) ctx = callerid.NewContext(ctx, TestCallerID, TestVTGateCallerID) - _, err := conn.BeginStreamExecute(ctx, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, TestExecuteOptions, func(qr *sqltypes.Result) error { + _, err := conn.BeginStreamExecute(ctx, TestSession, TestTarget, nil, StreamExecuteQuery, StreamExecuteBindVars, 0, func(qr *sqltypes.Result) error { // For some errors, the call can be retried. select { case <-f.PanicWait: diff --git a/go/vt/vttablet/tabletmanager/framework_test.go b/go/vt/vttablet/tabletmanager/framework_test.go index c3ca31e4862..8f4cc38b59b 100644 --- a/go/vt/vttablet/tabletmanager/framework_test.go +++ b/go/vt/vttablet/tabletmanager/framework_test.go @@ -217,7 +217,7 @@ type fakeTabletConn struct { } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (queryservice.TransactionState, error) { +func (ftc *fakeTabletConn) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (queryservice.TransactionState, error) { return queryservice.TransactionState{ TransactionID: 1, }, nil @@ -274,24 +274,24 @@ func (ftc *fakeTabletConn) ReadTransaction(ctx context.Context, target *querypb. } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { +func (ftc *fakeTabletConn) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (*sqltypes.Result, error) { return nil, nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { +func (ftc *fakeTabletConn) StreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) error { return nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) BeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { +func (ftc *fakeTabletConn) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { return queryservice.TransactionState{ TransactionID: 1, }, nil, nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) BeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { +func (ftc *fakeTabletConn) BeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, callback func(*sqltypes.Result) error) (queryservice.TransactionState, error) { return queryservice.TransactionState{ TransactionID: 1, }, nil @@ -327,28 +327,28 @@ func (ftc *fakeTabletConn) HandlePanic(err *error) { } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) ReserveBeginExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { +func (ftc *fakeTabletConn) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (queryservice.ReservedTransactionState, *sqltypes.Result, error) { return queryservice.ReservedTransactionState{ ReservedID: 1, }, nil, nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) ReserveBeginStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { +func (ftc *fakeTabletConn) ReserveBeginStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error) (queryservice.ReservedTransactionState, error) { return queryservice.ReservedTransactionState{ ReservedID: 1, }, nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) ReserveExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (queryservice.ReservedState, *sqltypes.Result, error) { +func (ftc *fakeTabletConn) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (queryservice.ReservedState, *sqltypes.Result, error) { return queryservice.ReservedState{ ReservedID: 1, }, nil, nil } // fakeTabletConn implements the QueryService interface. -func (ftc *fakeTabletConn) ReserveStreamExecute(ctx context.Context, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { +func (ftc *fakeTabletConn) ReserveStreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, preQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, callback func(*sqltypes.Result) error) (queryservice.ReservedState, error) { return queryservice.ReservedState{ ReservedID: 1, }, nil diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index b390ed86ba6..96b054ac241 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -25,6 +25,7 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/executorcontext" querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" @@ -293,6 +294,6 @@ func (tm *TabletManager) ExecuteQuery(ctx context.Context, req *tabletmanagerdat if err != nil { return nil, err } - result, err := tm.QueryServiceControl.QueryService().Execute(ctx, target, uq, nil, 0, 0, nil) + result, err := tm.QueryServiceControl.QueryService().Execute(ctx, executorcontext.NewSafeSession(nil), target, uq, nil, 0, 0) return sqltypes.ResultToProto3(result), err } diff --git a/go/vt/vttablet/tabletserver/bench_test.go b/go/vt/vttablet/tabletserver/bench_test.go index fd2d86c2812..9e8a2e7740b 100644 --- a/go/vt/vttablet/tabletserver/bench_test.go +++ b/go/vt/vttablet/tabletserver/bench_test.go @@ -26,6 +26,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) // Benchmark run on 6/27/17, with optimized byte-level operations @@ -36,8 +37,10 @@ import ( // BenchmarkExecuteVarBinary-4 100 14610045 ns/op // BenchmarkExecuteExpression-4 1000 1047798 ns/op -var benchQuery = "select a from test_table where v = :vtg1 and v0 = :vtg2 and v1 = :vtg3 and v2 = :vtg4 and v3 = :vtg5 and v4 = :vtg6 and v5 = :vtg7 and v6 = :vtg8 and v7 = :vtg9 and v8 = :vtg10 and v9 = :vtg11" -var benchVarValue []byte +var ( + benchQuery = "select a from test_table where v = :vtg1 and v0 = :vtg2 and v1 = :vtg3 and v2 = :vtg4 and v3 = :vtg5 and v4 = :vtg6 and v5 = :vtg7 and v6 = :vtg8 and v7 = :vtg9 and v8 = :vtg10 and v9 = :vtg11" + benchVarValue []byte +) func init() { // benchQuerySize is the approximate size of the query. @@ -71,7 +74,7 @@ func BenchmarkExecuteVarBinary(b *testing.B) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.SetAllowAll(true) for i := 0; i < b.N; i++ { - if _, err := tsv.Execute(ctx, &target, benchQuery, bv, 0, 0, nil); err != nil { + if _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, benchQuery, bv, 0, 0); err != nil { panic(err) } } @@ -98,7 +101,7 @@ func BenchmarkExecuteExpression(b *testing.B) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.SetAllowAll(true) for i := 0; i < b.N; i++ { - if _, err := tsv.Execute(ctx, &target, benchQuery, bv, 0, 0, nil); err != nil { + if _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, benchQuery, bv, 0, 0); err != nil { panic(err) } } diff --git a/go/vt/vttablet/tabletserver/dt_executor_test.go b/go/vt/vttablet/tabletserver/dt_executor_test.go index e1a7c80656a..b02046d1106 100644 --- a/go/vt/vttablet/tabletserver/dt_executor_test.go +++ b/go/vt/vttablet/tabletserver/dt_executor_test.go @@ -38,6 +38,7 @@ import ( "vitess.io/vitess/go/streamlog" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" "vitess.io/vitess/go/vt/vttablet/tabletserver/schema" @@ -771,7 +772,7 @@ func newNoTwopcExecutor(t *testing.T, ctx context.Context) (txe *DTExecutor, tsv func newTxForPrep(ctx context.Context, tsv *TabletServer) int64 { txid := newTransaction(tsv, nil) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - _, err := tsv.Execute(ctx, &target, "update test_table set name = 2 where pk = 1", nil, txid, 0, nil) + _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set name = 2 where pk = 1", nil, txid, 0) if err != nil { panic(err) } diff --git a/go/vt/vttablet/tabletserver/query_executor_test.go b/go/vt/vttablet/tabletserver/query_executor_test.go index ae590d45b0e..24a60093283 100644 --- a/go/vt/vttablet/tabletserver/query_executor_test.go +++ b/go/vt/vttablet/tabletserver/query_executor_test.go @@ -55,6 +55,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" tableaclpb "vitess.io/vitess/go/vt/proto/tableacl" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -95,229 +96,230 @@ func TestQueryExecutorPlans(t *testing.T) { outsideTxErr bool // TxThrottler allows the test case to override the transaction throttler txThrottler txthrottler.TxThrottler - }{{ - input: "select * from t", - dbResponses: []dbResponse{{ - query: "select * from t limit 10001", - result: selectResult, - }}, - resultWant: selectResult, - planWant: "Select", - logWant: "select * from t limit 10001", - inTxWant: "select * from t limit 10001", - }, { - input: "select * from t limit 1", - dbResponses: []dbResponse{{ - query: "select * from t limit 1", - result: selectResult, - }}, - resultWant: selectResult, - planWant: "Select", - logWant: "select * from t limit 1", - inTxWant: "select * from t limit 1", - }, { - input: "show engines", - dbResponses: []dbResponse{{ - query: "show engines", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "Show", - logWant: "show engines", - }, { - input: "repair t", - dbResponses: []dbResponse{{ - query: "repair t", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "OtherAdmin", - logWant: "repair t", - }, { - input: "insert into test_table(a) values(1)", - dbResponses: []dbResponse{{ - query: "insert into test_table(a) values (1)", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "Insert", - logWant: "insert into test_table(a) values (1)", - }, { - input: "replace into test_table(a) values(1)", - dbResponses: []dbResponse{{ - query: "replace into test_table(a) values (1)", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "Insert", - logWant: "replace into test_table(a) values (1)", - }, { - input: "update test_table set a=1", - dbResponses: []dbResponse{{ - query: "update test_table set a = 1 limit 10001", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "UpdateLimit", - // The UpdateLimit query will not use autocommit because - // it needs to roll back on failure. - logWant: "begin; update test_table set a = 1 limit 10001; commit", - inTxWant: "update test_table set a = 1 limit 10001", - }, { - input: "select a, b from test_table", - passThrough: true, - inDMLExec: true, - dbResponses: []dbResponse{{ - query: "select a, b from test_table", - result: selectResult, - }}, - resultWant: selectResult, - planWant: "SelectNoLimit", - logWant: "select a, b from test_table", - outsideTxErr: true, - errorWant: "[BUG] SelectNoLimit unexpected plan type", - }, { - input: "update test_table set a=1", - passThrough: true, - dbResponses: []dbResponse{{ - query: "update test_table set a = 1", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "Update", - logWant: "update test_table set a = 1", - }, { - input: "delete from test_table", - dbResponses: []dbResponse{{ - query: "delete from test_table limit 10001", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "DeleteLimit", - // The DeleteLimit query will not use autocommit because - // it needs to roll back on failure. - logWant: "begin; delete from test_table limit 10001; commit", - inTxWant: "delete from test_table limit 10001", - }, { - input: "delete from test_table", - passThrough: true, - dbResponses: []dbResponse{{ - query: "delete from test_table", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "Delete", - logWant: "delete from test_table", - }, { - input: "alter table test_table add zipcode int", - dbResponses: []dbResponse{{ - query: "alter table test_table add column zipcode int", - result: dmlResult, - }}, - resultWant: dmlResult, - planWant: "DDL", - logWant: "alter table test_table add column zipcode int", - onlyInTxErr: true, - errorWant: "DDL statement executed inside a transaction", - }, { - input: "savepoint a", - dbResponses: []dbResponse{{ - query: "savepoint a", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "Savepoint", - logWant: "savepoint a", - inTxWant: "savepoint a", - }, { - input: "create index a on user(id)", - dbResponses: []dbResponse{{ - query: "alter table `user` add key a (id)", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "DDL", - logWant: "alter table `user` add key a (id)", - inTxWant: "alter table `user` add key a (id)", - onlyInTxErr: true, - errorWant: "DDL statement executed inside a transaction", - }, { - input: "create index a on user(id1 + id2)", - dbResponses: []dbResponse{{ - query: "create index a on user(id1 + id2)", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "DDL", - logWant: "create index a on user(id1 + id2)", - inTxWant: "create index a on user(id1 + id2)", - onlyInTxErr: true, - errorWant: "DDL statement executed inside a transaction", - }, { - input: "ROLLBACK work to SAVEPOINT a", - dbResponses: []dbResponse{{ - query: "ROLLBACK work to SAVEPOINT a", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "RollbackSavepoint", - logWant: "ROLLBACK work to SAVEPOINT a", - inTxWant: "ROLLBACK work to SAVEPOINT a", - }, { - input: "RELEASE savepoint a", - dbResponses: []dbResponse{{ - query: "RELEASE savepoint a", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "Release", - logWant: "RELEASE savepoint a", - inTxWant: "RELEASE savepoint a", - }, { - input: "show create database db_name", - dbResponses: []dbResponse{{ - query: "show create database ks", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "Show", - logWant: "show create database ks", - }, { - input: "show create database mysql", - dbResponses: []dbResponse{{ - query: "show create database mysql", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "Show", - logWant: "show create database mysql", - }, { - input: "show create table mysql.user", - dbResponses: []dbResponse{{ - query: "show create table mysql.`user`", - result: emptyResult, - }}, - resultWant: emptyResult, - planWant: "Show", - logWant: "show create table mysql.`user`", - }, { - input: "update test_table set a=1", - dbResponses: []dbResponse{{ - query: "update test_table set a = 1 limit 10001", - result: dmlResult, - }}, - errorWant: "Transaction throttled", - txThrottler: &mockTxThrottler{true}, - }, { - input: "update test_table set a=1", - passThrough: true, - dbResponses: []dbResponse{{ - query: "update test_table set a = 1 limit 10001", - result: dmlResult, - }}, - errorWant: "Transaction throttled", - txThrottler: &mockTxThrottler{true}, - }, + }{ + { + input: "select * from t", + dbResponses: []dbResponse{{ + query: "select * from t limit 10001", + result: selectResult, + }}, + resultWant: selectResult, + planWant: "Select", + logWant: "select * from t limit 10001", + inTxWant: "select * from t limit 10001", + }, { + input: "select * from t limit 1", + dbResponses: []dbResponse{{ + query: "select * from t limit 1", + result: selectResult, + }}, + resultWant: selectResult, + planWant: "Select", + logWant: "select * from t limit 1", + inTxWant: "select * from t limit 1", + }, { + input: "show engines", + dbResponses: []dbResponse{{ + query: "show engines", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "Show", + logWant: "show engines", + }, { + input: "repair t", + dbResponses: []dbResponse{{ + query: "repair t", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "OtherAdmin", + logWant: "repair t", + }, { + input: "insert into test_table(a) values(1)", + dbResponses: []dbResponse{{ + query: "insert into test_table(a) values (1)", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "Insert", + logWant: "insert into test_table(a) values (1)", + }, { + input: "replace into test_table(a) values(1)", + dbResponses: []dbResponse{{ + query: "replace into test_table(a) values (1)", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "Insert", + logWant: "replace into test_table(a) values (1)", + }, { + input: "update test_table set a=1", + dbResponses: []dbResponse{{ + query: "update test_table set a = 1 limit 10001", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "UpdateLimit", + // The UpdateLimit query will not use autocommit because + // it needs to roll back on failure. + logWant: "begin; update test_table set a = 1 limit 10001; commit", + inTxWant: "update test_table set a = 1 limit 10001", + }, { + input: "select a, b from test_table", + passThrough: true, + inDMLExec: true, + dbResponses: []dbResponse{{ + query: "select a, b from test_table", + result: selectResult, + }}, + resultWant: selectResult, + planWant: "SelectNoLimit", + logWant: "select a, b from test_table", + outsideTxErr: true, + errorWant: "[BUG] SelectNoLimit unexpected plan type", + }, { + input: "update test_table set a=1", + passThrough: true, + dbResponses: []dbResponse{{ + query: "update test_table set a = 1", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "Update", + logWant: "update test_table set a = 1", + }, { + input: "delete from test_table", + dbResponses: []dbResponse{{ + query: "delete from test_table limit 10001", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "DeleteLimit", + // The DeleteLimit query will not use autocommit because + // it needs to roll back on failure. + logWant: "begin; delete from test_table limit 10001; commit", + inTxWant: "delete from test_table limit 10001", + }, { + input: "delete from test_table", + passThrough: true, + dbResponses: []dbResponse{{ + query: "delete from test_table", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "Delete", + logWant: "delete from test_table", + }, { + input: "alter table test_table add zipcode int", + dbResponses: []dbResponse{{ + query: "alter table test_table add column zipcode int", + result: dmlResult, + }}, + resultWant: dmlResult, + planWant: "DDL", + logWant: "alter table test_table add column zipcode int", + onlyInTxErr: true, + errorWant: "DDL statement executed inside a transaction", + }, { + input: "savepoint a", + dbResponses: []dbResponse{{ + query: "savepoint a", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "Savepoint", + logWant: "savepoint a", + inTxWant: "savepoint a", + }, { + input: "create index a on user(id)", + dbResponses: []dbResponse{{ + query: "alter table `user` add key a (id)", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "DDL", + logWant: "alter table `user` add key a (id)", + inTxWant: "alter table `user` add key a (id)", + onlyInTxErr: true, + errorWant: "DDL statement executed inside a transaction", + }, { + input: "create index a on user(id1 + id2)", + dbResponses: []dbResponse{{ + query: "create index a on user(id1 + id2)", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "DDL", + logWant: "create index a on user(id1 + id2)", + inTxWant: "create index a on user(id1 + id2)", + onlyInTxErr: true, + errorWant: "DDL statement executed inside a transaction", + }, { + input: "ROLLBACK work to SAVEPOINT a", + dbResponses: []dbResponse{{ + query: "ROLLBACK work to SAVEPOINT a", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "RollbackSavepoint", + logWant: "ROLLBACK work to SAVEPOINT a", + inTxWant: "ROLLBACK work to SAVEPOINT a", + }, { + input: "RELEASE savepoint a", + dbResponses: []dbResponse{{ + query: "RELEASE savepoint a", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "Release", + logWant: "RELEASE savepoint a", + inTxWant: "RELEASE savepoint a", + }, { + input: "show create database db_name", + dbResponses: []dbResponse{{ + query: "show create database ks", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "Show", + logWant: "show create database ks", + }, { + input: "show create database mysql", + dbResponses: []dbResponse{{ + query: "show create database mysql", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "Show", + logWant: "show create database mysql", + }, { + input: "show create table mysql.user", + dbResponses: []dbResponse{{ + query: "show create table mysql.`user`", + result: emptyResult, + }}, + resultWant: emptyResult, + planWant: "Show", + logWant: "show create table mysql.`user`", + }, { + input: "update test_table set a=1", + dbResponses: []dbResponse{{ + query: "update test_table set a = 1 limit 10001", + result: dmlResult, + }}, + errorWant: "Transaction throttled", + txThrottler: &mockTxThrottler{true}, + }, { + input: "update test_table set a=1", + passThrough: true, + dbResponses: []dbResponse{{ + query: "update test_table set a = 1 limit 10001", + result: dmlResult, + }}, + errorWant: "Transaction throttled", + txThrottler: &mockTxThrottler{true}, + }, } for _, tcase := range testcases { t.Run(tcase.input, func(t *testing.T) { @@ -352,7 +354,7 @@ func TestQueryExecutorPlans(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) if !tcase.outsideTxErr && tcase.errorWant != "" && !tcase.onlyInTxErr { require.EqualError(t, err, tcase.errorWant) return @@ -449,7 +451,7 @@ func TestQueryExecutorQueryAnnotation(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong alias returned by Begin") @@ -516,7 +518,7 @@ func TestQueryExecutorSelectImpossible(t *testing.T) { assert.Equal(t, tcase.planWant, qre.logStats.PlanType, tcase.input) assert.Equal(t, tcase.logWant, qre.logStats.RewrittenSQL(), tcase.input) target := tsv.sm.Target() - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong tablet alias from Begin") @@ -647,7 +649,7 @@ func TestQueryExecutorLimitFailure(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong tablet alias from Begin") @@ -1638,7 +1640,7 @@ func newTestTabletServer(ctx context.Context, flags executorFlags, db *fakesqldb func newTransaction(tsv *TabletServer, options *querypb.ExecuteOptions) int64 { target := tsv.sm.Target() - state, err := tsv.Begin(context.Background(), target, options) + state, err := tsv.Begin(context.Background(), &vtgatepb.Session{Options: options}, target) if err != nil { panic(vterrors.Wrap(err, "failed to start a transaction")) } diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 12a33bad575..8ba36743902 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -32,6 +32,7 @@ import ( "syscall" "time" + "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/tabletserver/querythrottler" "vitess.io/vitess/go/acl" @@ -463,12 +464,12 @@ func (tsv *TabletServer) IsHealthy() error { if topoproto.IsServingType(tsv.sm.Target().TabletType) { _, err := tsv.Execute( tabletenv.LocalContext(), + nilSession(), nil, "/* health */ select 1 from dual", nil, 0, 0, - nil, ) return err } @@ -535,8 +536,8 @@ func (tsv *TabletServer) SchemaEngine() *schema.Engine { } // Begin starts a new transaction. This is allowed only if the state is StateServing. -func (tsv *TabletServer) Begin(ctx context.Context, target *querypb.Target, options *querypb.ExecuteOptions) (state queryservice.TransactionState, err error) { - return tsv.begin(ctx, target, nil, 0, nil, options) +func (tsv *TabletServer) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (state queryservice.TransactionState, err error) { + return tsv.begin(ctx, target, nil, 0, nil, session.GetOptions()) } func (tsv *TabletServer) begin( @@ -903,7 +904,7 @@ func (tsv *TabletServer) UnresolvedTransactions(ctx context.Context, target *que } // Execute executes the query and returns the result as response. -func (tsv *TabletServer) Execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64, options *querypb.ExecuteOptions) (result *sqltypes.Result, err error) { +func (tsv *TabletServer) Execute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID, reservedID int64) (result *sqltypes.Result, err error) { span, ctx := trace.NewSpan(ctx, "TabletServer.Execute") trace.AnnotateSQL(span, sqlparser.Preview(sql)) defer span.Finish() @@ -912,7 +913,7 @@ func (tsv *TabletServer) Execute(ctx context.Context, target *querypb.Target, sq return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] transactionID and reserveID must match if both are non-zero") } - return tsv.execute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, options) + return tsv.execute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, session.GetOptions()) } func (tsv *TabletServer) execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, settings []string, options *querypb.ExecuteOptions) (result *sqltypes.Result, err error) { @@ -1009,12 +1010,12 @@ func smallerTimeout(t1, t2 time.Duration) time.Duration { // StreamExecute executes the query and streams the result. // The first QueryResult will have Fields set (and Rows nil). // The subsequent QueryResult will have Rows set (and Fields nil). -func (tsv *TabletServer) StreamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) (err error) { +func (tsv *TabletServer) StreamExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, callback func(*sqltypes.Result) error) (err error) { if transactionID != 0 && reservedID != 0 && transactionID != reservedID { return vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] transactionID and reserveID must match if both are non-zero") } - return tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, options, callback) + return tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, session.GetOptions(), callback) } func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, settings []string, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { @@ -1077,10 +1078,10 @@ func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Targ } // BeginExecute combines Begin and Execute. -func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Target, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, options *querypb.ExecuteOptions) (queryservice.TransactionState, *sqltypes.Result, error) { +func (tsv *TabletServer) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { // Disable hot row protection in case of reserve connection. if tsv.enableHotRowProtection && reservedID == 0 { - txDone, err := tsv.beginWaitForSameRangeTransactions(ctx, target, options, sql, bindVariables) + txDone, err := tsv.beginWaitForSameRangeTransactions(ctx, target, session.GetOptions(), sql, bindVariables) if err != nil { return queryservice.TransactionState{}, nil, err } @@ -1089,32 +1090,32 @@ func (tsv *TabletServer) BeginExecute(ctx context.Context, target *querypb.Targe } } - state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, options) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, session.GetOptions()) if err != nil { return state, nil, err } - result, err := tsv.Execute(ctx, target, sql, bindVariables, state.TransactionID, reservedID, options) + result, err := tsv.Execute(ctx, session, target, sql, bindVariables, state.TransactionID, reservedID) return state, result, err } // BeginStreamExecute combines Begin and StreamExecute. func (tsv *TabletServer) BeginStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (queryservice.TransactionState, error) { - state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, options) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, session.GetOptions()) if err != nil { return state, err } - err = tsv.StreamExecute(ctx, target, sql, bindVariables, state.TransactionID, reservedID, options, callback) + err = tsv.StreamExecute(ctx, session, target, sql, bindVariables, state.TransactionID, reservedID, callback) return state, err } @@ -1271,7 +1272,7 @@ func (tsv *TabletServer) execDML(ctx context.Context, target *querypb.Target, qu return 0, err } - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, nilSession(), target) if err != nil { return 0, err } @@ -1282,7 +1283,7 @@ func (tsv *TabletServer) execDML(ctx context.Context, target *querypb.Target, qu tsv.Rollback(ctx, target, state.TransactionID) } }() - qr, err := tsv.Execute(ctx, target, query, bv, state.TransactionID, 0, nil) + qr, err := tsv.Execute(ctx, nilSession(), target, query, bv, state.TransactionID, 0) if err != nil { return 0, err } @@ -1335,7 +1336,8 @@ func (tsv *TabletServer) VStreamResults(ctx context.Context, target *querypb.Tar } // ReserveBeginExecute implements the QueryService interface -func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { +func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { + options := session.GetOptions() state, result, err = tsv.beginExecuteWithSettings(ctx, target, settings, postBeginQueries, sql, bindVariables, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. @@ -1396,7 +1398,6 @@ func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *queryp return nil }, ) - if err != nil { return state, nil, err } @@ -1411,14 +1412,15 @@ func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, target *queryp // ReserveBeginStreamExecute combines Begin and StreamExecute. func (tsv *TabletServer) ReserveBeginStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedTransactionState, err error) { + options := session.GetOptions() txState, err := tsv.begin(ctx, target, postBeginQueries, 0, settings, options) if err != nil { return txToReserveState(txState), err @@ -1429,7 +1431,8 @@ func (tsv *TabletServer) ReserveBeginStreamExecute( } // ReserveExecute implements the QueryService interface -func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Target, settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, options *querypb.ExecuteOptions) (state queryservice.ReservedState, result *sqltypes.Result, err error) { +func (tsv *TabletServer) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (state queryservice.ReservedState, result *sqltypes.Result, err error) { + options := session.GetOptions() result, err = tsv.executeWithSettings(ctx, target, settings, sql, bindVariables, transactionID, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. @@ -1464,7 +1467,6 @@ func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Tar return nil }, ) - if err != nil { return state, nil, err } @@ -1476,15 +1478,15 @@ func (tsv *TabletServer) ReserveExecute(ctx context.Context, target *querypb.Tar // ReserveStreamExecute combines Begin and StreamExecute. func (tsv *TabletServer) ReserveStreamExecute( ctx context.Context, + session queryservice.Session, target *querypb.Target, settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, - options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedState, err error) { - return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, settings, options, callback) + return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, settings, session.GetOptions(), callback) } // Release implements the QueryService interface @@ -2114,3 +2116,8 @@ func skipQueryPlanCache(options *querypb.ExecuteOptions) bool { func (tsv *TabletServer) getShard() string { return tsv.sm.Target().Shard } + +// nilSession is a helper that returns an empty session. +func nilSession() *executorcontext.SafeSession { + return executorcontext.NewSafeSession(nil) +} diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index 62ae91a922a..6863a655a6c 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -59,6 +59,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -137,7 +138,7 @@ func TestBeginOnReplica(t *testing.T) { options := querypb.ExecuteOptions{ TransactionIsolation: querypb.ExecuteOptions_CONSISTENT_SNAPSHOT_READ_ONLY, } - state, err := tsv.Begin(ctx, &target, &options) + state, err := tsv.Begin(ctx, &vtgatepb.Session{Options: &options}, &target) require.NoError(t, err, "failed to create read only tx on replica") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong tablet alias from Begin") _, err = tsv.Rollback(ctx, &target, state.TransactionID) @@ -145,7 +146,7 @@ func TestBeginOnReplica(t *testing.T) { // test that we can still create transactions even in read-only mode options = querypb.ExecuteOptions{} - state, err = tsv.Begin(ctx, &target, &options) + state, err = tsv.Begin(ctx, &vtgatepb.Session{Options: &options}, &target) require.NoError(t, err, "expected write tx to be allowed") _, err = tsv.Rollback(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -163,14 +164,14 @@ func TestTabletServerPrimaryToReplica(t *testing.T) { tsv.te.shutdownGracePeriod = 1 tsv.sm.shutdownGracePeriod = 1 target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state1, err := tsv.Begin(ctx, &target, nil) + state1, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &target, "update test_table set `name` = 2 where pk = 1", nil, state1.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state1.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state1.TransactionID, "aa") require.NoError(t, err) - state2, err := tsv.Begin(ctx, &target, nil) + state2, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) // This makes txid2 busy @@ -246,7 +247,8 @@ func TestTabletServerRedoLogIsKeptBetweenRestarts(t *testing.T) { got := tsv.te.preparedPool.conns["dtid0"].TxProperties().Queries want := []tx.Query{{ Sql: "update test_table set `name` = 2 where pk = 1 limit 10001", - Tables: []string{"test_table"}}} + Tables: []string{"test_table"}, + }} utils.MustMatch(t, want, got, "Prepared queries") turnOffTxEngine() assert.Empty(t, tsv.te.preparedPool.conns, "tsv.te.preparedPool.conns") @@ -286,7 +288,8 @@ func TestTabletServerRedoLogIsKeptBetweenRestarts(t *testing.T) { got = tsv.te.preparedPool.conns["a:b:10"].TxProperties().Queries want = []tx.Query{{ Sql: "update test_table set `name` = 2 where pk = 1 limit 10001", - Tables: []string{"test_table"}}} + Tables: []string{"test_table"}, + }} utils.MustMatch(t, want, got, "Prepared queries") wantFailed := map[string]error{ "bogus": errPrepFailed, // The query is rejected by database so added to failed list. @@ -474,8 +477,8 @@ func TestTabletServerBeginFail(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} ctx, cancel = context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - tsv.Begin(ctx, &target, nil) - _, err := tsv.Begin(ctx, &target, nil) + tsv.Begin(ctx, &vtgatepb.Session{}, &target) + _, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.EqualError(t, err, "transaction pool aborting request due to already expired context", "Begin err") } @@ -498,9 +501,9 @@ func TestTabletServerCommitTransaction(t *testing.T) { db.AddQuery(executeSQL, executeSQLResult) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &target, executeSQL, nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) _, err = tsv.Commit(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -540,12 +543,12 @@ func TestTabletServerRollback(t *testing.T) { db.AddQuery(executeSQL, executeSQLResult) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) if err != nil { t.Fatalf("call TabletServer.Begin failed: %v", err) } - _, err = tsv.Execute(ctx, &target, executeSQL, nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) _, err = tsv.Rollback(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -558,9 +561,9 @@ func TestTabletServerPrepare(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) defer tsv.RollbackPrepared(ctx, &target, "aa", 0) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") @@ -574,9 +577,9 @@ func TestTabletServerCommitPrepared(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") require.NoError(t, err) @@ -623,12 +626,12 @@ func TestTabletServerWithNilTarget(t *testing.T) { expectedCount := tsv.stats.QueryTimingsByTabletType.Counts()[fullKey] - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) - _, err = tsv.Execute(ctx, target, executeSQL, nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) @@ -638,7 +641,7 @@ func TestTabletServerWithNilTarget(t *testing.T) { expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) - state, err = tsv.Begin(ctx, target, nil) + state, err = tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) @@ -651,7 +654,7 @@ func TestTabletServerWithNilTarget(t *testing.T) { // Finally be sure that we return an error now as expected when NOT // using a local context but passing a nil target. nonLocalCtx := context.Background() - _, err = tsv.Begin(nonLocalCtx, target, nil) + _, err = tsv.Begin(nonLocalCtx, &vtgatepb.Session{}, target) require.True(t, errors.Is(err, ErrNoTarget)) _, err = tsv.resolveTargetType(nonLocalCtx, target) require.True(t, errors.Is(err, ErrNoTarget)) @@ -773,11 +776,11 @@ func TestTabletServerReserveConnection(t *testing.T) { options := &querypb.ExecuteOptions{} // reserve a connection - state, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, 0, options) + state, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{Options: options}, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) // run a query in it - _, err = tsv.Execute(ctx, &target, "select 42", nil, 0, state.ReservedID, options) + _, err = tsv.Execute(ctx, &vtgatepb.Session{Options: options}, &target, "select 42", nil, 0, state.ReservedID) require.NoError(t, err) // release the connection @@ -797,7 +800,7 @@ func TestTabletServerExecNonExistentConnection(t *testing.T) { options := &querypb.ExecuteOptions{} // run a query with a non-existent reserved id - _, err := tsv.Execute(ctx, &target, "select 42", nil, 0, 123456, options) + _, err := tsv.Execute(ctx, &vtgatepb.Session{Options: options}, &target, "select 42", nil, 0, 123456) require.Error(t, err) } @@ -828,7 +831,7 @@ func TestMakeSureToCloseDbConnWhenBeginQueryFails(t *testing.T) { options := &querypb.ExecuteOptions{} // run a query with a non-existent reserved id - _, _, err := tsv.ReserveBeginExecute(ctx, &target, []string{}, nil, "select 42", nil, options) + _, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{Options: options}, &target, []string{}, nil, "select 42", nil) require.Error(t, err) } @@ -844,7 +847,7 @@ func TestTabletServerReserveAndBeginCommit(t *testing.T) { options := &querypb.ExecuteOptions{} // reserve a connection and a transaction - state, _, err := tsv.ReserveBeginExecute(ctx, &target, nil, nil, "set sql_mode = ''", nil, options) + state, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{Options: options}, &target, nil, nil, "set sql_mode = ''", nil) require.NoError(t, err) defer func() { // fallback so the test finishes quickly @@ -852,13 +855,13 @@ func TestTabletServerReserveAndBeginCommit(t *testing.T) { }() // run a query in it - _, err = tsv.Execute(ctx, &target, "select 42", nil, state.TransactionID, state.ReservedID, options) + _, err = tsv.Execute(ctx, &vtgatepb.Session{Options: options}, &target, "select 42", nil, state.TransactionID, state.ReservedID) require.NoError(t, err) // run a query in a non-existent connection - _, err = tsv.Execute(ctx, &target, "select 42", nil, state.TransactionID, state.ReservedID+100, options) + _, err = tsv.Execute(ctx, &vtgatepb.Session{Options: options}, &target, "select 42", nil, state.TransactionID, state.ReservedID+100) require.Error(t, err) - _, err = tsv.Execute(ctx, &target, "select 42", nil, state.TransactionID+100, state.ReservedID, options) + _, err = tsv.Execute(ctx, &vtgatepb.Session{Options: options}, &target, "select 42", nil, state.TransactionID+100, state.ReservedID) require.Error(t, err) // commit @@ -868,7 +871,7 @@ func TestTabletServerReserveAndBeginCommit(t *testing.T) { rID := newRID // begin and rollback - beginState, _, err := tsv.BeginExecute(ctx, &target, nil, "select 42", nil, rID, options) + beginState, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{Options: options}, &target, nil, "select 42", nil, rID) require.NoError(t, err) assert.Equal(t, newRID, beginState.TransactionID) rID = newRID @@ -894,9 +897,9 @@ func TestTabletServerRollbackPrepared(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0, nil) + _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") require.NoError(t, err) @@ -924,7 +927,7 @@ func TestTabletServerStreamExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - if err := tsv.StreamExecute(ctx, &target, executeSQL, nil, 0, 0, nil, callback); err != nil { + if err := tsv.StreamExecute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, 0, 0, callback); err != nil { t.Fatalf("TabletServer.StreamExecute should success: %s, but get error: %v", executeSQL, err) } @@ -954,7 +957,7 @@ func TestTabletServerStreamExecuteComments(t *testing.T) { ch := tabletenv.StatsLogger.Subscribe("test stats logging") defer tabletenv.StatsLogger.Unsubscribe(ch) - if err := tsv.StreamExecute(ctx, &target, executeSQL, nil, 0, 0, nil, callback); err != nil { + if err := tsv.StreamExecute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, 0, 0, callback); err != nil { t.Fatalf("TabletServer.StreamExecute should success: %s, but get error: %v", executeSQL, err) } @@ -990,7 +993,7 @@ func TestTabletServerBeginStreamExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - state, err := tsv.BeginStreamExecute(ctx, &target, nil, executeSQL, nil, 0, nil, callback) + state, err := tsv.BeginStreamExecute(ctx, &vtgatepb.Session{}, &target, nil, executeSQL, nil, 0, callback) if err != nil { t.Fatalf("TabletServer.BeginStreamExecute should success: %s, but get error: %v", executeSQL, err) @@ -1023,7 +1026,7 @@ func TestTabletServerBeginStreamExecuteWithError(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - state, err := tsv.BeginStreamExecute(ctx, &target, nil, executeSQL, nil, 0, nil, callback) + state, err := tsv.BeginStreamExecute(ctx, &vtgatepb.Session{}, &target, nil, executeSQL, nil, 0, callback) require.Error(t, err) err = tsv.Release(ctx, &target, state.TransactionID, 0) require.NoError(t, err) @@ -1090,7 +1093,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &target, nil, q1, bvTx1, 0, nil) + state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1105,7 +1108,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { defer wg.Done() <-tx1Started - state2, _, err := tsv.BeginExecute(ctx, &target, nil, q2, bvTx2, 0, nil) + state2, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q2, err) } @@ -1125,7 +1128,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { defer wg.Done() <-tx1Started - state3, _, err := tsv.BeginExecute(ctx, &target, nil, q3, bvTx3, 0, nil) + state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -1160,7 +1163,7 @@ func TestDMLQueryWithoutWhereClause(t *testing.T) { db.AddQuery(q+" limit 10001", &sqltypes.Result{}) - state, _, err := tsv.BeginExecute(ctx, &target, nil, q, nil, 0, nil) + state, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q, nil, 0) require.NoError(t, err) _, err = tsv.Commit(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -1234,7 +1237,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &target, nil, q1, bvTx1, 0, nil) + state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1253,7 +1256,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // In that case, we would see less than 3 pending transactions. <-tx1Started - state2, _, err := tsv.BeginExecute(ctx, &target, nil, q2, bvTx2, 0, nil) + state2, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q2, err) } @@ -1272,7 +1275,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // In that case, we would see less than 3 pending transactions. <-tx1Started - state3, _, err := tsv.BeginExecute(ctx, &target, nil, q3, bvTx3, 0, nil) + state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -1370,7 +1373,7 @@ func TestSerializeTransactionsSameRow_TooManyPendingRequests(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &target, nil, q1, bvTx1, 0, nil) + state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1386,7 +1389,7 @@ func TestSerializeTransactionsSameRow_TooManyPendingRequests(t *testing.T) { defer close(tx2Failed) <-tx1Started - _, _, err := tsv.BeginExecute(ctx, &target, nil, q2, bvTx2, 0, nil) + _, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) if err == nil || vterrors.Code(err) != vtrpcpb.Code_RESOURCE_EXHAUSTED || err.Error() != "hot row protection: too many queued transactions (1 >= 1) for the same row (table + WHERE clause: 'test_table where pk = 1 and `name` = 1')" { t.Errorf("tx2 should have failed because there are too many pending requests: %v", err) } @@ -1459,7 +1462,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &target, nil, q1, bvTx1, 0, nil) + state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1479,7 +1482,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { // Wait until tx1 has started to make the test deterministic. <-tx1Started - _, _, err := tsv.BeginExecute(ctxTx2, &target, nil, q2, bvTx2, 0, nil) + _, _, err := tsv.BeginExecute(ctxTx2, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) if err == nil || vterrors.Code(err) != vtrpcpb.Code_CANCELED || err.Error() != "context canceled" { t.Errorf("tx2 should have failed because the context was canceled: %v", err) } @@ -1496,7 +1499,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { t.Error(err) } - state3, _, err := tsv.BeginExecute(ctx, &target, nil, q3, bvTx3, 0, nil) + state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -2146,6 +2149,7 @@ var aclJSON1 = `{ } ] }` + var aclJSON2 = `{ "table_groups": [ { @@ -2156,6 +2160,7 @@ var aclJSON2 = `{ } ] }` + var aclJSONOverlapError = `{ "table_groups": [ { @@ -2316,7 +2321,7 @@ func TestReserveBeginExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.AddQueryPattern("set @@sql_mode = ''", &sqltypes.Result{}) - state, _, err := tsv.ReserveBeginExecute(ctx, &target, nil, nil, "set @@sql_mode = ''", nil, &querypb.ExecuteOptions{}) + state, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set @@sql_mode = ''", nil) require.NoError(t, err) assert.Greater(t, state.TransactionID, int64(0), "transactionID") @@ -2343,7 +2348,7 @@ func TestReserveExecute_WithoutTx(t *testing.T) { db.AddQueryPattern("set sql_mode = ''", &sqltypes.Result{}) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, 0, &querypb.ExecuteOptions{}) + state, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) assert.NotEqual(t, int64(0), state.ReservedID, "reservedID should not be zero") expected := []string{ @@ -2367,12 +2372,12 @@ func TestReserveExecute_WithTx(t *testing.T) { db.AddQueryPattern("set sql_mode = ''", &sqltypes.Result{}) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - beginState, err := tsv.Begin(ctx, &target, &querypb.ExecuteOptions{}) + beginState, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) require.NoError(t, err) require.NotEqual(t, int64(0), beginState.TransactionID) db.ResetQueryLog() - reserveState, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, beginState.TransactionID, &querypb.ExecuteOptions{}) + reserveState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, beginState.TransactionID) require.NoError(t, err) defer tsv.Release(ctx, &target, beginState.TransactionID, reserveState.ReservedID) assert.Equal(t, beginState.TransactionID, reserveState.ReservedID, "reservedID should be equal to transactionID") @@ -2431,19 +2436,19 @@ func TestRelease(t *testing.T) { switch { case test.begin && test.reserve: - state, _, err := tsv.ReserveBeginExecute(ctx, &target, nil, nil, "set sql_mode = ''", nil, &querypb.ExecuteOptions{}) + state, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set sql_mode = ''", nil) require.NoError(t, err) transactionID = state.TransactionID reservedID = state.ReservedID require.NotEqual(t, int64(0), transactionID) require.NotEqual(t, int64(0), reservedID) case test.begin: - state, _, err := tsv.BeginExecute(ctx, &target, nil, "select 42", nil, 0, &querypb.ExecuteOptions{}) + state, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, 0) require.NoError(t, err) transactionID = state.TransactionID require.NotEqual(t, int64(0), transactionID) case test.reserve: - state, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, 0, &querypb.ExecuteOptions{}) + state, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) reservedID = state.ReservedID require.NotEqual(t, int64(0), reservedID) @@ -2478,27 +2483,27 @@ func TestReserveStats(t *testing.T) { ctx = callerid.NewContext(ctx, nil, callerID) // Starts reserved connection and transaction - rbeState, _, err := tsv.ReserveBeginExecute(ctx, &target, nil, nil, "set sql_mode = ''", nil, &querypb.ExecuteOptions{}) + rbeState, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set sql_mode = ''", nil) require.NoError(t, err) assert.EqualValues(t, 1, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Starts reserved connection - reState, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, 0, &querypb.ExecuteOptions{}) + reState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Use previous reserved connection to start transaction - reBeState, _, err := tsv.BeginExecute(ctx, &target, nil, "select 42", nil, reState.ReservedID, &querypb.ExecuteOptions{}) + reBeState, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, reState.ReservedID) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Starts transaction. - beState, _, err := tsv.BeginExecute(ctx, &target, nil, "select 42", nil, 0, &querypb.ExecuteOptions{}) + beState, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, 0) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Reserved the connection on previous transaction - beReState, _, err := tsv.ReserveExecute(ctx, &target, nil, "set sql_mode = ''", nil, beState.TransactionID, &querypb.ExecuteOptions{}) + beReState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, beState.TransactionID) require.NoError(t, err) assert.EqualValues(t, 3, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) @@ -2544,11 +2549,11 @@ func TestDatabaseNameReplaceByKeyspaceNameExecuteMethod(t *testing.T) { target := tsv.sm.target // Testing Execute Method - state, err := tsv.Begin(ctx, target, nil) + state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) require.NoError(t, err) - res, err := tsv.Execute(ctx, target, executeSQL, nil, state.TransactionID, 0, &querypb.ExecuteOptions{ + res, err := tsv.Execute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, - }) + }}, target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) for _, field := range res.Fields { require.Equal(t, "keyspaceName", field.Database) @@ -2590,9 +2595,9 @@ func TestDatabaseNameReplaceByKeyspaceNameStreamExecuteMethod(t *testing.T) { } return nil } - err := tsv.StreamExecute(ctx, target, executeSQL, nil, 0, 0, &querypb.ExecuteOptions{ + err := tsv.StreamExecute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, - }, callback) + }}, target, executeSQL, nil, 0, 0, callback) require.NoError(t, err) } @@ -2621,9 +2626,9 @@ func TestDatabaseNameReplaceByKeyspaceNameBeginExecuteMethod(t *testing.T) { target := tsv.sm.target // Test BeginExecute Method - state, res, err := tsv.BeginExecute(ctx, target, nil, executeSQL, nil, 0, &querypb.ExecuteOptions{ + state, res, err := tsv.BeginExecute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, - }) + }}, target, nil, executeSQL, nil, 0) require.NoError(t, err) for _, field := range res.Fields { require.Equal(t, "keyspaceName", field.Database) @@ -2661,9 +2666,9 @@ func TestDatabaseNameReplaceByKeyspaceNameReserveExecuteMethod(t *testing.T) { target := tsv.sm.target // Test ReserveExecute - _, res, err := tsv.ReserveExecute(ctx, target, nil, executeSQL, nil, 0, &querypb.ExecuteOptions{ + _, res, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, - }) + }}, target, nil, executeSQL, nil, 0) require.NoError(t, err) for _, field := range res.Fields { require.Equal(t, "keyspaceName", field.Database) @@ -2695,9 +2700,9 @@ func TestDatabaseNameReplaceByKeyspaceNameReserveBeginExecuteMethod(t *testing.T target := tsv.sm.target // Test for ReserveBeginExecute - state, res, err := tsv.ReserveBeginExecute(ctx, target, nil, nil, executeSQL, nil, &querypb.ExecuteOptions{ + state, res, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, - }) + }}, target, nil, nil, executeSQL, nil) require.NoError(t, err) for _, field := range res.Fields { require.Equal(t, "keyspaceName", field.Database) diff --git a/go/vtbench/client.go b/go/vtbench/client.go index 70c9615b321..e141dba605d 100644 --- a/go/vtbench/client.go +++ b/go/vtbench/client.go @@ -55,7 +55,6 @@ func (c *mysqlClientConn) connect(ctx context.Context, cp ConnParams) error { Pass: cp.Password, UnixSocket: cp.UnixSocket, }) - if err != nil { return err } @@ -162,5 +161,5 @@ func (c *grpcVttabletConn) connect(ctx context.Context, cp ConnParams) error { } func (c *grpcVttabletConn) execute(ctx context.Context, query string, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return c.qs.Execute(ctx, &c.target, query, bindVars, 0, 0, nil) + return c.qs.Execute(ctx, nil, &c.target, query, bindVars, 0, 0) } diff --git a/proto/query.proto b/proto/query.proto index 798fd6d4996..6c15ff22bfb 100644 --- a/proto/query.proto +++ b/proto/query.proto @@ -379,9 +379,6 @@ message ExecuteOptions { // transaction_timeout specifies the transaction timeout in milliseconds. If not set, the default timeout is used. optional int64 transaction_timeout = 20; - - // SessionUUID is the UUID of the current session. - string SessionUUID = 21; } // Field describes a single column returned by a query diff --git a/web/vtadmin/src/proto/vtadmin.d.ts b/web/vtadmin/src/proto/vtadmin.d.ts index 1a36dbc57f4..419e558161f 100644 --- a/web/vtadmin/src/proto/vtadmin.d.ts +++ b/web/vtadmin/src/proto/vtadmin.d.ts @@ -42226,9 +42226,6 @@ export namespace query { /** ExecuteOptions transaction_timeout */ transaction_timeout?: (number|Long|null); - - /** ExecuteOptions SessionUUID */ - SessionUUID?: (string|null); } /** Represents an ExecuteOptions. */ @@ -42288,9 +42285,6 @@ export namespace query { /** ExecuteOptions transaction_timeout. */ public transaction_timeout?: (number|Long|null); - /** ExecuteOptions SessionUUID. */ - public SessionUUID: string; - /** ExecuteOptions timeout. */ public timeout?: "authoritative_timeout"; diff --git a/web/vtadmin/src/proto/vtadmin.js b/web/vtadmin/src/proto/vtadmin.js index 6f54f72b5f0..854d7430112 100644 --- a/web/vtadmin/src/proto/vtadmin.js +++ b/web/vtadmin/src/proto/vtadmin.js @@ -100782,7 +100782,6 @@ export const query = $root.query = (() => { * @property {boolean|null} [fetch_last_insert_id] ExecuteOptions fetch_last_insert_id * @property {boolean|null} [in_dml_execution] ExecuteOptions in_dml_execution * @property {number|Long|null} [transaction_timeout] ExecuteOptions transaction_timeout - * @property {string|null} [SessionUUID] ExecuteOptions SessionUUID */ /** @@ -100929,14 +100928,6 @@ export const query = $root.query = (() => { */ ExecuteOptions.prototype.transaction_timeout = null; - /** - * ExecuteOptions SessionUUID. - * @member {string} SessionUUID - * @memberof query.ExecuteOptions - * @instance - */ - ExecuteOptions.prototype.SessionUUID = ""; - // OneOf field names bound to virtual getters and setters let $oneOfFields; @@ -101017,8 +101008,6 @@ export const query = $root.query = (() => { writer.uint32(/* id 19, wireType 0 =*/152).bool(message.in_dml_execution); if (message.transaction_timeout != null && Object.hasOwnProperty.call(message, "transaction_timeout")) writer.uint32(/* id 20, wireType 0 =*/160).int64(message.transaction_timeout); - if (message.SessionUUID != null && Object.hasOwnProperty.call(message, "SessionUUID")) - writer.uint32(/* id 21, wireType 2 =*/170).string(message.SessionUUID); return writer; }; @@ -101124,10 +101113,6 @@ export const query = $root.query = (() => { message.transaction_timeout = reader.int64(); break; } - case 21: { - message.SessionUUID = reader.string(); - break; - } default: reader.skipType(tag & 7); break; @@ -101267,9 +101252,6 @@ export const query = $root.query = (() => { if (!$util.isInteger(message.transaction_timeout) && !(message.transaction_timeout && $util.isInteger(message.transaction_timeout.low) && $util.isInteger(message.transaction_timeout.high))) return "transaction_timeout: integer|Long expected"; } - if (message.SessionUUID != null && message.hasOwnProperty("SessionUUID")) - if (!$util.isString(message.SessionUUID)) - return "SessionUUID: string expected"; return null; }; @@ -101495,8 +101477,6 @@ export const query = $root.query = (() => { message.transaction_timeout = object.transaction_timeout; else if (typeof object.transaction_timeout === "object") message.transaction_timeout = new $util.LongBits(object.transaction_timeout.low >>> 0, object.transaction_timeout.high >>> 0).toNumber(); - if (object.SessionUUID != null) - message.SessionUUID = String(object.SessionUUID); return message; }; @@ -101533,7 +101513,6 @@ export const query = $root.query = (() => { object.priority = ""; object.fetch_last_insert_id = false; object.in_dml_execution = false; - object.SessionUUID = ""; } if (message.included_fields != null && message.hasOwnProperty("included_fields")) object.included_fields = options.enums === String ? $root.query.ExecuteOptions.IncludedFields[message.included_fields] === undefined ? message.included_fields : $root.query.ExecuteOptions.IncludedFields[message.included_fields] : message.included_fields; @@ -101585,8 +101564,6 @@ export const query = $root.query = (() => { if (options.oneofs) object._transaction_timeout = "transaction_timeout"; } - if (message.SessionUUID != null && message.hasOwnProperty("SessionUUID")) - object.SessionUUID = message.SessionUUID; return object; }; From 00601120a0036e35548dbb925dc7c3e60e3bbd76 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 13:23:04 -0500 Subject: [PATCH 52/67] fix e2e test Signed-off-by: Mohamed Hamza --- go/vt/vtgate/plugin_mysql_server.go | 3 +- go/vt/vttablet/endtoend/framework/client.go | 37 +++++++++++++-------- go/vtbench/client.go | 3 +- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index cfeb014d588..67d94b5ef6e 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -181,8 +181,7 @@ var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) // this function is here to make this logic easy to test by decoupling the logic from the `trace.NewSpan` and `trace.NewFromString` functions func startSpanTestable(ctx context.Context, query, label string, newSpan func(context.Context, string) (trace.Span, context.Context), - newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error), -) (trace.Span, context.Context, error) { + newSpanFromString func(context.Context, string, string) (trace.Span, context.Context, error)) (trace.Span, context.Context, error) { _, comments := sqlparser.SplitMarginComments(query) match := r.FindStringSubmatch(comments.Leading) span, ctx := getSpan(ctx, match, newSpan, label, newSpanFromString) diff --git a/go/vt/vttablet/endtoend/framework/client.go b/go/vt/vttablet/endtoend/framework/client.go index 59def25ab61..94bb3ccd154 100644 --- a/go/vt/vttablet/endtoend/framework/client.go +++ b/go/vt/vttablet/endtoend/framework/client.go @@ -28,6 +28,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -99,11 +100,9 @@ func (client *QueryClient) Begin(clientFoundRows bool) error { if client.transactionID != 0 { return errors.New("already in transaction") } - var options *querypb.ExecuteOptions - if clientFoundRows { - options = &querypb.ExecuteOptions{ClientFoundRows: clientFoundRows} - } - state, err := client.server.Begin(client.ctx, client.target, options) + + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{ClientFoundRows: clientFoundRows}} + state, err := client.server.Begin(client.ctx, session, client.target) if err != nil { return err } @@ -202,14 +201,16 @@ func (client *QueryClient) BeginExecute(query string, bindvars map[string]*query if client.transactionID != 0 { return nil, errors.New("already in transaction") } + + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} state, qr, err := client.server.BeginExecute( client.ctx, + session, client.target, preQueries, query, bindvars, client.reservedID, - &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}, ) client.transactionID = state.TransactionID client.sessionStateChanges = state.SessionStateChanges @@ -221,14 +222,15 @@ func (client *QueryClient) BeginExecute(query string, bindvars map[string]*query // ExecuteWithOptions executes a query using 'options'. func (client *QueryClient) ExecuteWithOptions(query string, bindvars map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { + session := &vtgatepb.Session{Options: options} return client.server.Execute( client.ctx, + session, client.target, query, bindvars, client.transactionID, client.reservedID, - options, ) } @@ -240,13 +242,14 @@ func (client *QueryClient) StreamExecute(query string, bindvars map[string]*quer // StreamExecuteWithOptions executes a query & returns the results using 'options'. func (client *QueryClient) StreamExecuteWithOptions(query string, bindvars map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { result := &sqltypes.Result{} + session := &vtgatepb.Session{Options: options} err := client.server.StreamExecute(client.ctx, + session, client.target, query, bindvars, client.transactionID, client.reservedID, - options, func(res *sqltypes.Result) error { if result.Fields == nil { result.Fields = res.Fields @@ -263,14 +266,15 @@ func (client *QueryClient) StreamExecuteWithOptions(query string, bindvars map[s // StreamBeginExecuteWithOptions starts a tx and executes a query using 'options', returning the results . func (client *QueryClient) StreamBeginExecuteWithOptions(query string, preQueries []string, bindvars map[string]*querypb.BindVariable, options *querypb.ExecuteOptions) (*sqltypes.Result, error) { result := &sqltypes.Result{} + session := &vtgatepb.Session{Options: options} state, err := client.server.BeginStreamExecute( client.ctx, + session, client.target, preQueries, query, bindvars, client.reservedID, - options, func(res *sqltypes.Result) error { if result.Fields == nil { result.Fields = res.Fields @@ -289,7 +293,8 @@ func (client *QueryClient) StreamBeginExecuteWithOptions(query string, preQuerie // Stream streams the results of a query. func (client *QueryClient) Stream(query string, bindvars map[string]*querypb.BindVariable, sendFunc func(*sqltypes.Result) error) error { - return client.server.StreamExecute(client.ctx, client.target, query, bindvars, 0, 0, &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}, sendFunc) + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} + return client.server.StreamExecute(client.ctx, session, client.target, query, bindvars, 0, 0, sendFunc) } // MessageStream streams messages from the message table. @@ -316,7 +321,8 @@ func (client *QueryClient) ReserveExecute(query string, preQueries []string, bin if client.reservedID != 0 { return nil, errors.New("already reserved a connection") } - state, qr, err := client.server.ReserveExecute(client.ctx, client.target, preQueries, query, bindvars, client.transactionID, &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}) + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} + state, qr, err := client.server.ReserveExecute(client.ctx, session, client.target, preQueries, query, bindvars, client.transactionID) client.reservedID = state.ReservedID if err != nil { return nil, err @@ -330,7 +336,8 @@ func (client *QueryClient) ReserveStreamExecute(query string, preQueries []strin return nil, errors.New("already reserved a connection") } result := &sqltypes.Result{} - state, err := client.server.ReserveStreamExecute(client.ctx, client.target, preQueries, query, bindvars, client.transactionID, &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}, + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} + state, err := client.server.ReserveStreamExecute(client.ctx, session, client.target, preQueries, query, bindvars, client.transactionID, func(res *sqltypes.Result) error { if result.Fields == nil { result.Fields = res.Fields @@ -353,7 +360,8 @@ func (client *QueryClient) ReserveBeginExecute(query string, preQueries []string if client.transactionID != 0 { return nil, errors.New("already in transaction") } - state, qr, err := client.server.ReserveBeginExecute(client.ctx, client.target, preQueries, postBeginQueries, query, bindvars, &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}) + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} + state, qr, err := client.server.ReserveBeginExecute(client.ctx, session, client.target, preQueries, postBeginQueries, query, bindvars) client.transactionID = state.TransactionID client.reservedID = state.ReservedID client.sessionStateChanges = state.SessionStateChanges @@ -372,7 +380,8 @@ func (client *QueryClient) ReserveBeginStreamExecute(query string, preQueries [] return nil, errors.New("already in transaction") } result := &sqltypes.Result{} - state, err := client.server.ReserveBeginStreamExecute(client.ctx, client.target, preQueries, postBeginQueries, query, bindvars, &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}, + session := &vtgatepb.Session{Options: &querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_ALL}} + state, err := client.server.ReserveBeginStreamExecute(client.ctx, session, client.target, preQueries, postBeginQueries, query, bindvars, func(res *sqltypes.Result) error { if result.Fields == nil { result.Fields = res.Fields diff --git a/go/vtbench/client.go b/go/vtbench/client.go index e141dba605d..c610f4ea1db 100644 --- a/go/vtbench/client.go +++ b/go/vtbench/client.go @@ -29,6 +29,7 @@ import ( "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/topo/topoproto" + "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/vtgateconn" "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/vt/vttablet/tabletconn" @@ -161,5 +162,5 @@ func (c *grpcVttabletConn) connect(ctx context.Context, cp ConnParams) error { } func (c *grpcVttabletConn) execute(ctx context.Context, query string, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return c.qs.Execute(ctx, nil, &c.target, query, bindVars, 0, 0) + return c.qs.Execute(ctx, executorcontext.NewSafeSession(nil), &c.target, query, bindVars, 0, 0) } From c00ef7bd90314c7796f891de32680ba9da8bbe63 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 13:39:35 -0500 Subject: [PATCH 53/67] fix import cycle Signed-off-by: Mohamed Hamza --- go/vt/vttablet/tabletconntest/tabletconntest.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/vt/vttablet/tabletconntest/tabletconntest.go b/go/vt/vttablet/tabletconntest/tabletconntest.go index bba5cab9fcc..e28aecbe47e 100644 --- a/go/vt/vttablet/tabletconntest/tabletconntest.go +++ b/go/vt/vttablet/tabletconntest/tabletconntest.go @@ -36,12 +36,12 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/vt/vttablet/tabletconn" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -117,7 +117,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe t.Log("testBeginError") f.HasBeginError = true testErrorHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, executorcontext.NewSafeSession(nil), TestTarget) + _, err := conn.Begin(ctx, &vtgatepb.Session{}, TestTarget) return err }) f.HasBeginError = false @@ -126,7 +126,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe func testBeginPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { t.Log("testBeginPanics") testPanicHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, executorcontext.NewSafeSession(nil), TestTarget) + _, err := conn.Begin(ctx, &vtgatepb.Session{}, TestTarget) return err }) } From 2ade6a9e7b0e4df0ea28a5b9136ea40430f370b9 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 14:04:08 -0500 Subject: [PATCH 54/67] try to fix actions Signed-off-by: Mohamed Hamza --- .github/actions/tune-os/action.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/actions/tune-os/action.yml b/.github/actions/tune-os/action.yml index 618e3bfedd2..8a4f01ef206 100644 --- a/.github/actions/tune-os/action.yml +++ b/.github/actions/tune-os/action.yml @@ -10,6 +10,10 @@ runs: sudo apt-get update && sudo apt-get install -y eatmydata echo "/usr/\${LIB}/libeatmydata.so" | sudo tee -a /etc/ld.so.preload + if systemctl list-unit-files snap.amazon-ssm-agent.amazon-ssm-agent.service &>/dev/null; then + sudo systemctl stop snap.amazon-ssm-agent.amazon-ssm-agent.service + fi + # Increase the range of ephemeral ports. sudo sysctl -w net.ipv4.ip_local_port_range="22768 65535" From 6531827242291bf92bf4f31a48daa62229d73067 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 15:03:43 -0500 Subject: [PATCH 55/67] fix scatter conn test Signed-off-by: Mohamed Hamza --- go/vt/vtgate/scatter_conn.go | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index 7fde259f1fd..f58525710d8 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -18,6 +18,7 @@ package vtgate import ( "context" + "fmt" "io" "runtime/debug" "sync" @@ -166,6 +167,7 @@ func (stc *ScatterConn) ExecuteMultiShard( go stc.runLockQuery(ctx, session) } + fmt.Println("HELLO") if session.Options != nil { session.Options.FetchLastInsertId = fetchLastInsertID } @@ -180,15 +182,20 @@ func (stc *ScatterConn) ExecuteMultiShard( var ( innerqr *sqltypes.Result err error + opts *querypb.ExecuteOptions alias *topodatapb.TabletAlias qs queryservice.QueryService ) transactionID := info.transactionID reservedID := info.reservedID - if session.GetOptions() == nil && fetchLastInsertID { - session = econtext.NewSafeSession(session.Session) - session.SetOptions(&querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID}) + if session != nil && session.Session != nil { + opts = session.Session.Options + } + + if opts == nil && fetchLastInsertID { + opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} + session = econtext.NewSafeSession(&vtgatepb.Session{Options: opts}) } if autocommit { @@ -407,15 +414,20 @@ func (stc *ScatterConn) StreamExecuteMulti( func(rs *srvtopo.ResolvedShard, i int, info *shardActionInfo) (*shardActionInfo, error) { var ( err error + opts *querypb.ExecuteOptions alias *topodatapb.TabletAlias qs queryservice.QueryService ) transactionID := info.transactionID reservedID := info.reservedID - if session.GetOptions() == nil && fetchLastInsertID { - session = econtext.NewSafeSession(session.Session) - session.SetOptions(&querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID}) + if session != nil && session.Session != nil { + opts = session.Session.Options + } + + if opts == nil && fetchLastInsertID { + opts = &querypb.ExecuteOptions{FetchLastInsertId: fetchLastInsertID} + session = econtext.NewSafeSession(&vtgatepb.Session{Options: opts}) } if autocommit { From 74dbe47283e8e31f856e8ca09753aabe93652991 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 15:08:04 -0500 Subject: [PATCH 56/67] try to fix actions Signed-off-by: Mohamed Hamza --- .github/actions/setup-mysql/action.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/actions/setup-mysql/action.yml b/.github/actions/setup-mysql/action.yml index 444c2133a82..5d7e368731e 100644 --- a/.github/actions/setup-mysql/action.yml +++ b/.github/actions/setup-mysql/action.yml @@ -11,6 +11,15 @@ runs: shell: bash run: | export DEBIAN_FRONTEND="noninteractive" + + # Wait for any existing apt processes to finish + echo "Checking for existing apt processes..." + while sudo fuser /var/lib/dpkg/lock-frontend >/dev/null 2>&1 || sudo fuser /var/lib/apt/lists/lock >/dev/null 2>&1 ; do + echo "Waiting for other apt processes to finish..." + sleep 2 + done + echo "No apt locks detected, proceeding..." + sudo apt-get update # Uninstall any previously installed MySQL first From 61439c59c4c0b8f508ae53a1d8939b4501df11b0 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Fri, 12 Dec 2025 17:10:50 -0500 Subject: [PATCH 57/67] remove debug message Signed-off-by: Mohamed Hamza --- go/vt/vtgate/scatter_conn.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index f58525710d8..822febd2745 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -18,7 +18,6 @@ package vtgate import ( "context" - "fmt" "io" "runtime/debug" "sync" @@ -167,7 +166,6 @@ func (stc *ScatterConn) ExecuteMultiShard( go stc.runLockQuery(ctx, session) } - fmt.Println("HELLO") if session.Options != nil { session.Options.FetchLastInsertId = fetchLastInsertID } From 62d289d61f021a0cce9d2abe59b0f20c3bf629ce Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 09:08:55 -0500 Subject: [PATCH 58/67] undo ci changes for flakiness Signed-off-by: Mohamed Hamza --- .github/actions/setup-mysql/action.yml | 9 --------- .github/actions/tune-os/action.yml | 4 ---- 2 files changed, 13 deletions(-) diff --git a/.github/actions/setup-mysql/action.yml b/.github/actions/setup-mysql/action.yml index 5d7e368731e..444c2133a82 100644 --- a/.github/actions/setup-mysql/action.yml +++ b/.github/actions/setup-mysql/action.yml @@ -11,15 +11,6 @@ runs: shell: bash run: | export DEBIAN_FRONTEND="noninteractive" - - # Wait for any existing apt processes to finish - echo "Checking for existing apt processes..." - while sudo fuser /var/lib/dpkg/lock-frontend >/dev/null 2>&1 || sudo fuser /var/lib/apt/lists/lock >/dev/null 2>&1 ; do - echo "Waiting for other apt processes to finish..." - sleep 2 - done - echo "No apt locks detected, proceeding..." - sudo apt-get update # Uninstall any previously installed MySQL first diff --git a/.github/actions/tune-os/action.yml b/.github/actions/tune-os/action.yml index 8a4f01ef206..618e3bfedd2 100644 --- a/.github/actions/tune-os/action.yml +++ b/.github/actions/tune-os/action.yml @@ -10,10 +10,6 @@ runs: sudo apt-get update && sudo apt-get install -y eatmydata echo "/usr/\${LIB}/libeatmydata.so" | sudo tee -a /etc/ld.so.preload - if systemctl list-unit-files snap.amazon-ssm-agent.amazon-ssm-agent.service &>/dev/null; then - sudo systemctl stop snap.amazon-ssm-agent.amazon-ssm-agent.service - fi - # Increase the range of ephemeral ports. sudo sysctl -w net.ipv4.ip_local_port_range="22768 65535" From 091aba78f5c596f2fbe9611683d09e82b016c544 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 18:36:08 -0500 Subject: [PATCH 59/67] add nil session checks Signed-off-by: Mohamed Hamza --- go/vt/vttablet/grpctabletconn/conn.go | 27 ++++++++---- go/vt/vttablet/sandboxconn/sandboxconn.go | 13 +++++- .../tabletconntest/fakequeryservice.go | 21 +++++++--- go/vt/vttablet/tabletserver/tabletserver.go | 42 +++++++++++-------- go/vtbench/client.go | 3 +- 5 files changed, 70 insertions(+), 36 deletions(-) diff --git a/go/vt/vttablet/grpctabletconn/conn.go b/go/vt/vttablet/grpctabletconn/conn.go index 0bf73bff7a0..ce48d0d430e 100644 --- a/go/vt/vttablet/grpctabletconn/conn.go +++ b/go/vt/vttablet/grpctabletconn/conn.go @@ -128,7 +128,7 @@ func (conn *gRPCQueryClient) Execute(ctx context.Context, session queryservice.S BindVariables: bindVars, }, TransactionId: transactionID, - Options: session.GetOptions(), + Options: getOptions(session), ReservedId: reservedID, } er, err := conn.c.Execute(ctx, req) @@ -166,7 +166,7 @@ func (conn *gRPCQueryClient) StreamExecute(ctx context.Context, session queryser Sql: query, BindVariables: bindVars, }, - Options: session.GetOptions(), + Options: getOptions(session), TransactionId: transactionID, ReservedId: reservedID, } @@ -209,7 +209,7 @@ func (conn *gRPCQueryClient) Begin(ctx context.Context, session queryservice.Ses Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: session.GetOptions(), + Options: getOptions(session), } br, err := conn.c.Begin(ctx, req) if err != nil { @@ -480,7 +480,7 @@ func (conn *gRPCQueryClient) BeginExecute(ctx context.Context, session queryserv BindVariables: bindVars, }, ReservedId: reservedID, - Options: session.GetOptions(), + Options: getOptions(session), } reply, err := conn.c.BeginExecute(ctx, req) if err != nil { @@ -524,7 +524,7 @@ func (conn *gRPCQueryClient) BeginStreamExecute(ctx context.Context, session que BindVariables: bindVars, }, ReservedId: reservedID, - Options: session.GetOptions(), + Options: getOptions(session), } stream, err := conn.c.BeginStreamExecute(ctx, req) if err != nil { @@ -872,7 +872,7 @@ func (conn *gRPCQueryClient) ReserveBeginExecute(ctx context.Context, session qu Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: session.GetOptions(), + Options: getOptions(session), PreQueries: preQueries, PostBeginQueries: postBeginQueries, Query: &querypb.BoundQuery{ @@ -917,7 +917,7 @@ func (conn *gRPCQueryClient) ReserveBeginStreamExecute(ctx context.Context, sess Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: session.GetOptions(), + Options: getOptions(session), PreQueries: preQueries, PostBeginQueries: postBeginQueries, Query: &querypb.BoundQuery{ @@ -993,7 +993,7 @@ func (conn *gRPCQueryClient) ReserveExecute(ctx context.Context, session queryse BindVariables: bindVariables, }, TransactionId: transactionID, - Options: session.GetOptions(), + Options: getOptions(session), PreQueries: preQueries, } reply, err := conn.c.ReserveExecute(ctx, req) @@ -1031,7 +1031,7 @@ func (conn *gRPCQueryClient) ReserveStreamExecute(ctx context.Context, session q Target: target, EffectiveCallerId: callerid.EffectiveCallerIDFromContext(ctx), ImmediateCallerId: callerid.ImmediateCallerIDFromContext(ctx), - Options: session.GetOptions(), + Options: getOptions(session), PreQueries: preQueries, Query: &querypb.BoundQuery{ Sql: sql, @@ -1167,3 +1167,12 @@ func (conn *gRPCQueryClient) Close(ctx context.Context) error { func (conn *gRPCQueryClient) Tablet() *topodatapb.Tablet { return conn.tablet } + +// getOptions safely extracts ExecuteOptions from a session, returning nil if session is nil. +func getOptions(session queryservice.Session) *querypb.ExecuteOptions { + if session == nil { + return nil + } + + return session.GetOptions() +} diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index 4b466e6d10d..92f96f114de 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -283,7 +283,7 @@ func (sbc *SandboxConn) Execute(ctx context.Context, session queryservice.Sessio Sql: query, BindVariables: bv, }) - sbc.Options = append(sbc.Options, session.GetOptions()) + sbc.Options = append(sbc.Options, getOptions(session)) if err := sbc.getError(); err != nil { return nil, err } @@ -309,7 +309,7 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, session queryservice. Sql: query, BindVariables: bv, }) - sbc.Options = append(sbc.Options, session.GetOptions()) + sbc.Options = append(sbc.Options, getOptions(session)) err := sbc.getError() if err != nil { sbc.sExecMu.Unlock() @@ -894,3 +894,12 @@ func (sbc *SandboxConn) panicIfNeeded() { panic(sbc.panicThis) } } + +// getOptions safely extracts ExecuteOptions from a session, returning nil if session is nil. +func getOptions(session queryservice.Session) *querypb.ExecuteOptions { + if session == nil { + return nil + } + + return session.GetOptions() +} diff --git a/go/vt/vttablet/tabletconntest/fakequeryservice.go b/go/vt/vttablet/tabletconntest/fakequeryservice.go index 19a93592432..6ba49809b6a 100644 --- a/go/vt/vttablet/tabletconntest/fakequeryservice.go +++ b/go/vt/vttablet/tabletconntest/fakequeryservice.go @@ -152,8 +152,8 @@ func (f *FakeQueryService) Begin(ctx context.Context, session queryservice.Sessi panic(errors.New("test-triggered panic")) } f.checkTargetCallerID(ctx, "Begin", target) - if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { - f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) + if !proto.Equal(getOptions(session), TestSession.GetOptions()) { + f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", getOptions(session), TestSession.GetOptions()) } return queryservice.TransactionState{TransactionID: beginTransactionID, TabletAlias: TestAlias}, nil } @@ -428,8 +428,8 @@ func (f *FakeQueryService) Execute(ctx context.Context, session queryservice.Ses if !sqltypes.BindVariablesEqual(bindVariables, ExecuteBindVars) { f.t.Errorf("invalid Execute.BindVariables: got %v expected %v", bindVariables, ExecuteBindVars) } - if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { - f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) + if !proto.Equal(getOptions(session), TestSession.GetOptions()) { + f.t.Errorf("invalid Execute.ExecuteOptions: got %v expected %v", getOptions(session), TestSession.GetOptions()) } f.checkTargetCallerID(ctx, "Execute", target) if transactionID != f.ExpectedTransactionID { @@ -485,8 +485,8 @@ func (f *FakeQueryService) StreamExecute(ctx context.Context, session queryservi if !sqltypes.BindVariablesEqual(bindVariables, StreamExecuteBindVars) { f.t.Errorf("invalid StreamExecute.BindVariables: got %v expected %v", bindVariables, StreamExecuteBindVars) } - if !proto.Equal(session.GetOptions(), TestSession.GetOptions()) { - f.t.Errorf("invalid StreamExecute.ExecuteOptions: got %v expected %v", session.GetOptions(), TestSession.GetOptions()) + if !proto.Equal(getOptions(session), TestSession.GetOptions()) { + f.t.Errorf("invalid StreamExecute.ExecuteOptions: got %v expected %v", getOptions(session), TestSession.GetOptions()) } f.checkTargetCallerID(ctx, "StreamExecute", target) if err := callback(&StreamExecuteQueryResult1); err != nil { @@ -787,3 +787,12 @@ func CreateFakeServer(t testing.TB) *FakeQueryService { t: t, } } + +// getOptions safely extracts ExecuteOptions from a session, returning nil if session is nil. +func getOptions(session queryservice.Session) *querypb.ExecuteOptions { + if session == nil { + return nil + } + + return session.GetOptions() +} diff --git a/go/vt/vttablet/tabletserver/tabletserver.go b/go/vt/vttablet/tabletserver/tabletserver.go index 8ba36743902..60f5ab2d436 100644 --- a/go/vt/vttablet/tabletserver/tabletserver.go +++ b/go/vt/vttablet/tabletserver/tabletserver.go @@ -32,7 +32,6 @@ import ( "syscall" "time" - "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vttablet/tabletserver/querythrottler" "vitess.io/vitess/go/acl" @@ -464,7 +463,7 @@ func (tsv *TabletServer) IsHealthy() error { if topoproto.IsServingType(tsv.sm.Target().TabletType) { _, err := tsv.Execute( tabletenv.LocalContext(), - nilSession(), + nil, nil, "/* health */ select 1 from dual", nil, @@ -537,7 +536,7 @@ func (tsv *TabletServer) SchemaEngine() *schema.Engine { // Begin starts a new transaction. This is allowed only if the state is StateServing. func (tsv *TabletServer) Begin(ctx context.Context, session queryservice.Session, target *querypb.Target) (state queryservice.TransactionState, err error) { - return tsv.begin(ctx, target, nil, 0, nil, session.GetOptions()) + return tsv.begin(ctx, target, nil, 0, nil, getOptions(session)) } func (tsv *TabletServer) begin( @@ -913,7 +912,7 @@ func (tsv *TabletServer) Execute(ctx context.Context, session queryservice.Sessi return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] transactionID and reserveID must match if both are non-zero") } - return tsv.execute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, session.GetOptions()) + return tsv.execute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, getOptions(session)) } func (tsv *TabletServer) execute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, settings []string, options *querypb.ExecuteOptions) (result *sqltypes.Result, err error) { @@ -1015,7 +1014,7 @@ func (tsv *TabletServer) StreamExecute(ctx context.Context, session queryservice return vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] transactionID and reserveID must match if both are non-zero") } - return tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, session.GetOptions(), callback) + return tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, reservedID, nil, getOptions(session), callback) } func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Target, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64, reservedID int64, settings []string, options *querypb.ExecuteOptions, callback func(*sqltypes.Result) error) error { @@ -1079,9 +1078,11 @@ func (tsv *TabletServer) streamExecute(ctx context.Context, target *querypb.Targ // BeginExecute combines Begin and Execute. func (tsv *TabletServer) BeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable, reservedID int64) (queryservice.TransactionState, *sqltypes.Result, error) { + options := getOptions(session) + // Disable hot row protection in case of reserve connection. if tsv.enableHotRowProtection && reservedID == 0 { - txDone, err := tsv.beginWaitForSameRangeTransactions(ctx, target, session.GetOptions(), sql, bindVariables) + txDone, err := tsv.beginWaitForSameRangeTransactions(ctx, target, options, sql, bindVariables) if err != nil { return queryservice.TransactionState{}, nil, err } @@ -1090,7 +1091,7 @@ func (tsv *TabletServer) BeginExecute(ctx context.Context, session queryservice. } } - state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, session.GetOptions()) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, options) if err != nil { return state, nil, err } @@ -1110,7 +1111,7 @@ func (tsv *TabletServer) BeginStreamExecute( reservedID int64, callback func(*sqltypes.Result) error, ) (queryservice.TransactionState, error) { - state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, session.GetOptions()) + state, err := tsv.begin(ctx, target, postBeginQueries, reservedID, nil, getOptions(session)) if err != nil { return state, err } @@ -1272,7 +1273,7 @@ func (tsv *TabletServer) execDML(ctx context.Context, target *querypb.Target, qu return 0, err } - state, err := tsv.Begin(ctx, nilSession(), target) + state, err := tsv.Begin(ctx, nil, target) if err != nil { return 0, err } @@ -1283,7 +1284,7 @@ func (tsv *TabletServer) execDML(ctx context.Context, target *querypb.Target, qu tsv.Rollback(ctx, target, state.TransactionID) } }() - qr, err := tsv.Execute(ctx, nilSession(), target, query, bv, state.TransactionID, 0) + qr, err := tsv.Execute(ctx, nil, target, query, bv, state.TransactionID, 0) if err != nil { return 0, err } @@ -1337,7 +1338,8 @@ func (tsv *TabletServer) VStreamResults(ctx context.Context, target *querypb.Tar // ReserveBeginExecute implements the QueryService interface func (tsv *TabletServer) ReserveBeginExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, settings []string, postBeginQueries []string, sql string, bindVariables map[string]*querypb.BindVariable) (state queryservice.ReservedTransactionState, result *sqltypes.Result, err error) { - options := session.GetOptions() + options := getOptions(session) + state, result, err = tsv.beginExecuteWithSettings(ctx, target, settings, postBeginQueries, sql, bindVariables, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. @@ -1420,7 +1422,8 @@ func (tsv *TabletServer) ReserveBeginStreamExecute( bindVariables map[string]*querypb.BindVariable, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedTransactionState, err error) { - options := session.GetOptions() + options := getOptions(session) + txState, err := tsv.begin(ctx, target, postBeginQueries, 0, settings, options) if err != nil { return txToReserveState(txState), err @@ -1432,7 +1435,8 @@ func (tsv *TabletServer) ReserveBeginStreamExecute( // ReserveExecute implements the QueryService interface func (tsv *TabletServer) ReserveExecute(ctx context.Context, session queryservice.Session, target *querypb.Target, settings []string, sql string, bindVariables map[string]*querypb.BindVariable, transactionID int64) (state queryservice.ReservedState, result *sqltypes.Result, err error) { - options := session.GetOptions() + options := getOptions(session) + result, err = tsv.executeWithSettings(ctx, target, settings, sql, bindVariables, transactionID, options) // If there is an error and the error message is about allowing query in reserved connection only, // then we do not return an error from here and continue to use the reserved connection path. @@ -1486,7 +1490,7 @@ func (tsv *TabletServer) ReserveStreamExecute( transactionID int64, callback func(*sqltypes.Result) error, ) (state queryservice.ReservedState, err error) { - return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, settings, session.GetOptions(), callback) + return state, tsv.streamExecute(ctx, target, sql, bindVariables, transactionID, 0, settings, getOptions(session), callback) } // Release implements the QueryService interface @@ -2117,7 +2121,11 @@ func (tsv *TabletServer) getShard() string { return tsv.sm.Target().Shard } -// nilSession is a helper that returns an empty session. -func nilSession() *executorcontext.SafeSession { - return executorcontext.NewSafeSession(nil) +// getOptions safely extracts ExecuteOptions from a session, returning nil if session is nil. +func getOptions(session queryservice.Session) *querypb.ExecuteOptions { + if session == nil { + return nil + } + + return session.GetOptions() } diff --git a/go/vtbench/client.go b/go/vtbench/client.go index c610f4ea1db..e141dba605d 100644 --- a/go/vtbench/client.go +++ b/go/vtbench/client.go @@ -29,7 +29,6 @@ import ( "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/topo/topoproto" - "vitess.io/vitess/go/vt/vtgate/executorcontext" "vitess.io/vitess/go/vt/vtgate/vtgateconn" "vitess.io/vitess/go/vt/vttablet/queryservice" "vitess.io/vitess/go/vt/vttablet/tabletconn" @@ -162,5 +161,5 @@ func (c *grpcVttabletConn) connect(ctx context.Context, cp ConnParams) error { } func (c *grpcVttabletConn) execute(ctx context.Context, query string, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - return c.qs.Execute(ctx, executorcontext.NewSafeSession(nil), &c.target, query, bindVars, 0, 0) + return c.qs.Execute(ctx, nil, &c.target, query, bindVars, 0, 0) } From 6f706d67099ef50ff9ac46e5eab9a3e9c8f1c853 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 18:55:18 -0500 Subject: [PATCH 60/67] remove nil safe sessions Signed-off-by: Mohamed Hamza --- go/vt/vtgate/executor.go | 2 +- go/vt/vttablet/tabletmanager/rpc_query.go | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index e3ed264f44b..59bdeb838bd 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1012,7 +1012,7 @@ func (e *Executor) ShowVitessReplicationStatus(ctx context.Context, filter *sqlp replicaSQLRunningField = "Slave_SQL_Running" secondsBehindSourceField = "Seconds_Behind_Master" } - results, err := e.txConn.tabletGateway.Execute(ctx, econtext.NewSafeSession(nil), ts.Target, sql, nil, 0, 0) + results, err := e.txConn.tabletGateway.Execute(ctx, nil, ts.Target, sql, nil, 0, 0) if err != nil || results == nil { log.Warningf("Could not get replication status from %s: %v", tabletHostPort, err) } else if row := results.Named().Row(); row != nil { diff --git a/go/vt/vttablet/tabletmanager/rpc_query.go b/go/vt/vttablet/tabletmanager/rpc_query.go index 96b054ac241..3a3fdc39b2b 100644 --- a/go/vt/vttablet/tabletmanager/rpc_query.go +++ b/go/vt/vttablet/tabletmanager/rpc_query.go @@ -25,7 +25,6 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vtgate/executorcontext" querypb "vitess.io/vitess/go/vt/proto/query" tabletmanagerdatapb "vitess.io/vitess/go/vt/proto/tabletmanagerdata" @@ -294,6 +293,6 @@ func (tm *TabletManager) ExecuteQuery(ctx context.Context, req *tabletmanagerdat if err != nil { return nil, err } - result, err := tm.QueryServiceControl.QueryService().Execute(ctx, executorcontext.NewSafeSession(nil), target, uq, nil, 0, 0) + result, err := tm.QueryServiceControl.QueryService().Execute(ctx, nil, target, uq, nil, 0, 0) return sqltypes.ResultToProto3(result), err } From 53664df9db31df9aa2d6fdcfec817749d513dc67 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 19:00:40 -0500 Subject: [PATCH 61/67] simplify pick opts Signed-off-by: Mohamed Hamza --- go/vt/vtgate/tabletgateway.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index c684b2d0d41..c153aad4f4a 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -463,12 +463,12 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // Get the tablet from the balancer if enabled if useBalancer { - var sessionUUID string + pickOpts := balancer.PickOpts{} if opts.Session != nil { - sessionUUID = opts.Session.GetSessionUUID() + pickOpts.SessionUUID = opts.Session.GetSessionUUID() } - tablet := gw.balancer.Pick(target, tablets, balancer.PickOpts{SessionUUID: sessionUUID}) + tablet := gw.balancer.Pick(target, tablets, pickOpts) if tablet != nil { return tablet } From de491d13e4f9cae5f4eb763f396a2a6a67895fc2 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 19:14:47 -0500 Subject: [PATCH 62/67] simplify empty sessions Signed-off-by: Mohamed Hamza --- go/vt/vtgate/tabletgateway_flaky_test.go | 10 +- go/vt/vtgate/tabletgateway_test.go | 12 +-- go/vt/vttablet/grpctabletconn/conn_test.go | 9 +- .../vttablet/tabletconntest/tabletconntest.go | 5 +- go/vt/vttablet/tabletserver/bench_test.go | 9 +- .../vttablet/tabletserver/dt_executor_test.go | 2 +- .../tabletserver/query_executor_test.go | 8 +- .../tabletserver/tabletserver_test.go | 96 +++++++++---------- 8 files changed, 74 insertions(+), 77 deletions(-) diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index da0a7c7cd9f..f7d18965b36 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -95,7 +95,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { sbc.SetResults([]*sqltypes.Result{sqlResult1}) // run a query that we indeed get the result added to the sandbox connection back - res, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + res, err := tg.Execute(ctx, nil, target, "query", nil, 0, 0) require.NoError(t, err) require.Equal(t, res, sqlResult1) @@ -115,7 +115,7 @@ func TestGatewayBufferingWhenPrimarySwitchesServingState(t *testing.T) { // execute the query in a go routine since it should be buffered, and check that it eventually succeed queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + res, err = tg.Execute(ctx, nil, target, "query", nil, 0, 0) queryChan <- struct{}{} }() @@ -187,7 +187,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // run a query that we indeed get the result added to the sandbox connection back // this also checks that the query reaches the primary tablet and not the replica - res, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + res, err := tg.Execute(ctx, nil, target, "query", nil, 0, 0) require.NoError(t, err) require.Equal(t, res, sqlResult1) @@ -225,7 +225,7 @@ func TestGatewayBufferingWhileReparenting(t *testing.T) { // execute the query in a go routine since it should be buffered, and check that it eventually succeed queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + res, err = tg.Execute(ctx, nil, target, "query", nil, 0, 0) queryChan <- struct{}{} }() @@ -333,7 +333,7 @@ func TestInconsistentStateDetectedBuffering(t *testing.T) { var err error queryChan := make(chan struct{}) go func() { - res, err = tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + res, err = tg.Execute(ctx, nil, target, "query", nil, 0, 0) queryChan <- struct{}{} }() diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index 6b92febf83a..e370f62546c 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -44,14 +44,14 @@ import ( func TestTabletGatewayExecute(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0) + _, err := tg.Execute(ctx, nil, target, "query", nil, 0, 0) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { assert.Equal(t, want, sc.ExecCount.Load()) }) testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 1, 0) + _, err := tg.Execute(ctx, nil, target, "query", nil, 1, 0) return err }) } @@ -59,7 +59,7 @@ func TestTabletGatewayExecute(t *testing.T) { func TestTabletGatewayExecuteStream(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - err := tg.StreamExecute(ctx, &vtgatepb.Session{}, target, "query", nil, 0, 0, func(qr *sqltypes.Result) error { + err := tg.StreamExecute(ctx, nil, target, "query", nil, 0, 0, func(qr *sqltypes.Result) error { return nil }) return err @@ -72,7 +72,7 @@ func TestTabletGatewayExecuteStream(t *testing.T) { func TestTabletGatewayBegin(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, err := tg.Begin(ctx, &vtgatepb.Session{}, target) + _, err := tg.Begin(ctx, nil, target) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { @@ -99,7 +99,7 @@ func TestTabletGatewayRollback(t *testing.T) { func TestTabletGatewayBeginExecute(t *testing.T) { ctx := utils.LeakCheckContext(t) testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error { - _, _, err := tg.BeginExecute(ctx, &vtgatepb.Session{}, target, nil, "query", nil, 0) + _, _, err := tg.BeginExecute(ctx, nil, target, nil, "query", nil, 0) return err }, func(t *testing.T, sc *sandboxconn.SandboxConn, want int64) { @@ -191,7 +191,7 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) { defer tg.Close(ctx) _ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil) - _, err := tg.Execute(ctx, &vtgatepb.Session{}, target, "query", nil, 1, 0) + _, err := tg.Execute(ctx, nil, target, "query", nil, 1, 0) verifyContainsError(t, err, "query service can only be used for non-transactional queries on replicas", vtrpcpb.Code_INTERNAL) } diff --git a/go/vt/vttablet/grpctabletconn/conn_test.go b/go/vt/vttablet/grpctabletconn/conn_test.go index c7e104660d1..a7da2194d11 100644 --- a/go/vt/vttablet/grpctabletconn/conn_test.go +++ b/go/vt/vttablet/grpctabletconn/conn_test.go @@ -32,7 +32,6 @@ import ( binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" queryservicepb "vitess.io/vitess/go/vt/proto/queryservice" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttablet/grpcqueryservice" "vitess.io/vitess/go/vt/vttablet/tabletconntest" @@ -192,22 +191,22 @@ func TestGoRoutineLeakPrevention(t *testing.T) { cc: &grpc.ClientConn{}, c: mqc, } - _ = qc.StreamExecute(context.Background(), &vtgatepb.Session{}, nil, "", nil, 0, 0, func(result *sqltypes.Result) error { + _ = qc.StreamExecute(context.Background(), nil, nil, "", nil, 0, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.BeginStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { + _, _ = qc.BeginStreamExecute(context.Background(), nil, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.ReserveBeginStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, nil, "", nil, func(result *sqltypes.Result) error { + _, _ = qc.ReserveBeginStreamExecute(context.Background(), nil, nil, nil, nil, "", nil, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) - _, _ = qc.ReserveStreamExecute(context.Background(), &vtgatepb.Session{}, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { + _, _ = qc.ReserveStreamExecute(context.Background(), nil, nil, nil, "", nil, 0, func(result *sqltypes.Result) error { return nil }) require.Error(t, mqc.lastCallCtx.Err()) diff --git a/go/vt/vttablet/tabletconntest/tabletconntest.go b/go/vt/vttablet/tabletconntest/tabletconntest.go index e28aecbe47e..a10b2b42bff 100644 --- a/go/vt/vttablet/tabletconntest/tabletconntest.go +++ b/go/vt/vttablet/tabletconntest/tabletconntest.go @@ -41,7 +41,6 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -117,7 +116,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe t.Log("testBeginError") f.HasBeginError = true testErrorHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, &vtgatepb.Session{}, TestTarget) + _, err := conn.Begin(ctx, nil, TestTarget) return err }) f.HasBeginError = false @@ -126,7 +125,7 @@ func testBeginError(t *testing.T, conn queryservice.QueryService, f *FakeQuerySe func testBeginPanics(t *testing.T, conn queryservice.QueryService, f *FakeQueryService) { t.Log("testBeginPanics") testPanicHelper(t, f, "Begin", func(ctx context.Context) error { - _, err := conn.Begin(ctx, &vtgatepb.Session{}, TestTarget) + _, err := conn.Begin(ctx, nil, TestTarget) return err }) } diff --git a/go/vt/vttablet/tabletserver/bench_test.go b/go/vt/vttablet/tabletserver/bench_test.go index 9e8a2e7740b..d2cd12c853f 100644 --- a/go/vt/vttablet/tabletserver/bench_test.go +++ b/go/vt/vttablet/tabletserver/bench_test.go @@ -26,7 +26,6 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) // Benchmark run on 6/27/17, with optimized byte-level operations @@ -73,8 +72,8 @@ func BenchmarkExecuteVarBinary(b *testing.B) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.SetAllowAll(true) - for i := 0; i < b.N; i++ { - if _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, benchQuery, bv, 0, 0); err != nil { + for b.Loop() { + if _, err := tsv.Execute(ctx, nil, &target, benchQuery, bv, 0, 0); err != nil { panic(err) } } @@ -100,8 +99,8 @@ func BenchmarkExecuteExpression(b *testing.B) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.SetAllowAll(true) - for i := 0; i < b.N; i++ { - if _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, benchQuery, bv, 0, 0); err != nil { + for b.Loop() { + if _, err := tsv.Execute(ctx, nil, &target, benchQuery, bv, 0, 0); err != nil { panic(err) } } diff --git a/go/vt/vttablet/tabletserver/dt_executor_test.go b/go/vt/vttablet/tabletserver/dt_executor_test.go index b02046d1106..03b74044148 100644 --- a/go/vt/vttablet/tabletserver/dt_executor_test.go +++ b/go/vt/vttablet/tabletserver/dt_executor_test.go @@ -772,7 +772,7 @@ func newNoTwopcExecutor(t *testing.T, ctx context.Context) (txe *DTExecutor, tsv func newTxForPrep(ctx context.Context, tsv *TabletServer) int64 { txid := newTransaction(tsv, nil) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - _, err := tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set name = 2 where pk = 1", nil, txid, 0) + _, err := tsv.Execute(ctx, nil, &target, "update test_table set name = 2 where pk = 1", nil, txid, 0) if err != nil { panic(err) } diff --git a/go/vt/vttablet/tabletserver/query_executor_test.go b/go/vt/vttablet/tabletserver/query_executor_test.go index 24a60093283..2e94f0eaa33 100644 --- a/go/vt/vttablet/tabletserver/query_executor_test.go +++ b/go/vt/vttablet/tabletserver/query_executor_test.go @@ -354,7 +354,7 @@ func TestQueryExecutorPlans(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) if !tcase.outsideTxErr && tcase.errorWant != "" && !tcase.onlyInTxErr { require.EqualError(t, err, tcase.errorWant) return @@ -451,7 +451,7 @@ func TestQueryExecutorQueryAnnotation(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong alias returned by Begin") @@ -518,7 +518,7 @@ func TestQueryExecutorSelectImpossible(t *testing.T) { assert.Equal(t, tcase.planWant, qre.logStats.PlanType, tcase.input) assert.Equal(t, tcase.logWant, qre.logStats.RewrittenSQL(), tcase.input) target := tsv.sm.Target() - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong tablet alias from Begin") @@ -649,7 +649,7 @@ func TestQueryExecutorLimitFailure(t *testing.T) { // Test inside a transaction. target := tsv.sm.Target() - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) require.NoError(t, err) require.NotNil(t, state.TabletAlias, "alias should not be nil") assert.Equal(t, tsv.alias, state.TabletAlias, "Wrong tablet alias from Begin") diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index 6863a655a6c..a98c691af00 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -164,14 +164,14 @@ func TestTabletServerPrimaryToReplica(t *testing.T) { tsv.te.shutdownGracePeriod = 1 tsv.sm.shutdownGracePeriod = 1 target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state1, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state1, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state1.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, "update test_table set `name` = 2 where pk = 1", nil, state1.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state1.TransactionID, "aa") require.NoError(t, err) - state2, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state2, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) // This makes txid2 busy @@ -477,8 +477,8 @@ func TestTabletServerBeginFail(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} ctx, cancel = context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - tsv.Begin(ctx, &vtgatepb.Session{}, &target) - _, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + tsv.Begin(ctx, nil, &target) + _, err := tsv.Begin(ctx, nil, &target) require.EqualError(t, err, "transaction pool aborting request due to already expired context", "Begin err") } @@ -501,9 +501,9 @@ func TestTabletServerCommitTransaction(t *testing.T) { db.AddQuery(executeSQL, executeSQLResult) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) _, err = tsv.Commit(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -543,12 +543,12 @@ func TestTabletServerRollback(t *testing.T) { db.AddQuery(executeSQL, executeSQLResult) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) if err != nil { t.Fatalf("call TabletServer.Begin failed: %v", err) } - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) _, err = tsv.Rollback(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -561,9 +561,9 @@ func TestTabletServerPrepare(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) defer tsv.RollbackPrepared(ctx, &target, "aa", 0) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") @@ -577,9 +577,9 @@ func TestTabletServerCommitPrepared(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") require.NoError(t, err) @@ -626,12 +626,12 @@ func TestTabletServerWithNilTarget(t *testing.T) { expectedCount := tsv.stats.QueryTimingsByTabletType.Counts()[fullKey] - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, target, executeSQL, nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, target, executeSQL, nil, state.TransactionID, 0) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) @@ -641,7 +641,7 @@ func TestTabletServerWithNilTarget(t *testing.T) { expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) - state, err = tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err = tsv.Begin(ctx, nil, target) require.NoError(t, err) expectedCount++ require.Equal(t, expectedCount, tsv.stats.QueryTimingsByTabletType.Counts()[fullKey]) @@ -654,7 +654,7 @@ func TestTabletServerWithNilTarget(t *testing.T) { // Finally be sure that we return an error now as expected when NOT // using a local context but passing a nil target. nonLocalCtx := context.Background() - _, err = tsv.Begin(nonLocalCtx, &vtgatepb.Session{}, target) + _, err = tsv.Begin(nonLocalCtx, nil, target) require.True(t, errors.Is(err, ErrNoTarget)) _, err = tsv.resolveTargetType(nonLocalCtx, target) require.True(t, errors.Is(err, ErrNoTarget)) @@ -897,9 +897,9 @@ func TestTabletServerRollbackPrepared(t *testing.T) { _, tsv, _, closer := newTestTxExecutor(t, ctx) defer closer() target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + state, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) - _, err = tsv.Execute(ctx, &vtgatepb.Session{}, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) + _, err = tsv.Execute(ctx, nil, &target, "update test_table set `name` = 2 where pk = 1", nil, state.TransactionID, 0) require.NoError(t, err) err = tsv.Prepare(ctx, &target, state.TransactionID, "aa") require.NoError(t, err) @@ -927,7 +927,7 @@ func TestTabletServerStreamExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - if err := tsv.StreamExecute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, 0, 0, callback); err != nil { + if err := tsv.StreamExecute(ctx, nil, &target, executeSQL, nil, 0, 0, callback); err != nil { t.Fatalf("TabletServer.StreamExecute should success: %s, but get error: %v", executeSQL, err) } @@ -957,7 +957,7 @@ func TestTabletServerStreamExecuteComments(t *testing.T) { ch := tabletenv.StatsLogger.Subscribe("test stats logging") defer tabletenv.StatsLogger.Unsubscribe(ch) - if err := tsv.StreamExecute(ctx, &vtgatepb.Session{}, &target, executeSQL, nil, 0, 0, callback); err != nil { + if err := tsv.StreamExecute(ctx, nil, &target, executeSQL, nil, 0, 0, callback); err != nil { t.Fatalf("TabletServer.StreamExecute should success: %s, but get error: %v", executeSQL, err) } @@ -993,7 +993,7 @@ func TestTabletServerBeginStreamExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - state, err := tsv.BeginStreamExecute(ctx, &vtgatepb.Session{}, &target, nil, executeSQL, nil, 0, callback) + state, err := tsv.BeginStreamExecute(ctx, nil, &target, nil, executeSQL, nil, 0, callback) if err != nil { t.Fatalf("TabletServer.BeginStreamExecute should success: %s, but get error: %v", executeSQL, err) @@ -1026,7 +1026,7 @@ func TestTabletServerBeginStreamExecuteWithError(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} callback := func(*sqltypes.Result) error { return nil } - state, err := tsv.BeginStreamExecute(ctx, &vtgatepb.Session{}, &target, nil, executeSQL, nil, 0, callback) + state, err := tsv.BeginStreamExecute(ctx, nil, &target, nil, executeSQL, nil, 0, callback) require.Error(t, err) err = tsv.Release(ctx, &target, state.TransactionID, 0) require.NoError(t, err) @@ -1093,7 +1093,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) + state1, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1108,7 +1108,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { defer wg.Done() <-tx1Started - state2, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) + state2, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q2, bvTx2, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q2, err) } @@ -1128,7 +1128,7 @@ func TestSerializeTransactionsSameRow(t *testing.T) { defer wg.Done() <-tx1Started - state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) + state3, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -1163,7 +1163,7 @@ func TestDMLQueryWithoutWhereClause(t *testing.T) { db.AddQuery(q+" limit 10001", &sqltypes.Result{}) - state, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q, nil, 0) + state, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q, nil, 0) require.NoError(t, err) _, err = tsv.Commit(ctx, &target, state.TransactionID) require.NoError(t, err) @@ -1237,7 +1237,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) + state1, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1256,7 +1256,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // In that case, we would see less than 3 pending transactions. <-tx1Started - state2, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) + state2, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q2, bvTx2, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q2, err) } @@ -1275,7 +1275,7 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // In that case, we would see less than 3 pending transactions. <-tx1Started - state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) + state3, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -1373,7 +1373,7 @@ func TestSerializeTransactionsSameRow_TooManyPendingRequests(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) + state1, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1389,7 +1389,7 @@ func TestSerializeTransactionsSameRow_TooManyPendingRequests(t *testing.T) { defer close(tx2Failed) <-tx1Started - _, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) + _, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q2, bvTx2, 0) if err == nil || vterrors.Code(err) != vtrpcpb.Code_RESOURCE_EXHAUSTED || err.Error() != "hot row protection: too many queued transactions (1 >= 1) for the same row (table + WHERE clause: 'test_table where pk = 1 and `name` = 1')" { t.Errorf("tx2 should have failed because there are too many pending requests: %v", err) } @@ -1462,7 +1462,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { go func() { defer wg.Done() - state1, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q1, bvTx1, 0) + state1, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q1, bvTx1, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q1, err) } @@ -1482,7 +1482,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { // Wait until tx1 has started to make the test deterministic. <-tx1Started - _, _, err := tsv.BeginExecute(ctxTx2, &vtgatepb.Session{}, &target, nil, q2, bvTx2, 0) + _, _, err := tsv.BeginExecute(ctxTx2, nil, &target, nil, q2, bvTx2, 0) if err == nil || vterrors.Code(err) != vtrpcpb.Code_CANCELED || err.Error() != "context canceled" { t.Errorf("tx2 should have failed because the context was canceled: %v", err) } @@ -1499,7 +1499,7 @@ func TestSerializeTransactionsSameRow_RequestCanceled(t *testing.T) { t.Error(err) } - state3, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, q3, bvTx3, 0) + state3, _, err := tsv.BeginExecute(ctx, nil, &target, nil, q3, bvTx3, 0) if err != nil { t.Errorf("failed to execute query: %s: %s", q3, err) } @@ -2321,7 +2321,7 @@ func TestReserveBeginExecute(t *testing.T) { target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} db.AddQueryPattern("set @@sql_mode = ''", &sqltypes.Result{}) - state, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set @@sql_mode = ''", nil) + state, _, err := tsv.ReserveBeginExecute(ctx, nil, &target, nil, nil, "set @@sql_mode = ''", nil) require.NoError(t, err) assert.Greater(t, state.TransactionID, int64(0), "transactionID") @@ -2348,7 +2348,7 @@ func TestReserveExecute_WithoutTx(t *testing.T) { db.AddQueryPattern("set sql_mode = ''", &sqltypes.Result{}) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - state, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) + state, _, err := tsv.ReserveExecute(ctx, nil, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) assert.NotEqual(t, int64(0), state.ReservedID, "reservedID should not be zero") expected := []string{ @@ -2372,12 +2372,12 @@ func TestReserveExecute_WithTx(t *testing.T) { db.AddQueryPattern("set sql_mode = ''", &sqltypes.Result{}) target := querypb.Target{TabletType: topodatapb.TabletType_PRIMARY} - beginState, err := tsv.Begin(ctx, &vtgatepb.Session{}, &target) + beginState, err := tsv.Begin(ctx, nil, &target) require.NoError(t, err) require.NotEqual(t, int64(0), beginState.TransactionID) db.ResetQueryLog() - reserveState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, beginState.TransactionID) + reserveState, _, err := tsv.ReserveExecute(ctx, nil, &target, nil, "set sql_mode = ''", nil, beginState.TransactionID) require.NoError(t, err) defer tsv.Release(ctx, &target, beginState.TransactionID, reserveState.ReservedID) assert.Equal(t, beginState.TransactionID, reserveState.ReservedID, "reservedID should be equal to transactionID") @@ -2436,19 +2436,19 @@ func TestRelease(t *testing.T) { switch { case test.begin && test.reserve: - state, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set sql_mode = ''", nil) + state, _, err := tsv.ReserveBeginExecute(ctx, nil, &target, nil, nil, "set sql_mode = ''", nil) require.NoError(t, err) transactionID = state.TransactionID reservedID = state.ReservedID require.NotEqual(t, int64(0), transactionID) require.NotEqual(t, int64(0), reservedID) case test.begin: - state, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, 0) + state, _, err := tsv.BeginExecute(ctx, nil, &target, nil, "select 42", nil, 0) require.NoError(t, err) transactionID = state.TransactionID require.NotEqual(t, int64(0), transactionID) case test.reserve: - state, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) + state, _, err := tsv.ReserveExecute(ctx, nil, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) reservedID = state.ReservedID require.NotEqual(t, int64(0), reservedID) @@ -2483,27 +2483,27 @@ func TestReserveStats(t *testing.T) { ctx = callerid.NewContext(ctx, nil, callerID) // Starts reserved connection and transaction - rbeState, _, err := tsv.ReserveBeginExecute(ctx, &vtgatepb.Session{}, &target, nil, nil, "set sql_mode = ''", nil) + rbeState, _, err := tsv.ReserveBeginExecute(ctx, nil, &target, nil, nil, "set sql_mode = ''", nil) require.NoError(t, err) assert.EqualValues(t, 1, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Starts reserved connection - reState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, 0) + reState, _, err := tsv.ReserveExecute(ctx, nil, &target, nil, "set sql_mode = ''", nil, 0) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Use previous reserved connection to start transaction - reBeState, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, reState.ReservedID) + reBeState, _, err := tsv.BeginExecute(ctx, nil, &target, nil, "select 42", nil, reState.ReservedID) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Starts transaction. - beState, _, err := tsv.BeginExecute(ctx, &vtgatepb.Session{}, &target, nil, "select 42", nil, 0) + beState, _, err := tsv.BeginExecute(ctx, nil, &target, nil, "select 42", nil, 0) require.NoError(t, err) assert.EqualValues(t, 2, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) // Reserved the connection on previous transaction - beReState, _, err := tsv.ReserveExecute(ctx, &vtgatepb.Session{}, &target, nil, "set sql_mode = ''", nil, beState.TransactionID) + beReState, _, err := tsv.ReserveExecute(ctx, nil, &target, nil, "set sql_mode = ''", nil, beState.TransactionID) require.NoError(t, err) assert.EqualValues(t, 3, tsv.te.txPool.env.Stats().UserActiveReservedCount.Counts()["test"]) @@ -2549,7 +2549,7 @@ func TestDatabaseNameReplaceByKeyspaceNameExecuteMethod(t *testing.T) { target := tsv.sm.target // Testing Execute Method - state, err := tsv.Begin(ctx, &vtgatepb.Session{}, target) + state, err := tsv.Begin(ctx, nil, target) require.NoError(t, err) res, err := tsv.Execute(ctx, &vtgatepb.Session{Options: &querypb.ExecuteOptions{ IncludedFields: querypb.ExecuteOptions_ALL, From 19e11c16b6c9f7880a99bd2822028509b6845454 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Mon, 15 Dec 2025 19:21:41 -0500 Subject: [PATCH 63/67] goimports Signed-off-by: Mohamed Hamza --- go/vt/vtgate/tabletgateway_flaky_test.go | 1 - go/vt/vtgate/tabletgateway_test.go | 1 - go/vt/vttablet/tabletserver/dt_executor_test.go | 1 - 3 files changed, 3 deletions(-) diff --git a/go/vt/vtgate/tabletgateway_flaky_test.go b/go/vt/vtgate/tabletgateway_flaky_test.go index f7d18965b36..ef15d344762 100644 --- a/go/vt/vtgate/tabletgateway_flaky_test.go +++ b/go/vt/vtgate/tabletgateway_flaky_test.go @@ -32,7 +32,6 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) // TestGatewayBufferingWhenPrimarySwitchesServingState is used to test that the buffering mechanism buffers the queries when a primary goes to a non serving state and diff --git a/go/vt/vtgate/tabletgateway_test.go b/go/vt/vtgate/tabletgateway_test.go index e370f62546c..c7a7eebd98e 100644 --- a/go/vt/vtgate/tabletgateway_test.go +++ b/go/vt/vtgate/tabletgateway_test.go @@ -34,7 +34,6 @@ import ( "vitess.io/vitess/go/vt/discovery" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/topo" "vitess.io/vitess/go/vt/vterrors" diff --git a/go/vt/vttablet/tabletserver/dt_executor_test.go b/go/vt/vttablet/tabletserver/dt_executor_test.go index 03b74044148..5f24ef2a0f4 100644 --- a/go/vt/vttablet/tabletserver/dt_executor_test.go +++ b/go/vt/vttablet/tabletserver/dt_executor_test.go @@ -38,7 +38,6 @@ import ( "vitess.io/vitess/go/streamlog" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" - vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" "vitess.io/vitess/go/vt/vttablet/tabletserver/schema" From a4eb6e1b6b45c2dab41574cd22d37dfa94b48b5c Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 16 Dec 2025 10:28:52 -0500 Subject: [PATCH 64/67] Update go/vt/vtgate/balancer/balancer.go Co-authored-by: Matt Lord Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index a785c6f0382..92470c45406 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -154,7 +154,7 @@ type PickOpts struct { // - "prefer-cell": Flow-based balancer that maintains cell affinity while balancing load // - See the RFC here: https://github.com/vitessio/vitess/issues/12241 // - "random": Random balancer that uniformly distributes load without cell affinity -// - "session": Session balancer that pins a session to the same tablet for the duration of the session. +// - "session": Session balancer that pins a session to the same tablet for the duration of the session. If the tablet goes away, the session is automatically and transparently migrated to another tablet of the same type. // // Note: "cell" mode is handled by the gateway and does not create a balancer instance. // operates as a round robin inside of the vtgate's cell From f63dcd6aef23e9a0f2c02eee21e1157e7476cf6b Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Tue, 16 Dec 2025 11:49:15 -0500 Subject: [PATCH 65/67] switch pick opts to functional options Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/balancer.go | 10 +-- go/vt/vtgate/balancer/balancer_test.go | 8 +-- go/vt/vtgate/balancer/options.go | 62 +++++++++++++++++++ go/vt/vtgate/balancer/random_balancer.go | 2 +- go/vt/vtgate/balancer/random_balancer_test.go | 8 +-- go/vt/vtgate/balancer/session.go | 7 ++- go/vt/vtgate/balancer/session_test.go | 32 ++++------ go/vt/vtgate/tabletgateway.go | 28 +++++---- 8 files changed, 104 insertions(+), 53 deletions(-) create mode 100644 go/vt/vtgate/balancer/options.go diff --git a/go/vt/vtgate/balancer/balancer.go b/go/vt/vtgate/balancer/balancer.go index 92470c45406..7f5631bd959 100644 --- a/go/vt/vtgate/balancer/balancer.go +++ b/go/vt/vtgate/balancer/balancer.go @@ -137,18 +137,12 @@ func GetAvailableModeNames() []string { type TabletBalancer interface { // Pick is the main entry point to the balancer. Returns the best tablet out of the list // for a given query to maintain the desired balanced allocation over multiple executions. - Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth + Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts ...PickOption) *discovery.TabletHealth // DebugHandler provides a summary of tablet balancer state DebugHandler(w http.ResponseWriter, r *http.Request) } -// PickOpts are balancer options that are passed into Pick. -type PickOpts struct { - // SessionUUID is the the current session UUID. - SessionUUID string -} - // NewTabletBalancer creates a new tablet balancer based on the specified mode. // Supported modes: // - "prefer-cell": Flow-based balancer that maintains cell affinity while balancing load @@ -245,7 +239,7 @@ func (b *flowBalancer) DebugHandler(w http.ResponseWriter, _ *http.Request) { // Given the total allocation for the set of tablets, choose the best target // by a weighted random sample so that over time the system will achieve the // desired balanced allocation. -func (b *flowBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ PickOpts) *discovery.TabletHealth { +func (b *flowBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ ...PickOption) *discovery.TabletHealth { numTablets := len(tablets) if numTablets == 0 { return nil diff --git a/go/vt/vtgate/balancer/balancer_test.go b/go/vt/vtgate/balancer/balancer_test.go index 452dbc5b2bb..c2c732ebcda 100644 --- a/go/vt/vtgate/balancer/balancer_test.go +++ b/go/vt/vtgate/balancer/balancer_test.go @@ -298,7 +298,7 @@ func TestBalancedPick(t *testing.T) { b := newFlowBalancer(localCell, vtGateCells).(*flowBalancer) for i := 0; i < N/len(vtGateCells); i++ { - th := b.Pick(target, tablets, PickOpts{}) + th := b.Pick(target, tablets) if i == 0 { t.Logf("Target Flows %v, Balancer: %s\n", expectedPerCell, b.print()) } @@ -336,7 +336,7 @@ func TestTopologyChanged(t *testing.T) { tablets = tablets[0:2] for i := 0; i < N; i++ { - th := b.Pick(target, tablets, PickOpts{}) + th := b.Pick(target, tablets) allocation, totalAllocation := b.getAllocation(target, tablets) assert.Equalf(t, ALLOCATION/2, totalAllocation, "totalAllocation mismatch %s", b.print()) @@ -346,7 +346,7 @@ func TestTopologyChanged(t *testing.T) { // Run again with the full topology. Now traffic should go to cell b for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, PickOpts{}) + th := b.Pick(target, allTablets) allocation, totalAllocation := b.getAllocation(target, allTablets) @@ -359,7 +359,7 @@ func TestTopologyChanged(t *testing.T) { newTablet := createTestTablet("b") allTablets[2] = newTablet for i := 0; i < N; i++ { - th := b.Pick(target, allTablets, PickOpts{}) + th := b.Pick(target, allTablets) allocation, totalAllocation := b.getAllocation(target, allTablets) diff --git a/go/vt/vtgate/balancer/options.go b/go/vt/vtgate/balancer/options.go new file mode 100644 index 00000000000..a9313a3cb0b --- /dev/null +++ b/go/vt/vtgate/balancer/options.go @@ -0,0 +1,62 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package balancer + +// pickOptions configure a Pick call. pickOptions are set by the PickOption +// values passed to the Pick function. +type pickOptions struct { + sessionUUID string +} + +// PickOption configures how we perform the pick operation. +type PickOption interface { + apply(*pickOptions) +} + +// funcPickOption wraps a function that modifies pickOptions into an +// implementation of the PickOption interface. +type funcPickOption struct { + f func(*pickOptions) +} + +func (fpo *funcPickOption) apply(po *pickOptions) { + fpo.f(po) +} + +func newFuncPickOption(f func(*pickOptions)) *funcPickOption { + return &funcPickOption{ + f: f, + } +} + +// WithSessionUUID allows you to specify the session UUID for the session balancer. +func WithSessionUUID(sessionUUID string) PickOption { + return newFuncPickOption(func(o *pickOptions) { + o.sessionUUID = sessionUUID + }) +} + +// getOptions applies the given options to a new pickOptions struct +// and returns it. +func getOptions(opts []PickOption) *pickOptions { + options := &pickOptions{} + for _, opt := range opts { + opt.apply(options) + } + + return options +} diff --git a/go/vt/vtgate/balancer/random_balancer.go b/go/vt/vtgate/balancer/random_balancer.go index 825dfe4dc84..52c28792af7 100644 --- a/go/vt/vtgate/balancer/random_balancer.go +++ b/go/vt/vtgate/balancer/random_balancer.go @@ -73,7 +73,7 @@ type randomBalancer struct { // Pick returns a random tablet from the list with uniform probability (1/N). // If vtGateCells is configured, only tablets in those cells are considered. -func (b *randomBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ PickOpts) *discovery.TabletHealth { +func (b *randomBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, _ ...PickOption) *discovery.TabletHealth { if len(tablets) == 0 { return nil } diff --git a/go/vt/vtgate/balancer/random_balancer_test.go b/go/vt/vtgate/balancer/random_balancer_test.go index 5bd74330405..fce3306507f 100644 --- a/go/vt/vtgate/balancer/random_balancer_test.go +++ b/go/vt/vtgate/balancer/random_balancer_test.go @@ -47,7 +47,7 @@ func TestRandomBalancerUniformDistribution(t *testing.T) { pickCounts := make(map[uint32]int) for i := 0; i < numPicks; i++ { - th := b.Pick(target, tablets, PickOpts{}) + th := b.Pick(target, tablets) require.NotNil(t, th, "Pick should not return nil") pickCounts[th.Tablet.Alias.Uid]++ } @@ -66,7 +66,7 @@ func TestRandomBalancerPickEmpty(t *testing.T) { target := &querypb.Target{Keyspace: "k", Shard: "s", TabletType: topodatapb.TabletType_REPLICA} b := newRandomBalancer("cell1", []string{}) - th := b.Pick(target, []*discovery.TabletHealth{}, PickOpts{}) + th := b.Pick(target, []*discovery.TabletHealth{}) assert.Nil(t, th, "Pick should return nil for empty tablet list") } @@ -80,7 +80,7 @@ func TestRandomBalancerPickSingle(t *testing.T) { // Pick multiple times, should always return the same tablet for i := 0; i < 100; i++ { - th := b.Pick(target, tablets, PickOpts{}) + th := b.Pick(target, tablets) require.NotNil(t, th, "Pick should not return nil") assert.Equal(t, tablets[0].Tablet.Alias.Uid, th.Tablet.Alias.Uid, "Pick should return the only available tablet") @@ -133,7 +133,7 @@ func TestRandomBalancerCellFiltering(t *testing.T) { pickCounts := make(map[string]int) for i := 0; i < numPicks; i++ { - th := b.Pick(target, tablets, PickOpts{}) + th := b.Pick(target, tablets) require.NotNil(t, th) pickCounts[th.Tablet.Alias.Cell]++ } diff --git a/go/vt/vtgate/balancer/session.go b/go/vt/vtgate/balancer/session.go index ab4c0f9312c..7f70bc6dcac 100644 --- a/go/vt/vtgate/balancer/session.go +++ b/go/vt/vtgate/balancer/session.go @@ -44,8 +44,9 @@ func newSessionBalancer(localCell string) TabletBalancer { // // For a given session, it will return the same tablet for its duration, with preference to tablets // in the local cell. -func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts PickOpts) *discovery.TabletHealth { - if opts.SessionUUID == "" { +func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.TabletHealth, opts ...PickOption) *discovery.TabletHealth { + options := getOptions(opts) + if options.sessionUUID == "" { return nil } @@ -55,7 +56,7 @@ func (b *SessionBalancer) Pick(target *querypb.Target, tablets []*discovery.Tabl for _, tablet := range tablets { alias := tabletAlias(tablet) - weight := tabletWeight(alias, opts.SessionUUID) + weight := tabletWeight(alias, options.sessionUUID) if b.isLocal(tablet) && ((maxLocalTablet == nil) || (weight > maxLocalWeight)) { maxLocalWeight = weight diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index 4d78cb64154..cbf9b7e3c35 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -48,8 +48,7 @@ func TestPickNoTablets(t *testing.T) { Cell: "local", } - opts := buildOpts("a") - result := b.Pick(target, nil, opts) + result := b.Pick(target, nil, WithSessionUUID("a")) require.Nil(t, result) } @@ -102,17 +101,15 @@ func TestPickLocalOnly(t *testing.T) { } // Pick for a specific session UUID - opts := buildOpts("a") - picked1 := b.Pick(target, tablets, opts) + picked1 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, picked1) // Pick again with same session hash, should return same tablet - picked2 := b.Pick(target, tablets, opts) + picked2 := b.Pick(target, tablets, WithSessionUUID("a")) require.Equal(t, picked1, picked2, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked2))) // Pick with different session hash, empirically know that it should return tablet2 - opts = buildOpts("b") - picked3 := b.Pick(target, tablets, opts) + picked3 := b.Pick(target, tablets, WithSessionUUID("b")) require.NotNil(t, picked3) require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s for both", tabletAlias(picked3))) } @@ -184,8 +181,7 @@ func TestPickPreferLocal(t *testing.T) { } // Pick should prefer local cell - opts := buildOpts("a") - picked1 := b.Pick(target, tablets, opts) + picked1 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) } @@ -238,8 +234,7 @@ func TestPickNoLocal(t *testing.T) { } // Pick should return external cell since there are no local cells - opts := buildOpts("a") - picked1 := b.Pick(target, tablets, opts) + picked1 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) } @@ -274,8 +269,8 @@ func TestPickNoOpts(t *testing.T) { }, } - // Test with empty opts - result := b.Pick(target, tablets, PickOpts{}) + // Test with no opts (no session UUID) + result := b.Pick(target, tablets) require.Nil(t, result) } @@ -328,8 +323,7 @@ func TestPickInvalidTablets(t *testing.T) { } // Get a tablet regularly - opts := buildOpts("a") - tablet := b.Pick(target, tablets, opts) + tablet := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, tablet) // Filter out the returned tablet as invalid @@ -338,15 +332,11 @@ func TestPickInvalidTablets(t *testing.T) { }) // Pick should now return a different tablet - tablet2 := b.Pick(target, tablets, opts) + tablet2 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, tablet2) require.NotEqual(t, tablet, tablet2) // Filter out the last tablet, Pick should return nothing - tablet3 := b.Pick(target, []*discovery.TabletHealth{}, opts) + tablet3 := b.Pick(target, []*discovery.TabletHealth{}, WithSessionUUID("a")) require.Nil(t, tablet3) } - -func buildOpts(uuid string) PickOpts { - return PickOpts{SessionUUID: uuid} -} diff --git a/go/vt/vtgate/tabletgateway.go b/go/vt/vtgate/tabletgateway.go index c153aad4f4a..bfe9508691c 100644 --- a/go/vt/vtgate/tabletgateway.go +++ b/go/vt/vtgate/tabletgateway.go @@ -109,6 +109,9 @@ type TabletGateway struct { // balancer used for routing to tablets balancer balancer.TabletBalancer + + // balancerMode is the current tablet balancer mode. + balancerMode balancer.Mode } func createHealthCheck(ctx context.Context, retryDelay, timeout time.Duration, ts *topo.Server, cell, cellsToWatch string) discovery.HealthCheck { @@ -182,38 +185,37 @@ func (gw *TabletGateway) setupBalancer() { } // Determine the effective mode: new flag takes precedence, then deprecated flag, then default - var mode balancer.Mode if balancerModeFlag != "" { // Explicit new flag - mode = balancer.ParseMode(balancerModeFlag) + gw.balancerMode = balancer.ParseMode(balancerModeFlag) } else if balancerEnabled { // Deprecated flag for backwards compatibility log.Warning("Flag --enable-balancer is deprecated. Please use --vtgate-balancer-mode=prefer-cell instead.") - mode = balancer.ModePreferCell + gw.balancerMode = balancer.ModePreferCell } else { // Default: no flags set - mode = balancer.ModeCell + gw.balancerMode = balancer.ModeCell } // Cell mode uses the default shuffleTablets behavior, no balancer needed - if mode == balancer.ModeCell { + if gw.balancerMode == balancer.ModeCell { log.Info("Tablet balancer using 'cell' mode (shuffle tablets in local cell)") return } // Validate mode-specific requirements - if mode == balancer.ModePreferCell && len(balancerVtgateCells) == 0 { + if gw.balancerMode == balancer.ModePreferCell && len(balancerVtgateCells) == 0 { log.Exitf("--balancer-vtgate-cells is required when using --vtgate-balancer-mode=prefer-cell") } // Create the balancer for prefer-cell or random modes var err error - gw.balancer, err = balancer.NewTabletBalancer(mode, gw.localCell, balancerVtgateCells) + gw.balancer, err = balancer.NewTabletBalancer(gw.balancerMode, gw.localCell, balancerVtgateCells) if err != nil { log.Exitf("Failed to create tablet balancer: %v", err) } - log.Infof("Tablet balancer enabled with mode: %s", mode) + log.Infof("Tablet balancer enabled with mode: %s", gw.balancerMode) } // QueryServiceByAlias satisfies the Gateway interface @@ -463,12 +465,14 @@ func (gw *TabletGateway) getBalancerTablet(target *querypb.Target, tablets []*di // Get the tablet from the balancer if enabled if useBalancer { - pickOpts := balancer.PickOpts{} - if opts.Session != nil { - pickOpts.SessionUUID = opts.Session.GetSessionUUID() + var pickOpts []balancer.PickOption + + // Add the session UUID to the options if the session balancer is enabled and a session is present. + if gw.balancerMode == balancer.ModeSession && opts.Session != nil { + pickOpts = append(pickOpts, balancer.WithSessionUUID(opts.Session.GetSessionUUID())) } - tablet := gw.balancer.Pick(target, tablets, pickOpts) + tablet := gw.balancer.Pick(target, tablets, pickOpts...) if tablet != nil { return tablet } From a3b213a8c70cef5687fc1eb01bfd89398d2dfe26 Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 17 Dec 2025 08:47:07 -0500 Subject: [PATCH 66/67] Update options.go Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/options.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/balancer/options.go b/go/vt/vtgate/balancer/options.go index a9313a3cb0b..9ca065b7285 100644 --- a/go/vt/vtgate/balancer/options.go +++ b/go/vt/vtgate/balancer/options.go @@ -1,5 +1,5 @@ /* -Copyright 2024 The Vitess Authors. +Copyright 2025 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 5d6d428cfd565dca2ccf122ce91e9438f722359f Mon Sep 17 00:00:00 2001 From: Mohamed Hamza Date: Wed, 17 Dec 2025 13:42:56 -0500 Subject: [PATCH 67/67] Test connection stickiness Signed-off-by: Mohamed Hamza --- go/vt/vtgate/balancer/session_test.go | 31 +++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/go/vt/vtgate/balancer/session_test.go b/go/vt/vtgate/balancer/session_test.go index cbf9b7e3c35..3b127221bcf 100644 --- a/go/vt/vtgate/balancer/session_test.go +++ b/go/vt/vtgate/balancer/session_test.go @@ -29,16 +29,13 @@ import ( "vitess.io/vitess/go/vt/topo/topoproto" ) -func createSessionBalancer(t *testing.T) *SessionBalancer { +func createSessionBalancer(t *testing.T) TabletBalancer { t.Helper() - b := newSessionBalancer("local") - sb := b.(*SessionBalancer) - - return sb + return newSessionBalancer("local") } -func TestPickNoTablets(t *testing.T) { +func TestSessionPickNoTablets(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{ @@ -52,7 +49,7 @@ func TestPickNoTablets(t *testing.T) { require.Nil(t, result) } -func TestPickLocalOnly(t *testing.T) { +func TestSessionPickLocalOnly(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{ @@ -114,7 +111,7 @@ func TestPickLocalOnly(t *testing.T) { require.NotEqual(t, picked2, picked3, fmt.Sprintf("expected different tablets, got %s for both", tabletAlias(picked3))) } -func TestPickPreferLocal(t *testing.T) { +func TestSessionPickPreferLocal(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{ @@ -184,9 +181,15 @@ func TestPickPreferLocal(t *testing.T) { picked1 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, picked1) require.Equal(t, "local", picked1.Target.Cell) + + // Pick should pick the same tablet consistently + for range 20 { + picked := b.Pick(target, tablets, WithSessionUUID("a")) + require.Equal(t, picked1, picked, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked))) + } } -func TestPickNoLocal(t *testing.T) { +func TestSessionPickNoLocal(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{ @@ -237,9 +240,15 @@ func TestPickNoLocal(t *testing.T) { picked1 := b.Pick(target, tablets, WithSessionUUID("a")) require.NotNil(t, picked1) require.Equal(t, "external", picked1.Target.Cell) + + // Pick should pick the same tablet consistently + for range 20 { + picked := b.Pick(target, tablets, WithSessionUUID("a")) + require.Equal(t, picked1, picked, fmt.Sprintf("expected %s, got %s", tabletAlias(picked1), tabletAlias(picked))) + } } -func TestPickNoOpts(t *testing.T) { +func TestSessionPickNoOpts(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{ @@ -274,7 +283,7 @@ func TestPickNoOpts(t *testing.T) { require.Nil(t, result) } -func TestPickInvalidTablets(t *testing.T) { +func TestSessionPickInvalidTablets(t *testing.T) { b := createSessionBalancer(t) target := &querypb.Target{