Skip to content

Commit be87daa

Browse files
authored
Merge pull request #57 from rust-lang/protection
add protected commands and protected help menu items
2 parents 39b4e3f + 089759b commit be87daa

File tree

5 files changed

+176
-95
lines changed

5 files changed

+176
-95
lines changed

src/ban.rs

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,43 +79,41 @@ pub(crate) fn start_unban_thread(cx: Context) {
7979
///
8080
/// Requires the ban members permission
8181
pub(crate) fn temp_ban(args: Args) -> Result<()> {
82-
if api::is_mod(&args)? {
83-
let user_id = parse_username(
84-
&args
85-
.params
86-
.get("user")
87-
.ok_or("unable to retrieve user param")?,
88-
)
89-
.ok_or("unable to retrieve user id")?;
82+
let user_id = parse_username(
83+
&args
84+
.params
85+
.get("user")
86+
.ok_or("unable to retrieve user param")?,
87+
)
88+
.ok_or("unable to retrieve user id")?;
9089

91-
use std::str::FromStr;
90+
use std::str::FromStr;
9291

93-
let hours = u64::from_str(
94-
args.params
95-
.get("hours")
96-
.ok_or("unable to retrieve hours param")?,
97-
)?;
92+
let hours = u64::from_str(
93+
args.params
94+
.get("hours")
95+
.ok_or("unable to retrieve hours param")?,
96+
)?;
9897

99-
let reason = args
100-
.params
101-
.get("reason")
102-
.ok_or("unable to retrieve reason param")?;
98+
let reason = args
99+
.params
100+
.get("reason")
101+
.ok_or("unable to retrieve reason param")?;
103102

104-
if let Some(guild) = args.msg.guild(&args.cx) {
105-
info!("Banning user from guild");
106-
let user = UserId::from(user_id);
103+
if let Some(guild) = args.msg.guild(&args.cx) {
104+
info!("Banning user from guild");
105+
let user = UserId::from(user_id);
107106

108-
user.create_dm_channel(args.cx)?
109-
.say(args.cx, ban_message(reason, hours))?;
107+
user.create_dm_channel(args.cx)?
108+
.say(args.cx, ban_message(reason, hours))?;
110109

111-
guild.read().ban(args.cx, &user, &"all")?;
110+
guild.read().ban(args.cx, &user, &"all")?;
112111

113-
save_ban(
114-
format!("{}", user_id),
115-
format!("{}", guild.read().id),
116-
hours,
117-
)?;
118-
}
112+
save_ban(
113+
format!("{}", user_id),
114+
format!("{}", guild.read().id),
115+
hours,
116+
)?;
119117
}
120118
Ok(())
121119
}

src/commands.rs

Lines changed: 94 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
1-
use crate::state_machine::{CharacterSet, StateMachine};
1+
use crate::{
2+
api,
3+
state_machine::{CharacterSet, StateMachine},
4+
};
25
use reqwest::blocking::Client as HttpClient;
36
use serenity::{model::channel::Message, prelude::Context};
47
use std::{collections::HashMap, sync::Arc};
58

69
const PREFIX: &'static str = "?";
710
pub(crate) type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
8-
pub(crate) type CmdPtr = Arc<dyn for<'m> Fn(Args<'m>) -> Result<()> + Send + Sync>;
11+
pub(crate) type GuardFn = fn(&Args) -> Result<bool>;
12+
13+
struct Command {
14+
guard: GuardFn,
15+
ptr: Box<dyn for<'m> Fn(Args<'m>) -> Result<()> + Send + Sync>,
16+
}
17+
18+
impl Command {
19+
fn authorize(&self, args: &Args) -> Result<bool> {
20+
(self.guard)(&args)
21+
}
22+
23+
fn call(&self, args: Args) -> Result<()> {
24+
(self.ptr)(args)
25+
}
26+
}
927

1028
pub struct Args<'m> {
1129
pub http: &'m HttpClient,
@@ -15,24 +33,33 @@ pub struct Args<'m> {
1533
}
1634

1735
pub(crate) struct Commands {
18-
state_machine: StateMachine,
36+
state_machine: StateMachine<Arc<Command>>,
1937
client: HttpClient,
20-
menu: HashMap<&'static str, &'static str>,
38+
menu: Option<HashMap<&'static str, (&'static str, GuardFn)>>,
2139
}
2240

2341
impl Commands {
2442
pub(crate) fn new() -> Self {
2543
Self {
2644
state_machine: StateMachine::new(),
2745
client: HttpClient::new(),
28-
menu: HashMap::new(),
46+
menu: Some(HashMap::new()),
2947
}
3048
}
3149

3250
pub(crate) fn add(
3351
&mut self,
3452
command: &'static str,
3553
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
54+
) {
55+
self.add_protected(command, handler, |_| Ok(true));
56+
}
57+
58+
pub(crate) fn add_protected(
59+
&mut self,
60+
command: &'static str,
61+
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
62+
guard: GuardFn,
3663
) {
3764
info!("Adding command {}", &command);
3865
let mut state = 0;
@@ -89,7 +116,10 @@ impl Commands {
89116
}
90117
});
91118

92-
let handler = Arc::new(handler);
119+
let handler = Arc::new(Command {
120+
guard,
121+
ptr: Box::new(handler),
122+
});
93123

94124
if opt_lambda_state.is_some() {
95125
opt_final_states.iter().for_each(|state| {
@@ -107,35 +137,69 @@ impl Commands {
107137
cmd: &'static str,
108138
desc: &'static str,
109139
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
140+
) {
141+
self.help_protected(cmd, desc, handler, |_| Ok(true));
142+
}
143+
144+
pub(crate) fn help_protected(
145+
&mut self,
146+
cmd: &'static str,
147+
desc: &'static str,
148+
handler: impl Fn(Args) -> Result<()> + Send + Sync + 'static,
149+
guard: GuardFn,
110150
) {
111151
let base_cmd = &cmd[1..];
112152
info!("Adding command ?help {}", &base_cmd);
113153
let mut state = 0;
114154

115-
self.menu.insert(cmd, desc);
116-
state = add_help_menu(&mut self.state_machine, base_cmd, state);
155+
self.menu.as_mut().map(|menu| {
156+
menu.insert(cmd, (desc, guard));
157+
menu
158+
});
117159

160+
state = add_help_menu(&mut self.state_machine, base_cmd, state);
118161
self.state_machine.set_final_state(state);
119-
self.state_machine.set_handler(state, Arc::new(handler));
162+
self.state_machine.set_handler(
163+
state,
164+
Arc::new(Command {
165+
guard,
166+
ptr: Box::new(handler),
167+
}),
168+
);
120169
}
121170

122-
pub(crate) fn menu(&mut self) -> &HashMap<&'static str, &'static str> {
123-
&self.menu
171+
pub(crate) fn menu(&mut self) -> Option<HashMap<&'static str, (&'static str, GuardFn)>> {
172+
self.menu.take()
124173
}
125174

126175
pub(crate) fn execute<'m>(&'m self, cx: Context, msg: Message) {
127176
let message = &msg.content;
128177
if !msg.is_own(&cx) && message.starts_with(PREFIX) {
129178
self.state_machine.process(message).map(|matched| {
130-
info!("Executing command {}", message);
179+
info!("Processing command: {}", message);
131180
let args = Args {
132181
http: &self.client,
133182
cx: &cx,
134183
msg: &msg,
135184
params: matched.params,
136185
};
137-
if let Err(e) = (matched.handler)(args) {
138-
println!("{}", e);
186+
info!("Checking permissions");
187+
match matched.handler.authorize(&args) {
188+
Ok(true) => {
189+
info!("Executing command");
190+
if let Err(e) = matched.handler.call(args) {
191+
error!("{}", e);
192+
}
193+
}
194+
Ok(false) => {
195+
info!("Not executing command, unauthorized");
196+
if let Err(e) =
197+
api::send_reply(&args, "You do not have permission to run this command")
198+
{
199+
error!("{}", e);
200+
}
201+
}
202+
Err(e) => error!("{}", e),
139203
}
140204
});
141205
}
@@ -156,7 +220,7 @@ fn key_value_pair(s: &'static str) -> Option<&'static str> {
156220
.flatten()
157221
}
158222

159-
fn add_space(state_machine: &mut StateMachine, mut state: usize, i: usize) -> usize {
223+
fn add_space<T>(state_machine: &mut StateMachine<T>, mut state: usize, i: usize) -> usize {
160224
if i > 0 {
161225
let mut char_set = CharacterSet::from_char(' ');
162226
char_set.insert('\n');
@@ -167,8 +231,8 @@ fn add_space(state_machine: &mut StateMachine, mut state: usize, i: usize) -> us
167231
state
168232
}
169233

170-
fn add_help_menu(
171-
mut state_machine: &mut StateMachine,
234+
fn add_help_menu<T>(
235+
mut state_machine: &mut StateMachine<T>,
172236
cmd: &'static str,
173237
mut state: usize,
174238
) -> usize {
@@ -183,8 +247,8 @@ fn add_help_menu(
183247
state
184248
}
185249

186-
fn add_dynamic_segment(
187-
state_machine: &mut StateMachine,
250+
fn add_dynamic_segment<T>(
251+
state_machine: &mut StateMachine<T>,
188252
name: &'static str,
189253
mut state: usize,
190254
) -> usize {
@@ -198,8 +262,8 @@ fn add_dynamic_segment(
198262
state
199263
}
200264

201-
fn add_remaining_segment(
202-
state_machine: &mut StateMachine,
265+
fn add_remaining_segment<T>(
266+
state_machine: &mut StateMachine<T>,
203267
name: &'static str,
204268
mut state: usize,
205269
) -> usize {
@@ -212,8 +276,8 @@ fn add_remaining_segment(
212276
state
213277
}
214278

215-
fn add_code_segment_multi_line(
216-
state_machine: &mut StateMachine,
279+
fn add_code_segment_multi_line<T>(
280+
state_machine: &mut StateMachine<T>,
217281
name: &'static str,
218282
mut state: usize,
219283
) -> usize {
@@ -246,8 +310,8 @@ fn add_code_segment_multi_line(
246310
state
247311
}
248312

249-
fn add_code_segment_single_line(
250-
state_machine: &mut StateMachine,
313+
fn add_code_segment_single_line<T>(
314+
state_machine: &mut StateMachine<T>,
251315
name: &'static str,
252316
mut state: usize,
253317
n_backticks: usize,
@@ -266,7 +330,11 @@ fn add_code_segment_single_line(
266330
state
267331
}
268332

269-
fn add_key_value(state_machine: &mut StateMachine, name: &'static str, mut state: usize) -> usize {
333+
fn add_key_value<T>(
334+
state_machine: &mut StateMachine<T>,
335+
name: &'static str,
336+
mut state: usize,
337+
) -> usize {
270338
name.chars().for_each(|c| {
271339
state = state_machine.add(state, CharacterSet::from_char(c));
272340
});

0 commit comments

Comments
 (0)