Skip to content

Commit 1349de3

Browse files
committed
feat: expose shard information in redis.Ring
- Add GetShards() method to retrieve a list of active shard clients. - Add GetShardByKey(key string) method to get the shard client for a specific key. - These methods enable users to manage Pub/Sub operations more effectively by accessing shard-specific clients.
1 parent d7ba255 commit 1349de3

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

ring.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,3 +847,22 @@ func (c *Ring) Close() error {
847847

848848
return c.sharding.Close()
849849
}
850+
851+
func (c *Ring) GetShards() []*Client {
852+
shards := c.sharding.List()
853+
clients := make([]*Client, 0, len(shards))
854+
for _, shard := range shards {
855+
if shard.IsUp() {
856+
clients = append(clients, shard.Client)
857+
}
858+
}
859+
return clients
860+
}
861+
862+
func (c *Ring) GetShardByKey(key string) (*Client, error) {
863+
shard, err := c.sharding.GetByKey(key)
864+
if err != nil {
865+
return nil, err
866+
}
867+
return shard.Client, nil
868+
}

ring_test.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,3 +766,74 @@ var _ = Describe("Ring Tx timeout", func() {
766766
testTimeout()
767767
})
768768
})
769+
770+
var _ = Describe("Ring GetShards and GetShardByKey", func() {
771+
var ring *redis.Ring
772+
773+
BeforeEach(func() {
774+
ring = redis.NewRing(&redis.RingOptions{
775+
Addrs: map[string]string{
776+
"shard1": ":6379",
777+
"shard2": ":6380",
778+
},
779+
})
780+
})
781+
782+
AfterEach(func() {
783+
Expect(ring.Close()).NotTo(HaveOccurred())
784+
})
785+
786+
It("GetShards returns active shard clients", func() {
787+
shards := ring.GetShards()
788+
if len(shards) == 0 {
789+
// Expected if Redis servers are not running
790+
Skip("No active shards found (Redis servers not running)")
791+
} else {
792+
Expect(len(shards)).To(BeNumerically(">", 0))
793+
for _, client := range shards {
794+
Expect(client).NotTo(BeNil())
795+
}
796+
}
797+
})
798+
799+
It("GetShardByKey returns correct shard for keys", func() {
800+
testKeys := []string{"key1", "key2", "user:123", "channel:test"}
801+
802+
for _, key := range testKeys {
803+
client, err := ring.GetShardByKey(key)
804+
Expect(err).NotTo(HaveOccurred())
805+
Expect(client).NotTo(BeNil())
806+
}
807+
})
808+
809+
It("GetShardByKey is consistent for same key", func() {
810+
key := "test:consistency"
811+
812+
var firstClient *redis.Client
813+
for i := 0; i < 5; i++ {
814+
client, err := ring.GetShardByKey(key)
815+
Expect(err).NotTo(HaveOccurred())
816+
Expect(client).NotTo(BeNil())
817+
818+
if i == 0 {
819+
firstClient = client
820+
} else {
821+
Expect(client.String()).To(Equal(firstClient.String()))
822+
}
823+
}
824+
})
825+
826+
It("GetShardByKey distributes keys across shards", func() {
827+
testKeys := []string{"key1", "key2", "key3", "key4", "key5"}
828+
shardMap := make(map[string]int)
829+
830+
for _, key := range testKeys {
831+
client, err := ring.GetShardByKey(key)
832+
Expect(err).NotTo(HaveOccurred())
833+
shardMap[client.String()]++
834+
}
835+
836+
Expect(len(shardMap)).To(BeNumerically(">=", 1))
837+
Expect(len(shardMap)).To(BeNumerically("<=", 2)) // At most 2 shards (our setup)
838+
})
839+
})

0 commit comments

Comments
 (0)