Skip to content

Commit 9a27909

Browse files
committed
Fix security and performance audit findings
Security: - Validate Binance query params to prevent URL parameter injection - Use saturating_abs() in risk checks to fix negative price bypass and i64::MIN panic - Fail all risk checks when equity <= 0 (was silently passing) - Guard CostModel u128→i64 cast with try_from Performance: - O(N) rolling metrics via running sum/sum_sq (was O(N*K)) - Eliminate 3 Vec allocations in RSI/MACD hot paths - Compute CVaR tail mean on iterator (no intermediate Vec)
1 parent ebaa3ea commit 9a27909

File tree

5 files changed

+118
-49
lines changed

5 files changed

+118
-49
lines changed

broker/src/binance/client.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,20 @@ use super::auth;
99
use super::types::{AccountInfo, BookTicker, OrderResponse};
1010
use crate::error::BrokerError;
1111

12+
/// Validate that a parameter value is safe for URL query strings.
13+
///
14+
/// Rejects any value containing characters that could inject additional
15+
/// query parameters (e.g., `&`, `=`, `?`, `#`, space).
16+
fn validate_query_param(value: &str, name: &str) -> Result<(), BrokerError> {
17+
if value.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.' || b == b'-') {
18+
Ok(())
19+
} else {
20+
Err(BrokerError::Order(format!(
21+
"invalid characters in {name}: {value:?}"
22+
)))
23+
}
24+
}
25+
1226
/// Blocking Binance REST client.
1327
pub struct BinanceClient {
1428
client: Client,
@@ -98,6 +112,17 @@ impl BinanceClient {
98112
price: Option<&str>,
99113
time_in_force: Option<&str>,
100114
) -> Result<OrderResponse, BrokerError> {
115+
validate_query_param(symbol, "symbol")?;
116+
validate_query_param(side, "side")?;
117+
validate_query_param(order_type, "order_type")?;
118+
validate_query_param(quantity, "quantity")?;
119+
if let Some(p) = price {
120+
validate_query_param(p, "price")?;
121+
}
122+
if let Some(tif) = time_in_force {
123+
validate_query_param(tif, "timeInForce")?;
124+
}
125+
101126
let timestamp = current_timestamp_ms();
102127
let mut query = format!(
103128
"symbol={symbol}&side={side}&type={order_type}&quantity={quantity}&timestamp={timestamp}"
@@ -137,6 +162,7 @@ impl BinanceClient {
137162

138163
/// Get order status (GET /api/v3/order).
139164
pub fn order_status(&self, symbol: &str, order_id: u64) -> Result<OrderResponse, BrokerError> {
165+
validate_query_param(symbol, "symbol")?;
140166
let timestamp = current_timestamp_ms();
141167
let query = format!("symbol={symbol}&orderId={order_id}&timestamp={timestamp}");
142168
let signature = auth::sign(&query, &self.secret_key);
@@ -166,6 +192,7 @@ impl BinanceClient {
166192

167193
/// Cancel an order (DELETE /api/v3/order).
168194
pub fn cancel_order(&self, symbol: &str, order_id: u64) -> Result<(), BrokerError> {
195+
validate_query_param(symbol, "symbol")?;
169196
let timestamp = current_timestamp_ms();
170197
let query = format!("symbol={symbol}&orderId={order_id}&timestamp={timestamp}");
171198
let signature = auth::sign(&query, &self.secret_key);
@@ -194,6 +221,7 @@ impl BinanceClient {
194221

195222
/// Get book ticker (best bid/ask) for a symbol (GET /api/v3/ticker/bookTicker).
196223
pub fn book_ticker(&self, symbol: &str) -> Result<BookTicker, BrokerError> {
224+
validate_query_param(symbol, "symbol")?;
197225
let url = format!("{}/api/v3/ticker/bookTicker?symbol={symbol}", self.base_url);
198226

199227
let resp = self

risk/src/checks.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ pub fn check_batch(
6868
let gross_exposure: i64 = post_qty
6969
.iter()
7070
.map(|(sym, qty)| {
71-
let price = price_map.get(sym).copied().unwrap_or(0);
72-
qty.abs() * price
71+
let price = price_map.get(sym).copied().unwrap_or(0).saturating_abs();
72+
qty.saturating_abs().saturating_mul(price)
7373
})
74-
.sum();
74+
.fold(0_i64, |acc, v| acc.saturating_add(v));
7575
let leverage = if equity > 0 {
7676
gross_exposure as f64 / equity as f64
7777
} else {
78-
0.0
78+
f64::INFINITY
7979
};
8080
let lev_status = if leverage > config.max_leverage {
8181
RiskStatus::Fail
@@ -102,14 +102,14 @@ pub fn check_batch(
102102
.iter()
103103
.filter(|(_, qty)| **qty < 0)
104104
.map(|(sym, qty)| {
105-
let price = price_map.get(sym).copied().unwrap_or(0);
106-
qty.abs() * price
105+
let price = price_map.get(sym).copied().unwrap_or(0).saturating_abs();
106+
qty.saturating_abs().saturating_mul(price)
107107
})
108-
.sum();
108+
.fold(0_i64, |acc, v| acc.saturating_add(v));
109109
let short_pct = if equity > 0 {
110110
short_exposure as f64 / equity as f64
111111
} else {
112-
0.0
112+
f64::INFINITY
113113
};
114114

115115
let has_shorts = orders
@@ -149,7 +149,7 @@ pub fn check_batch(
149149
let max_cents = (config.max_trade_usd * 100.0) as i64;
150150
for &(sym, _side, qty, price) in orders {
151151
let qty_i64 = i64::try_from(qty).unwrap_or(i64::MAX);
152-
let notional = qty_i64.saturating_mul(price);
152+
let notional = qty_i64.saturating_mul(price.saturating_abs());
153153
if notional > max_cents {
154154
checks.push(RiskCheck {
155155
name: "Max trade size",

src/indicators.rs

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ fn sma(values: &[f64], period: usize) -> Vec<f64> {
6262

6363
/// Population standard deviation over a rolling window.
6464
///
65+
/// Uses O(N) running sum/sum-of-squares instead of O(N*K) re-summation.
6566
/// Returns NaN for the lookback period.
6667
fn rolling_std_pop(values: &[f64], period: usize) -> Vec<f64> {
6768
let n = values.len();
@@ -70,11 +71,24 @@ fn rolling_std_pop(values: &[f64], period: usize) -> Vec<f64> {
7071
return out;
7172
}
7273

73-
for i in (period - 1)..n {
74-
let window = &values[i + 1 - period..=i];
75-
let mean = window.iter().sum::<f64>() / period as f64;
76-
let var = window.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / period as f64;
77-
out[i] = var.sqrt();
74+
let k = period as f64;
75+
76+
// Seed: first window
77+
let mut sum: f64 = values[..period].iter().sum();
78+
let mut sum_sq: f64 = values[..period].iter().map(|v| v * v).sum();
79+
80+
let mean = sum / k;
81+
out[period - 1] = (sum_sq / k - mean * mean).max(0.0).sqrt();
82+
83+
// Slide window: add new, remove old
84+
for i in period..n {
85+
let old = values[i - period];
86+
let new = values[i];
87+
sum += new - old;
88+
sum_sq += new * new - old * old;
89+
90+
let mean = sum / k;
91+
out[i] = (sum_sq / k - mean * mean).max(0.0).sqrt();
7892
}
7993
out
8094
}
@@ -114,21 +128,19 @@ pub fn rsi(close: &[f64], period: usize) -> Vec<f64> {
114128
return out;
115129
}
116130

117-
// Compute gains and losses
118-
let mut gains = vec![0.0_f64; n];
119-
let mut losses = vec![0.0_f64; n];
120-
for i in 1..n {
131+
// Seed with simple average over first `period` changes (indices 1..=period)
132+
let mut avg_gain = 0.0_f64;
133+
let mut avg_loss = 0.0_f64;
134+
for i in 1..=period {
121135
let diff = close[i] - close[i - 1];
122136
if diff > 0.0 {
123-
gains[i] = diff;
137+
avg_gain += diff;
124138
} else {
125-
losses[i] = -diff;
139+
avg_loss -= diff;
126140
}
127141
}
128-
129-
// Seed with simple average over first `period` changes (indices 1..=period)
130-
let mut avg_gain: f64 = gains[1..=period].iter().sum::<f64>() / period as f64;
131-
let mut avg_loss: f64 = losses[1..=period].iter().sum::<f64>() / period as f64;
142+
avg_gain /= period as f64;
143+
avg_loss /= period as f64;
132144

133145
// First RSI value
134146
out[period] = if avg_gain == 0.0 && avg_loss == 0.0 {
@@ -142,8 +154,11 @@ pub fn rsi(close: &[f64], period: usize) -> Vec<f64> {
142154

143155
// Subsequent values with Wilder's smoothing
144156
for i in (period + 1)..n {
145-
avg_gain = (avg_gain * (period as f64 - 1.0) + gains[i]) / period as f64;
146-
avg_loss = (avg_loss * (period as f64 - 1.0) + losses[i]) / period as f64;
157+
let diff = close[i] - close[i - 1];
158+
let gain = if diff > 0.0 { diff } else { 0.0 };
159+
let loss = if diff < 0.0 { -diff } else { 0.0 };
160+
avg_gain = (avg_gain * (period as f64 - 1.0) + gain) / period as f64;
161+
avg_loss = (avg_loss * (period as f64 - 1.0) + loss) / period as f64;
147162

148163
out[i] = if avg_gain == 0.0 && avg_loss == 0.0 {
149164
0.0
@@ -210,9 +225,8 @@ pub fn macd(
210225
}
211226
}
212227

213-
// Signal line = EMA of valid MACD values
214-
let valid_macd: Vec<f64> = macd_line[first_valid..].to_vec();
215-
let signal_raw = ema(&valid_macd, signal_period);
228+
// Signal line = EMA of valid MACD values (pass slice directly — no copy)
229+
let signal_raw = ema(&macd_line[first_valid..], signal_period);
216230

217231
let mut signal_line = vec![f64::NAN; n];
218232
for (j, &val) in signal_raw.iter().enumerate() {

src/portfolio/cost_model.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ impl CostModel {
4141
let notional = notional.unsigned_abs() as u128;
4242
let total_bps = self.commission_bps as u128 + self.slippage_bps as u128;
4343
// notional * bps / 10_000 — use u128 to prevent overflow
44-
let bps_cost = (notional * total_bps / 10_000) as i64;
44+
let raw = notional * total_bps / 10_000;
45+
let bps_cost = i64::try_from(raw).unwrap_or(i64::MAX);
4546
bps_cost.max(self.min_trade_fee)
4647
}
4748
}

src/portfolio/metrics.rs

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,18 @@ fn compute_cvar(returns: &[f64], alpha: f64) -> f64 {
255255
let z = norm_ppf(alpha);
256256
let var_threshold = mu + sigma * z;
257257

258-
// CVaR: mean of returns strictly below VaR
259-
let tail: Vec<f64> = returns.iter().copied().filter(|&r| r < var_threshold).collect();
260-
if tail.is_empty() {
258+
// CVaR: mean of returns strictly below VaR (computed on iterator — no allocation)
259+
let (tail_sum, tail_count) = returns
260+
.iter()
261+
.filter(|&&r| r < var_threshold)
262+
.fold((0.0_f64, 0_usize), |(sum, cnt), &r| (sum + r, cnt + 1));
263+
if tail_count == 0 {
261264
return *returns
262265
.iter()
263266
.min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
264267
.unwrap_or(&0.0);
265268
}
266-
tail.iter().sum::<f64>() / tail.len() as f64
269+
tail_sum / tail_count as f64
267270
}
268271

269272
/// Inverse of the standard normal CDF (probit function).
@@ -351,17 +354,28 @@ pub fn rolling_sharpe(returns: &[f64], window: usize, periods_per_year: usize) -
351354
}
352355

353356
let ppy = periods_per_year as f64;
357+
let k = window as f64;
354358

355-
for i in (window - 1)..n {
356-
let w = &returns[i + 1 - window..=i];
357-
let mean = w.iter().sum::<f64>() / window as f64;
358-
let var = w.iter().map(|&r| (r - mean).powi(2)).sum::<f64>() / (window - 1) as f64;
359-
let std = var.sqrt();
360-
out[i] = if std > 0.0 {
361-
mean * ppy.sqrt() / std
362-
} else {
363-
0.0
364-
};
359+
// Seed first window
360+
let mut sum: f64 = returns[..window].iter().sum();
361+
let mut sum_sq: f64 = returns[..window].iter().map(|r| r * r).sum();
362+
363+
let mean = sum / k;
364+
let var = (sum_sq - sum * sum / k) / (k - 1.0);
365+
let std = var.max(0.0).sqrt();
366+
out[window - 1] = if std > 0.0 { mean * ppy.sqrt() / std } else { 0.0 };
367+
368+
// Slide window
369+
for i in window..n {
370+
let old = returns[i - window];
371+
let new = returns[i];
372+
sum += new - old;
373+
sum_sq += new * new - old * old;
374+
375+
let mean = sum / k;
376+
let var = (sum_sq - sum * sum / k) / (k - 1.0);
377+
let std = var.max(0.0).sqrt();
378+
out[i] = if std > 0.0 { mean * ppy.sqrt() / std } else { 0.0 };
365379
}
366380

367381
out
@@ -384,12 +398,24 @@ pub fn rolling_volatility(returns: &[f64], window: usize, periods_per_year: usiz
384398
}
385399

386400
let ppy = periods_per_year as f64;
401+
let k = window as f64;
402+
403+
// Seed first window
404+
let mut sum: f64 = returns[..window].iter().sum();
405+
let mut sum_sq: f64 = returns[..window].iter().map(|r| r * r).sum();
406+
407+
let var = (sum_sq - sum * sum / k) / (k - 1.0);
408+
out[window - 1] = var.max(0.0).sqrt() * ppy.sqrt();
409+
410+
// Slide window
411+
for i in window..n {
412+
let old = returns[i - window];
413+
let new = returns[i];
414+
sum += new - old;
415+
sum_sq += new * new - old * old;
387416

388-
for i in (window - 1)..n {
389-
let w = &returns[i + 1 - window..=i];
390-
let mean = w.iter().sum::<f64>() / window as f64;
391-
let var = w.iter().map(|&r| (r - mean).powi(2)).sum::<f64>() / (window - 1) as f64;
392-
out[i] = var.sqrt() * ppy.sqrt();
417+
let var = (sum_sq - sum * sum / k) / (k - 1.0);
418+
out[i] = var.max(0.0).sqrt() * ppy.sqrt();
393419
}
394420

395421
out

0 commit comments

Comments
 (0)