Skip to content

Commit a8d509a

Browse files
authored
Merge pull request #2242 from karthik2804/allow_multiple_redis_server
allow redis trigger to connect to multiple servers
2 parents aeb51e8 + 473383a commit a8d509a

File tree

4 files changed

+85
-48
lines changed

4 files changed

+85
-48
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/trigger-redis/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ spin-trigger = { path = "../trigger" }
1919
spin-world = { path = "../world" }
2020
redis = { version = "0.21", features = ["tokio-comp"] }
2121
tracing = { workspace = true }
22+
tokio = { version = "1.23", features = ["full"] }
2223

2324
[dev-dependencies]
2425
spin-testing = { path = "../testing" }

crates/trigger-redis/src/lib.rs

Lines changed: 81 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,27 @@
22
33
mod spin;
44

5-
use std::collections::HashMap;
6-
75
use anyhow::{anyhow, Context, Result};
86
use futures::{future::join_all, StreamExt};
97
use redis::{Client, ConnectionLike};
108
use serde::{de::IgnoredAny, Deserialize, Serialize};
119
use spin_core::async_trait;
1210
use spin_trigger::{cli::NoArgs, TriggerAppEngine, TriggerExecutor};
11+
use std::collections::HashMap;
12+
use std::sync::Arc;
1313

1414
use crate::spin::SpinRedisExecutor;
1515

1616
pub(crate) type RuntimeData = ();
1717
pub(crate) type Store = spin_core::Store<RuntimeData>;
1818

19+
type ChannelComponents = HashMap<String, Vec<String>>;
1920
/// The Spin Redis trigger.
21+
#[derive(Clone)]
2022
pub struct RedisTrigger {
21-
engine: TriggerAppEngine<Self>,
22-
// Redis address to connect to
23-
address: String,
24-
// Mapping of subscription channels to component IDs
25-
channel_components: HashMap<String, Vec<String>>,
23+
engine: Arc<TriggerAppEngine<Self>>,
24+
// Mapping of server url with subscription channel and associated component IDs
25+
server_channels: HashMap<String, ChannelComponents>,
2626
}
2727

2828
/// Redis trigger configuration.
@@ -33,6 +33,8 @@ pub struct RedisTriggerConfig {
3333
pub component: String,
3434
/// Channel to subscribe to
3535
pub channel: String,
36+
/// optional overide address for trigger
37+
pub address: Option<String>,
3638
/// Trigger executor (currently unused)
3739
#[serde(default, skip_serializing)]
3840
pub executor: IgnoredAny,
@@ -52,32 +54,90 @@ impl TriggerExecutor for RedisTrigger {
5254
type RunConfig = NoArgs;
5355

5456
async fn new(engine: TriggerAppEngine<Self>) -> Result<Self> {
55-
let address = engine
57+
let default_address: String = engine
5658
.trigger_metadata::<TriggerMetadata>()?
5759
.unwrap_or_default()
5860
.address;
59-
let address_expr = spin_expressions::Template::new(address)?;
60-
let address = engine.resolve_template(&address_expr)?;
61+
let default_address_expr = spin_expressions::Template::new(default_address)?;
62+
let default_address = engine.resolve_template(&default_address_expr)?;
6163

62-
let mut channel_components: HashMap<String, Vec<String>> = HashMap::new();
64+
let mut server_channels: HashMap<String, ChannelComponents> = HashMap::new();
6365

6466
for (_, config) in engine.trigger_configs() {
65-
channel_components
67+
let address = config.address.clone().unwrap_or(default_address.clone());
68+
let address_expr = spin_expressions::Template::new(address)?;
69+
let address = engine.resolve_template(&address_expr)?;
70+
let server = server_channels.entry(address).or_default();
71+
server
6672
.entry(config.channel.clone())
6773
.or_default()
6874
.push(config.component.clone());
6975
}
7076
Ok(Self {
71-
engine,
72-
address,
73-
channel_components,
77+
engine: Arc::new(engine),
78+
server_channels,
7479
})
7580
}
7681

7782
/// Run the Redis trigger indefinitely.
7883
async fn run(self, _config: Self::RunConfig) -> Result<()> {
79-
let address = &self.address;
84+
let tasks: Vec<_> = self
85+
.server_channels
86+
.clone()
87+
.into_iter()
88+
.map(|(server_address, channel_components)| {
89+
let trigger = self.clone();
90+
tokio::spawn(async move {
91+
trigger
92+
.run_listener(server_address.clone(), channel_components.clone())
93+
.await
94+
})
95+
})
96+
.collect();
97+
98+
// wait for the first handle to be returned and drop the rest
99+
let (_, _, rest) = futures::future::select_all(tasks).await;
100+
drop(rest);
101+
102+
Ok(())
103+
}
104+
}
80105

106+
impl RedisTrigger {
107+
// Handle the message.
108+
async fn handle(
109+
&self,
110+
address: &str,
111+
channel_components: &ChannelComponents,
112+
msg: redis::Msg,
113+
) -> Result<()> {
114+
let channel = msg.get_channel_name();
115+
tracing::info!("Received message on channel {address}:{:?}", channel);
116+
117+
if let Some(component_ids) = channel_components.get(channel) {
118+
let futures = component_ids.iter().map(|id| {
119+
tracing::trace!("Executing Redis component {id:?}");
120+
SpinRedisExecutor.execute(&self.engine, id, channel, msg.get_payload_bytes())
121+
});
122+
let results: Vec<_> = join_all(futures).await.into_iter().collect();
123+
let errors = results
124+
.into_iter()
125+
.filter_map(|r| r.err())
126+
.collect::<Vec<_>>();
127+
if !errors.is_empty() {
128+
return Err(anyhow!("{errors:#?}"));
129+
}
130+
} else {
131+
tracing::debug!("No subscription found for {:?}", channel);
132+
}
133+
Ok(())
134+
}
135+
136+
async fn run_listener(
137+
&self,
138+
address: String,
139+
channel_components: ChannelComponents,
140+
) -> Result<()> {
81141
tracing::info!("Connecting to Redis server at {}", address);
82142
let mut client = Client::open(address.to_string())?;
83143
let mut pubsub = client
@@ -88,55 +148,30 @@ impl TriggerExecutor for RedisTrigger {
88148

89149
println!("Active Channels on {address}:");
90150
// Subscribe to channels
91-
for (channel, component) in self.channel_components.iter() {
151+
for (channel, component) in channel_components.iter() {
92152
tracing::info!("Subscribing component {component:?} to channel {channel:?}");
93153
pubsub.subscribe(channel).await?;
94-
println!("\t{channel}: [{}]", component.join(","));
154+
println!("\t{address}:{channel}: [{}]", component.join(","));
95155
}
96156

97157
let mut stream = pubsub.on_message();
98158
loop {
99159
match stream.next().await {
100160
Some(msg) => {
101-
if let Err(err) = self.handle(msg).await {
161+
if let Err(err) = self.handle(&address, &channel_components, msg).await {
102162
tracing::warn!("Error handling message: {err}");
103163
}
104164
}
105165
None => {
106166
tracing::trace!("Empty message");
107167
if !client.check_connection() {
108168
tracing::info!("No Redis connection available");
109-
break Ok(());
169+
println!("Disconnected from {address}");
170+
break;
110171
}
111172
}
112173
};
113174
}
114-
}
115-
}
116-
117-
impl RedisTrigger {
118-
// Handle the message.
119-
async fn handle(&self, msg: redis::Msg) -> Result<()> {
120-
let channel = msg.get_channel_name();
121-
tracing::info!("Received message on channel {:?}", channel);
122-
123-
if let Some(component_ids) = self.channel_components.get(channel) {
124-
let futures = component_ids.iter().map(|id| {
125-
tracing::trace!("Executing Redis component {id:?}");
126-
SpinRedisExecutor.execute(&self.engine, id, channel, msg.get_payload_bytes())
127-
});
128-
let results: Vec<_> = join_all(futures).await.into_iter().collect();
129-
let errors = results
130-
.into_iter()
131-
.filter_map(|r| r.err())
132-
.collect::<Vec<_>>();
133-
if !errors.is_empty() {
134-
return Err(anyhow!("{errors:#?}"));
135-
}
136-
} else {
137-
tracing::debug!("No subscription found for {:?}", channel);
138-
}
139-
140175
Ok(())
141176
}
142177
}

crates/trigger-redis/src/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ async fn test_pubsub() -> Result<()> {
1818
.test_program("redis-rust.wasm")
1919
.build_trigger("messages")
2020
.await;
21-
21+
let test = HashMap::new();
2222
let msg = create_trigger_event("messages", "hello");
23-
trigger.handle(msg).await?;
23+
trigger.handle("", &test, msg).await?;
2424

2525
Ok(())
2626
}

0 commit comments

Comments
 (0)