Skip to content

Commit c6e8a79

Browse files
authored
fix: enbuged for Pub/Sub (#222)
1 parent c72a3d3 commit c6e8a79

File tree

2 files changed

+102
-60
lines changed

2 files changed

+102
-60
lines changed

lib/redis_client/cluster/pub_sub.rb

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,45 @@
33
class RedisClient
44
class Cluster
55
class PubSub
6-
MAX_THREADS = Integer(ENV.fetch('REDIS_CLIENT_MAX_THREADS', 5))
6+
class State
7+
def initialize(client)
8+
@client = client
9+
@worker = nil
10+
end
11+
12+
def call(command)
13+
@client.call_v(command)
14+
end
15+
16+
def close
17+
@worker.exit if @worker&.alive?
18+
@client.close
19+
end
20+
21+
def take_message(timeout)
22+
@worker = subscribe(@client, timeout) if @worker.nil?
23+
return if @worker.join(0.01).nil?
24+
25+
message = @worker[:reply]
26+
@worker = nil
27+
message
28+
end
29+
30+
private
31+
32+
def subscribe(client, timeout)
33+
Thread.new(client, timeout) do |pubsub, to|
34+
Thread.current[:reply] = pubsub.next_event(to)
35+
rescue StandardError => e
36+
Thread.current[:reply] = e
37+
end
38+
end
39+
end
740

841
def initialize(router, command_builder)
942
@router = router
1043
@command_builder = command_builder
11-
@pubsub_states = {}
12-
@messages = []
44+
@states = {}
1345
end
1446

1547
def call(*args, **kwargs)
@@ -21,46 +53,39 @@ def call_v(command)
2153
end
2254

2355
def close
24-
@pubsub_states.each_value(&:close)
25-
@pubsub_states.clear
26-
@messages.clear
56+
@states.each_value(&:close)
57+
@states.clear
2758
end
2859

2960
def next_event(timeout = nil)
30-
return if @pubsub_states.empty?
31-
return @messages.shift unless @messages.empty?
61+
return if @states.empty?
3262

33-
collect_messages(timeout)
34-
@messages.shift
63+
max_duration = calc_max_duration(timeout)
64+
starting = obtain_current_time
65+
loop do
66+
break if max_duration > 0 && obtain_current_time - starting > max_duration
67+
68+
@states.each_value do |pubsub|
69+
message = pubsub.take_message(timeout)
70+
return message if message
71+
end
72+
end
3573
end
3674

3775
private
3876

3977
def _call(command)
4078
node_key = @router.find_node_key(command)
41-
pubsub = if @pubsub_states.key?(node_key)
42-
@pubsub_states[node_key]
43-
else
44-
@pubsub_states[node_key] = @router.find_node(node_key).pubsub
45-
end
46-
pubsub.call_v(command)
79+
@states[node_key] = State.new(@router.find_node(node_key).pubsub) unless @states.key?(node_key)
80+
@states[node_key].call(command)
4781
end
4882

49-
def collect_messages(timeout)
50-
@pubsub_states.each_slice(MAX_THREADS) do |chuncked_pubsub_states|
51-
threads = chuncked_pubsub_states.map do |_, v|
52-
Thread.new(v) do |pubsub|
53-
Thread.current[:reply] = pubsub.next_event(timeout)
54-
rescue StandardError => e
55-
Thread.current[:reply] = e
56-
end
57-
end
83+
def obtain_current_time
84+
Process.clock_gettime(Process::CLOCK_MONOTONIC, :microsecond)
85+
end
5886

59-
threads.each do |t|
60-
t.join
61-
@messages << t[:reply] unless t[:reply].nil?
62-
end
63-
end
87+
def calc_max_duration(timeout)
88+
timeout.nil? || timeout < 0 ? 0 : timeout * 1_000_000
6489
end
6590
end
6691
end

test/redis_client/test_cluster.rb

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,28 @@ def test_global_pubsub
194194
assert_equal(['subscribe', channel, 1], pubsub.next_event(TEST_TIMEOUT_SEC))
195195
Fiber.yield(channel)
196196
Fiber.yield(pubsub.next_event(TEST_TIMEOUT_SEC))
197-
pubsub.call('UNSUBSCRIBE')
198197
pubsub.close
199198
end
200199

201200
channel = sub.resume(@client.pubsub)
202-
publish_messages do |cli|
203-
cli.call('PUBLISH', channel, 'hello global world')
201+
publish_messages { |cli| cli.call('PUBLISH', channel, 'hello global world') }
202+
assert_equal(['message', channel, 'hello global world'], sub.resume)
203+
end
204+
205+
def test_global_pubsub_without_timeout
206+
sub = Fiber.new do |pubsub|
207+
pubsub.call('SUBSCRIBE', 'my-global-not-published-channel', 'my-global-published-channel')
208+
want = [%w[subscribe my-global-not-published-channel], %w[subscribe my-global-published-channel]]
209+
got = collect_messages(pubsub, size: 2, timeout: nil).map { |e| e.take(2) }.sort_by { |e| e[1].to_s }
210+
assert_equal(want, got)
211+
Fiber.yield('my-global-published-channel')
212+
Fiber.yield(collect_messages(pubsub, size: 1, timeout: nil).first)
213+
pubsub.close
204214
end
205215

206-
assert_equal(['message', channel, 'hello global world'], sub.resume)
216+
channel = sub.resume(@client.pubsub)
217+
publish_messages { |cli| cli.call('PUBLISH', channel, 'hello global published world') }
218+
assert_equal(['message', channel, 'hello global published world'], sub.resume)
207219
end
208220

209221
def test_global_pubsub_with_multiple_channels
@@ -214,19 +226,15 @@ def test_global_pubsub_with_multiple_channels
214226

215227
sub = Fiber.new do |pubsub|
216228
pubsub.call('SUBSCRIBE', *Array.new(10) { |i| "g-chan#{i}" })
217-
got = collect_messages(pubsub).sort_by { |e| e[1].to_s }
229+
got = collect_messages(pubsub, size: 10).sort_by { |e| e[1].to_s }
218230
10.times { |i| assert_equal(['subscribe', "g-chan#{i}", i + 1], got[i]) }
219231
Fiber.yield
220-
Fiber.yield(collect_messages(pubsub))
221-
pubsub.call('UNSUBSCRIBE')
232+
Fiber.yield(collect_messages(pubsub, size: 10))
222233
pubsub.close
223234
end
224235

225236
sub.resume(@client.pubsub)
226-
publish_messages do |cli|
227-
cli.pipelined { |pi| 10.times { |i| pi.call('PUBLISH', "g-chan#{i}", i) } }
228-
end
229-
237+
publish_messages { |cli| cli.pipelined { |pi| 10.times { |i| pi.call('PUBLISH', "g-chan#{i}", i) } } }
230238
got = sub.resume.sort_by { |e| e[1].to_s }
231239
10.times { |i| assert_equal(['message', "g-chan#{i}", i.to_s], got[i]) }
232240
end
@@ -243,16 +251,34 @@ def test_sharded_pubsub
243251
assert_equal(['ssubscribe', channel, 1], pubsub.next_event(TEST_TIMEOUT_SEC))
244252
Fiber.yield(channel)
245253
Fiber.yield(pubsub.next_event(TEST_TIMEOUT_SEC))
246-
pubsub.call('SUNSUBSCRIBE')
247254
pubsub.close
248255
end
249256

250257
channel = sub.resume(@client.pubsub)
251-
publish_messages do |cli|
252-
cli.call('SPUBLISH', channel, 'hello sharded world')
258+
publish_messages { |cli| cli.call('SPUBLISH', channel, 'hello sharded world') }
259+
assert_equal(['smessage', channel, 'hello sharded world'], sub.resume)
260+
end
261+
262+
def test_sharded_pubsub_without_timeout
263+
if TEST_REDIS_MAJOR_VERSION < 7
264+
skip('Sharded Pub/Sub is supported by Redis 7+.')
265+
return
253266
end
254267

255-
assert_equal(['smessage', channel, 'hello sharded world'], sub.resume)
268+
sub = Fiber.new do |pubsub|
269+
pubsub.call('SSUBSCRIBE', 'my-sharded-not-published-channel')
270+
pubsub.call('SSUBSCRIBE', 'my-sharded-published-channel')
271+
want = [%w[ssubscribe my-sharded-not-published-channel], %w[ssubscribe my-sharded-published-channel]]
272+
got = collect_messages(pubsub, size: 2, timeout: nil).map { |e| e.take(2) }.sort_by { |e| e[1].to_s }
273+
assert_equal(want, got)
274+
Fiber.yield('my-sharded-published-channel')
275+
Fiber.yield(collect_messages(pubsub, size: 1, timeout: nil).first)
276+
pubsub.close
277+
end
278+
279+
channel = sub.resume(@client.pubsub)
280+
publish_messages { |cli| cli.call('SPUBLISH', channel, 'hello sharded published world') }
281+
assert_equal(['smessage', channel, 'hello sharded published world'], sub.resume)
256282
end
257283

258284
def test_sharded_pubsub_with_multiple_channels
@@ -268,19 +294,15 @@ def test_sharded_pubsub_with_multiple_channels
268294

269295
sub = Fiber.new do |pubsub|
270296
10.times { |i| pubsub.call('SSUBSCRIBE', "s-chan#{i}") }
271-
got = collect_messages(pubsub).sort_by { |e| e[1].to_s }
297+
got = collect_messages(pubsub, size: 10).sort_by { |e| e[1].to_s }
272298
10.times { |i| assert_equal(['ssubscribe', "s-chan#{i}"], got[i].take(2)) }
273299
Fiber.yield
274-
Fiber.yield(collect_messages(pubsub))
275-
pubsub.call('SUNSUBSCRIBE')
300+
Fiber.yield(collect_messages(pubsub, size: 10))
276301
pubsub.close
277302
end
278303

279304
sub.resume(@client.pubsub)
280-
publish_messages do |cli|
281-
cli.pipelined { |pi| 10.times { |i| pi.call('SPUBLISH', "s-chan#{i}", i) } }
282-
end
283-
305+
publish_messages { |cli| cli.pipelined { |pi| 10.times { |i| pi.call('SPUBLISH', "s-chan#{i}", i) } } }
284306
got = sub.resume.sort_by { |e| e[1].to_s }
285307
10.times { |i| assert_equal(['smessage', "s-chan#{i}", i.to_s], got[i]) }
286308
end
@@ -386,7 +408,7 @@ def wait_for_replication
386408
@client&.blocking_call(client_side_timeout, 'WAIT', TEST_REPLICA_SIZE, server_side_timeout)
387409
end
388410

389-
def collect_messages(pubsub, max_attempts: 30, timeout: 1.0)
411+
def collect_messages(pubsub, size:, max_attempts: 30, timeout: 1.0)
390412
messages = []
391413
attempts = 0
392414
loop do
@@ -396,14 +418,9 @@ def collect_messages(pubsub, max_attempts: 30, timeout: 1.0)
396418
reply = pubsub.next_event(timeout)
397419
break if reply.nil?
398420

399-
if reply.first.is_a?(Array)
400-
messages += reply
401-
else
402-
messages << reply
403-
end
421+
messages << reply
422+
break messages if messages.size == size
404423
end
405-
406-
messages
407424
end
408425

409426
def publish_messages

0 commit comments

Comments
 (0)