Skip to content

Commit 5aeebb6

Browse files
committed
fix: align config, JSON-RPC, and GPU process handling
Ensure overrides, IDs, and GPU usage reporting follow documented behavior and security expectations.
1 parent 5abf471 commit 5aeebb6

File tree

12 files changed

+349
-103
lines changed

12 files changed

+349
-103
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ members = [
66

77
[package]
88
name = "gpukill"
9-
version = "0.1.10"
9+
version = "0.1.11"
1010
edition = "2021"
1111
authors = ["Kage <info@treadie.com>"]
1212
description = "A CLI tool for GPU management and monitoring supporting NVIDIA, AMD, Intel, and Apple Silicon GPUs"

mcp/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "gpukill-mcp"
3-
version = "0.1.1"
3+
version = "0.1.2"
44
edition = "2021"
55
authors = ["GPU Kill Team"]
66
description = "MCP server for GPU Kill - AI-accessible GPU management"

mcp/src/server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,14 @@ impl GpuKillMCPServer {
153153
move |request: axum::extract::Json<JsonRpcRequest>| {
154154
let server = server.clone();
155155
async move {
156+
let request_id = request.0.id.clone();
156157
match server.handle_request(request.0).await {
157158
Ok(response) => axum::response::Json(response),
158159
Err(e) => {
159160
error!("Failed to handle HTTP request: {}", e);
160161
axum::response::Json(JsonRpcResponse {
161162
jsonrpc: "2.0".to_string(),
162-
id: crate::types::RequestId::Null, // Per JSON-RPC 2.0: use null when id cannot be determined
163+
id: request_id,
163164
result: None,
164165
error: Some(JsonRpcError {
165166
code: -32603,

mcp/src/types.rs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,25 @@ impl From<i32> for RequestId {
6161
}
6262
}
6363

64+
fn validate_jsonrpc_version<'de, D>(deserializer: D) -> Result<String, D::Error>
65+
where
66+
D: serde::Deserializer<'de>,
67+
{
68+
let version = String::deserialize(deserializer)?;
69+
if version != "2.0" {
70+
return Err(serde::de::Error::custom(format!(
71+
"jsonrpc must be '2.0', got '{}'",
72+
version
73+
)));
74+
}
75+
Ok(version)
76+
}
77+
6478
/// MCP Request/Response types
6579
#[derive(Debug, Serialize, Deserialize)]
66-
#[serde(tag = "jsonrpc", rename = "2.0")]
6780
pub struct JsonRpcRequest {
81+
#[serde(deserialize_with = "validate_jsonrpc_version")]
82+
pub jsonrpc: String,
6883
/// Request identifier - can be String, Number, or Null per JSON-RPC 2.0
6984
pub id: RequestId,
7085
pub method: String,
@@ -73,6 +88,7 @@ pub struct JsonRpcRequest {
7388

7489
#[derive(Debug, Serialize, Deserialize)]
7590
pub struct JsonRpcResponse {
91+
#[serde(deserialize_with = "validate_jsonrpc_version")]
7692
pub jsonrpc: String,
7793
/// Response identifier - must match the request id per JSON-RPC 2.0
7894
pub id: RequestId,
@@ -298,6 +314,34 @@ mod tests {
298314
assert_eq!(parsed.method, "initialize");
299315
}
300316

317+
#[test]
318+
fn test_jsonrpc_request_rejects_wrong_version() {
319+
let request = json!({
320+
"jsonrpc": "1.0",
321+
"method": "initialize",
322+
"params": {},
323+
"id": 1
324+
});
325+
326+
let parsed: Result<JsonRpcRequest, _> = from_value(request);
327+
assert!(parsed.is_err(), "expected jsonrpc version to be rejected");
328+
}
329+
330+
#[test]
331+
fn test_jsonrpc_request_rejects_missing_version() {
332+
let request = json!({
333+
"method": "initialize",
334+
"params": {},
335+
"id": 1
336+
});
337+
338+
let parsed: Result<JsonRpcRequest, _> = from_value(request);
339+
assert!(
340+
parsed.is_err(),
341+
"expected missing jsonrpc field to be rejected"
342+
);
343+
}
344+
301345
#[test]
302346
fn test_jsonrpc_request_with_null_id() {
303347
// Null IDs are valid in JSON-RPC 2.0 (but not MCP)

src/args.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ pub struct Cli {
8686
#[arg(long)]
8787
pub filter: Option<String>,
8888

89-
/// Kill multiple processes matching the filter
90-
#[arg(long, requires = "filter")]
89+
/// Kill multiple processes matching the filter or GPU
90+
#[arg(long)]
9191
pub batch: bool,
9292

9393
/// Show container information for processes
@@ -628,6 +628,14 @@ mod tests {
628628
assert!(cli.force);
629629
}
630630

631+
#[test]
632+
fn test_kill_batch_with_gpu() {
633+
let cli = Cli::try_parse_from(["gpukill", "--kill", "--batch", "--gpu", "0"]).unwrap();
634+
assert!(cli.kill);
635+
assert!(cli.batch);
636+
assert_eq!(cli.gpu, Some(0));
637+
}
638+
631639
#[test]
632640
fn test_reset_single_gpu() {
633641
let cli = Cli::try_parse_from(["gpukill", "--reset", "--gpu", "0"]).unwrap();

src/config.rs

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -88,42 +88,7 @@ impl ConfigManager {
8888
/// Load configuration from environment variables
8989
pub fn load_from_env() -> Self {
9090
let mut config = Config::default();
91-
92-
// Override with environment variables if present
93-
if let Ok(log_level) = std::env::var("GPUKILL_LOG_LEVEL") {
94-
config.log_level = log_level;
95-
}
96-
97-
if let Ok(output_format) = std::env::var("GPUKILL_OUTPUT_FORMAT") {
98-
config.output_format = output_format;
99-
}
100-
101-
if let Ok(timeout) = std::env::var("GPUKILL_DEFAULT_TIMEOUT") {
102-
if let Ok(timeout_secs) = timeout.parse::<u16>() {
103-
config.default_timeout_secs = timeout_secs;
104-
}
105-
}
106-
107-
if let Ok(show_details) = std::env::var("GPUKILL_SHOW_DETAILS") {
108-
config.show_details = show_details.parse().unwrap_or(false);
109-
}
110-
111-
if let Ok(watch_interval) = std::env::var("GPUKILL_WATCH_INTERVAL") {
112-
if let Ok(interval_secs) = watch_interval.parse::<u64>() {
113-
config.watch_interval_secs = interval_secs;
114-
}
115-
}
116-
117-
if let Ok(table_width) = std::env::var("GPUKILL_TABLE_WIDTH") {
118-
if let Ok(width) = table_width.parse::<usize>() {
119-
config.table_width = width;
120-
}
121-
}
122-
123-
if let Ok(use_colors) = std::env::var("GPUKILL_USE_COLORS") {
124-
config.use_colors = use_colors.parse().unwrap_or(true);
125-
}
126-
91+
apply_env_overrides(&mut config);
12792
Self { config }
12893
}
12994

@@ -181,20 +146,55 @@ impl ConfigManager {
181146
}
182147
}
183148

184-
/// Get configuration with fallback chain
185-
pub fn get_config(config_path: Option<String>) -> Result<ConfigManager> {
186-
// 1. Try to load from specified path
187-
if let Some(path) = config_path {
188-
return ConfigManager::load_from_file(path);
149+
fn apply_env_overrides(config: &mut Config) {
150+
// Override with environment variables if present
151+
if let Ok(log_level) = std::env::var("GPUKILL_LOG_LEVEL") {
152+
config.log_level = log_level;
153+
}
154+
155+
if let Ok(output_format) = std::env::var("GPUKILL_OUTPUT_FORMAT") {
156+
config.output_format = output_format;
157+
}
158+
159+
if let Ok(timeout) = std::env::var("GPUKILL_DEFAULT_TIMEOUT") {
160+
if let Ok(timeout_secs) = timeout.parse::<u16>() {
161+
config.default_timeout_secs = timeout_secs;
162+
}
189163
}
190164

191-
// 2. Try to load from default location
192-
if let Ok(config_manager) = ConfigManager::load_default() {
193-
return Ok(config_manager);
165+
if let Ok(show_details) = std::env::var("GPUKILL_SHOW_DETAILS") {
166+
config.show_details = show_details.parse().unwrap_or(false);
194167
}
195168

196-
// 3. Load from environment variables
197-
Ok(ConfigManager::load_from_env())
169+
if let Ok(watch_interval) = std::env::var("GPUKILL_WATCH_INTERVAL") {
170+
if let Ok(interval_secs) = watch_interval.parse::<u64>() {
171+
config.watch_interval_secs = interval_secs;
172+
}
173+
}
174+
175+
if let Ok(table_width) = std::env::var("GPUKILL_TABLE_WIDTH") {
176+
if let Ok(width) = table_width.parse::<usize>() {
177+
config.table_width = width;
178+
}
179+
}
180+
181+
if let Ok(use_colors) = std::env::var("GPUKILL_USE_COLORS") {
182+
config.use_colors = use_colors.parse().unwrap_or(true);
183+
}
184+
}
185+
186+
/// Get configuration with fallback chain
187+
pub fn get_config(config_path: Option<String>) -> Result<ConfigManager> {
188+
let mut config = if let Some(path) = config_path {
189+
ConfigManager::load_from_file(path)?.config
190+
} else if let Ok(config_manager) = ConfigManager::load_default() {
191+
config_manager.config
192+
} else {
193+
Config::default()
194+
};
195+
196+
apply_env_overrides(&mut config);
197+
Ok(ConfigManager { config })
198198
}
199199

200200
#[cfg(test)]
@@ -239,4 +239,25 @@ mod tests {
239239
let manager = ConfigManager::new();
240240
assert_eq!(manager.config().log_level, "info");
241241
}
242+
243+
#[test]
244+
fn test_env_overrides_config_file() {
245+
let mut config = Config::default();
246+
config.log_level = "warn".to_string();
247+
config.watch_interval_secs = 2;
248+
let toml_str = toml::to_string_pretty(&config).unwrap();
249+
250+
let temp_file = NamedTempFile::new().unwrap();
251+
std::fs::write(temp_file.path(), toml_str).unwrap();
252+
253+
std::env::set_var("GPUKILL_LOG_LEVEL", "debug");
254+
std::env::set_var("GPUKILL_WATCH_INTERVAL", "10");
255+
256+
let loaded = get_config(Some(temp_file.path().to_string_lossy().to_string())).unwrap();
257+
assert_eq!(loaded.config().log_level, "debug");
258+
assert_eq!(loaded.config().watch_interval_secs, 10);
259+
260+
std::env::remove_var("GPUKILL_LOG_LEVEL");
261+
std::env::remove_var("GPUKILL_WATCH_INTERVAL");
262+
}
242263
}

src/coordinator.rs

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ impl CoordinatorState {
239239
let snapshots = self.snapshots.read().await;
240240
let mut blocked_gpus = Vec::new();
241241
// Track unique (node_id, gpu_index) pairs per user to correctly count GPUs
242-
// Tuple: (unique_gpus, memory, utilization, process_count)
242+
// Tuple: (unique_gpus, memory, utilization_sum, process_count)
243243
#[allow(clippy::type_complexity)]
244244
let mut user_stats: HashMap<String, (HashSet<(String, u16)>, u32, f32, u32)> =
245245
HashMap::new();
@@ -279,9 +279,10 @@ impl CoordinatorState {
279279
0,
280280
));
281281
// Track unique (node_id, gpu_index) pairs to correctly count GPUs
282-
entry.0.insert((node_id.clone(), gpu.gpu_index));
282+
if entry.0.insert((node_id.clone(), gpu.gpu_index)) {
283+
entry.2 += gpu.util_pct; // utilization (sum per unique GPU)
284+
}
283285
entry.1 += process.used_mem_mb; // memory
284-
entry.2 += gpu.util_pct; // utilization (will average later)
285286
entry.3 += 1; // process_count
286287
}
287288
}
@@ -295,7 +296,11 @@ impl CoordinatorState {
295296
user,
296297
gpu_count: gpu_set.len() as u32, // Count unique GPUs
297298
total_memory_mb,
298-
avg_utilization: total_util / process_count as f32,
299+
avg_utilization: if gpu_set.is_empty() {
300+
0.0
301+
} else {
302+
total_util / gpu_set.len() as f32
303+
},
299304
process_count,
300305
},
301306
)
@@ -961,4 +966,96 @@ mod tests {
961966
"Bob uses 4 unique GPUs across 2 nodes"
962967
);
963968
}
969+
970+
#[tokio::test]
971+
async fn test_contention_analysis_avg_utilization_unique_gpus() {
972+
let state = CoordinatorState::new();
973+
974+
let snapshot = NodeSnapshot {
975+
node_id: "test-node".to_string(),
976+
hostname: "test-host".to_string(),
977+
timestamp: Utc::now(),
978+
gpus: vec![
979+
GpuSnapshot {
980+
gpu_index: 0,
981+
name: "GPU 0".to_string(),
982+
vendor: GpuVendor::Nvidia,
983+
mem_used_mb: 8000,
984+
mem_total_mb: 10000,
985+
util_pct: 90.0,
986+
temp_c: 75,
987+
power_w: 200.0,
988+
ecc_volatile: None,
989+
pids: 2,
990+
top_proc: None,
991+
},
992+
GpuSnapshot {
993+
gpu_index: 1,
994+
name: "GPU 1".to_string(),
995+
vendor: GpuVendor::Nvidia,
996+
mem_used_mb: 3000,
997+
mem_total_mb: 10000,
998+
util_pct: 30.0,
999+
temp_c: 65,
1000+
power_w: 100.0,
1001+
ecc_volatile: None,
1002+
pids: 1,
1003+
top_proc: None,
1004+
},
1005+
],
1006+
processes: vec![
1007+
GpuProc {
1008+
gpu_index: 0,
1009+
pid: 1001,
1010+
user: "charlie".to_string(),
1011+
proc_name: "train1".to_string(),
1012+
used_mem_mb: 4000,
1013+
start_time: "2025-09-20T01:00:00Z".to_string(),
1014+
container: None,
1015+
},
1016+
GpuProc {
1017+
gpu_index: 0,
1018+
pid: 1002,
1019+
user: "charlie".to_string(),
1020+
proc_name: "train2".to_string(),
1021+
used_mem_mb: 4000,
1022+
start_time: "2025-09-20T01:00:00Z".to_string(),
1023+
container: None,
1024+
},
1025+
GpuProc {
1026+
gpu_index: 1,
1027+
pid: 1003,
1028+
user: "charlie".to_string(),
1029+
proc_name: "train3".to_string(),
1030+
used_mem_mb: 3000,
1031+
start_time: "2025-09-20T01:00:00Z".to_string(),
1032+
container: None,
1033+
},
1034+
],
1035+
status: NodeStatus::Online,
1036+
};
1037+
1038+
state
1039+
.update_snapshot("test-node".to_string(), snapshot)
1040+
.await
1041+
.unwrap();
1042+
1043+
let analysis = state.get_contention_analysis().await.unwrap();
1044+
1045+
let charlie_stats = analysis
1046+
.top_users
1047+
.iter()
1048+
.find(|u| u.user == "charlie")
1049+
.expect("Charlie should be in top users");
1050+
1051+
assert_eq!(charlie_stats.gpu_count, 2, "Charlie uses 2 unique GPUs");
1052+
assert_eq!(charlie_stats.process_count, 3, "Charlie has 3 processes");
1053+
1054+
let expected_avg = 60.0;
1055+
let diff = (charlie_stats.avg_utilization - expected_avg).abs();
1056+
assert!(
1057+
diff < 0.01,
1058+
"Average utilization should be calculated per unique GPU"
1059+
);
1060+
}
9641061
}

0 commit comments

Comments
 (0)