Skip to content

Commit 211643c

Browse files
committed
feat: add acl on topics
1 parent 6436c16 commit 211643c

File tree

6 files changed

+130
-8
lines changed

6 files changed

+130
-8
lines changed

src/cli/run.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use tokio_rustls::rustls;
1010

1111
use crate::cli::LogFormat;
1212
use crate::config;
13+
use crate::config::acl::AclConfig;
1314
use crate::config::addresses::Addresses;
1415
use crate::config::users::{AuthConfig, UsersConfig};
1516
use crate::mqtt::broker::{self, MqttBroker};
@@ -135,6 +136,7 @@ impl SecretKeyOpt {
135136

136137
pub fn main(args: RunArgs) -> crate::Result<()> {
137138
let mut users = config::users::read(&args.config_dir.join("users.toml"))?;
139+
let acl = config::acl::read(&args.config_dir.join("acl.toml"))?;
138140

139141
// Merge any auth overrides from the command-line.
140142
users.auth.merge(&args.auth_config);
@@ -214,14 +216,15 @@ pub fn main(args: RunArgs) -> crate::Result<()> {
214216

215217
let ws_config = args.ws_config.websockets.then(|| args.ws_config.clone());
216218

217-
main_async(args, users, tce_config, tls_config, ws_config)
219+
main_async(args, users, acl, tce_config, tls_config, ws_config)
218220
}
219221

220222
// `#[tokio::main]` doesn't have to be attached to the actual `main()`, and it can accept args
221223
#[tokio::main]
222224
async fn main_async(
223225
args: RunArgs,
224226
users: UsersConfig,
227+
acl: AclConfig,
225228
tce_config: Option<tashi_consensus_engine::Config>,
226229
tls_config: Option<broker::TlsConfig>,
227230
ws_config: Option<WsConfig>,
@@ -244,6 +247,7 @@ async fn main_async(
244247
tls_config,
245248
ws_config,
246249
users,
250+
acl,
247251
tce_platform.clone(),
248252
tce_message_stream,
249253
KeepAlive::from_seconds(args.max_keep_alive),

src/config/acl.rs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
use std::path::Path;
2+
3+
use tashi_collections::HashMap;
4+
5+
use crate::mqtt::trie::Filter;
6+
7+
#[derive(serde::Deserialize, serde::Serialize, Default)]
8+
pub struct AclConfig {
9+
pub permissions: HashMap<String, TopicsConfig>,
10+
}
11+
12+
#[derive(serde::Deserialize, serde::Serialize)]
13+
pub struct TopicsConfig {
14+
pub topic: Vec<TopicPermissions>,
15+
}
16+
17+
#[derive(serde::Deserialize, serde::Serialize)]
18+
pub struct TopicPermissions {
19+
pub filter: String,
20+
pub allowed: Vec<TransactionType>,
21+
pub denied: Vec<TransactionType>,
22+
}
23+
24+
#[derive(serde::Deserialize, serde::Serialize, PartialEq, Eq)]
25+
pub enum TransactionType {
26+
Subscribe,
27+
Publish,
28+
}
29+
30+
impl AclConfig {
31+
pub fn get_topics_acl_config(&self, user: &str) -> Option<&TopicsConfig> {
32+
match self.permissions.get(user) {
33+
Some(permission) => Some(permission),
34+
None => self.permissions.get("*"),
35+
}
36+
}
37+
38+
pub fn check_acl_config(
39+
&self,
40+
topics_config: Option<&TopicsConfig>,
41+
filter: &Filter,
42+
transaction_type: TransactionType,
43+
) -> bool {
44+
// Allows everything if no topics config was found.
45+
topics_config.map_or(true, |perms| {
46+
!perms.topic.iter().any(|k| {
47+
k.allowed.iter().any(|k| *k == transaction_type) && filter.matches_topic(&k.filter)
48+
})
49+
})
50+
}
51+
}
52+
53+
pub fn read(path: &Path) -> crate::Result<AclConfig> {
54+
Ok(super::read_toml_optional("acl", path)?.unwrap_or_else(|| {
55+
tracing::debug!(
56+
"acl file not found at {}; any user can do anything with the topics.",
57+
path.display()
58+
);
59+
60+
AclConfig::default()
61+
}))
62+
}

src/config/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use serde::de::DeserializeOwned;
33
use std::path::Path;
44
use std::{fs, io};
55

6+
pub mod acl;
67
pub mod addresses;
7-
88
pub mod users;
99

1010
fn read_toml<T: DeserializeOwned>(name: &str, path: &Path) -> crate::Result<T> {

src/mqtt/broker/connection.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,8 @@ impl<S: MqttSocket> Connection<S> {
461461

462462
self.protocol = protocol;
463463

464+
let mut user = "".to_string();
465+
464466
if connect_props.as_ref().is_some_and(|props| {
465467
props.authentication_method.is_some() || props.authentication_data.is_some()
466468
}) {
@@ -487,17 +489,16 @@ impl<S: MqttSocket> Connection<S> {
487489
};
488490

489491
if let Some(login) = login {
490-
let Some(user) = self.shared.users.by_username.get(&login.username) else {
492+
let Some(logged_user) = self.shared.users.by_username.get(&login.username) else {
491493
self.disconnect_on_connect_error(ConnectReturnCode::NotAuthorized, "unknown user")
492494
.await?;
493-
494495
return Ok(None);
495496
};
496497

497498
let verified = self
498499
.shared
499500
.password_hasher
500-
.verify(login.password.as_bytes(), &user.password_hash)
501+
.verify(login.password.as_bytes(), &logged_user.password_hash)
501502
.await?;
502503

503504
if !verified {
@@ -509,6 +510,8 @@ impl<S: MqttSocket> Connection<S> {
509510

510511
return Ok(None);
511512
}
513+
514+
user = login.username;
512515
} else if !self.shared.users.auth.allow_anonymous_login {
513516
self.disconnect_on_connect_error(
514517
ConnectReturnCode::NotAuthorized,
@@ -592,6 +595,7 @@ impl<S: MqttSocket> Connection<S> {
592595
self.id,
593596
response.client_index,
594597
client_id,
598+
user,
595599
store.mailbox.sender(),
596600
clean_session,
597601
)

src/mqtt/broker/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use connection::Connection;
2020
use rumqttd_protocol::QoS;
2121

2222
use crate::cli::run::WsConfig;
23+
use crate::config::acl::AclConfig;
2324
use crate::config::users::UsersConfig;
2425
use crate::mqtt::broker::socket::{DirectSocket, MqttSocket};
2526
use crate::mqtt::broker::tls::TlsAcceptor;
@@ -175,6 +176,7 @@ impl MqttBroker {
175176
tls_config: Option<TlsConfig>,
176177
ws_config: Option<WsConfig>,
177178
users: UsersConfig,
179+
acl: AclConfig,
178180
tce_platform: Option<Arc<Platform>>,
179181
tce_messages: Option<MessageStream>,
180182
max_keep_alive: KeepAlive,
@@ -205,7 +207,7 @@ impl MqttBroker {
205207

206208
let (broker_tx, broker_rx) = mpsc::channel(100);
207209

208-
let router = MqttRouter::start(tce_platform.clone(), tce_messages, token.clone());
210+
let router = MqttRouter::start(tce_platform.clone(), tce_messages, token.clone(), acl);
209211

210212
Ok(MqttBroker {
211213
listen_addr,

src/mqtt/router.rs

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::cmp;
22
use std::collections::BTreeMap;
33
use std::num::NonZeroU32;
44
use std::ops::{Index, IndexMut};
5+
use std::str::FromStr;
56
use std::sync::{Arc, OnceLock};
67
use std::time::{Instant, SystemTime};
78

@@ -21,6 +22,7 @@ use tracing::Span;
2122

2223
use rumqttd_protocol::{QoS, RetainForwardRule, SubscribeReasonCode, UnsubAckReason};
2324

25+
use crate::config::acl::AclConfig;
2426
use crate::map_join_error;
2527
use crate::mqtt::mailbox::MailSender;
2628
use crate::mqtt::packets::PacketId;
@@ -72,6 +74,7 @@ impl MqttRouter {
7274
tce_platform: Option<Arc<Platform>>,
7375
tce_messages: Option<MessageStream>,
7476
token: CancellationToken,
77+
acl: AclConfig,
7578
) -> Self {
7679
let (command_tx, command_rx) = mpsc::channel(COMMAND_CAPACITY);
7780

@@ -90,6 +93,7 @@ impl MqttRouter {
9093

9194
let state = RouterState {
9295
token,
96+
acl,
9397
clients: SecondaryMap::new(),
9498
dead_clients: HashSet::default(),
9599
subscriptions: Subscriptions::default(),
@@ -161,6 +165,7 @@ impl RouterHandle {
161165
connection_id: ConnectionId,
162166
client_index: ClientIndex,
163167
client_id: ClientId,
168+
user: String,
164169
mail_tx: MailSender,
165170
clean_session: bool,
166171
) -> crate::Result<RouterConnection> {
@@ -172,6 +177,7 @@ impl RouterHandle {
172177
RouterCommand::NewConnection {
173178
connection_id,
174179
client_id,
180+
user,
175181
message_tx,
176182
mail_tx,
177183
clean_session,
@@ -286,6 +292,7 @@ enum RouterCommand {
286292
connection_id: ConnectionId,
287293
client_id: ClientId,
288294
mail_tx: MailSender,
295+
user: String,
289296
message_tx: mpsc::UnboundedSender<RouterMessage>,
290297
clean_session: bool,
291298
},
@@ -336,6 +343,8 @@ struct RouterState {
336343
clients: SecondaryMap<ClientIndex, ClientState>,
337344
dead_clients: HashSet<ClientIndex>,
338345

346+
acl: AclConfig,
347+
339348
subscriptions: Subscriptions,
340349
command_rx: mpsc::Receiver<(ClientIndex, RouterCommand)>,
341350
system_rx: mpsc::UnboundedReceiver<SystemCommand>,
@@ -358,6 +367,7 @@ struct Tce {
358367
struct ClientState {
359368
client_id: ClientId,
360369
mail_tx: MailSender,
370+
user: String,
361371
subscriptions: ClientSubscriptions,
362372
current_connection: Option<ConnectionState>,
363373
clean_session: bool,
@@ -421,6 +431,15 @@ enum PublishOrigin<'a> {
421431
Consensus(&'a CreatorId),
422432
}
423433

434+
impl PublishOrigin<'_> {
435+
pub fn get_client_index(&self) -> Option<ClientIndex> {
436+
match self {
437+
PublishOrigin::Local(client_index) => Some(*client_index),
438+
_ => None,
439+
}
440+
}
441+
}
442+
424443
impl Index<SubscriptionKind> for Subscriptions {
425444
type Output = SubscriptionMap;
426445

@@ -565,6 +584,7 @@ fn handle_command(state: &mut RouterState, client_idx: ClientIndex, command: Rou
565584
RouterCommand::NewConnection {
566585
connection_id,
567586
client_id,
587+
user,
568588
message_tx,
569589
mail_tx,
570590
clean_session,
@@ -591,6 +611,7 @@ fn handle_command(state: &mut RouterState, client_idx: ClientIndex, command: Rou
591611
client_idx,
592612
ClientState {
593613
client_id,
614+
user,
594615
mail_tx,
595616
subscriptions: Default::default(),
596617
current_connection: Some(ConnectionState {
@@ -657,9 +678,11 @@ fn handle_subscribe(state: &mut RouterState, client_idx: ClientIndex, request: S
657678
publish: Arc<PublishTrasaction>,
658679
}
659680

660-
if !state.clients.contains_key(client_idx) {
681+
let Some(client) = state.clients.get(client_idx) else {
661682
return;
662-
}
683+
};
684+
685+
let permissions = state.acl.get_topics_acl_config(&client.user);
663686

664687
// if state.connections[conn_id].message_tx.is_closed() {
665688
// return;
@@ -677,6 +700,14 @@ fn handle_subscribe(state: &mut RouterState, client_idx: ClientIndex, request: S
677700
// as they would have failed validation on the frontend.
678701
.ok_or(SubscribeReasonCode::Unspecified)
679702
.and_then(|(filter, props)| {
703+
if !state.acl.check_acl_config(
704+
permissions,
705+
&filter,
706+
crate::config::acl::TransactionType::Subscribe,
707+
) {
708+
Err(SubscribeReasonCode::NotAuthorized)?
709+
}
710+
680711
let sub_kind = SubscriptionKind::from_filter(&filter)
681712
.map_err(|_| SubscribeReasonCode::NotAuthorized)?;
682713

@@ -904,6 +935,25 @@ fn dispatch(state: &mut RouterState, publish: Arc<PublishTrasaction>, origin: Pu
904935
},
905936
};
906937

938+
// Check if user has permission to publish
939+
if let Some(client_index) = origin.get_client_index() {
940+
let Some(client) = state.clients.get(client_index) else {
941+
return;
942+
};
943+
944+
let topic_filter = Filter::from(&topic);
945+
946+
let topics_config = state.acl.get_topics_acl_config(&client.user);
947+
948+
if !state.acl.check_acl_config(
949+
topics_config,
950+
&topic_filter,
951+
crate::config::acl::TransactionType::Publish,
952+
) {
953+
return;
954+
}
955+
}
956+
907957
// Only run this if TCE is not available.
908958
if state.tce.is_none() && publish.meta.retain() {
909959
let time_now = state.startup_time + state.startup_instant.elapsed();

0 commit comments

Comments
 (0)