From 710c9a9ef42e580acfe8ffd20d436eee2061296d Mon Sep 17 00:00:00 2001 From: Tomasz Kalinowski Date: Thu, 18 Jun 2026 17:42:22 -0400 Subject: [PATCH] Extract Python session I/O modules --- src/python_session.rs | 998 +--------------------------- src/python_session/state.rs | 330 +++++++++ src/python_session/stdio.rs | 196 ++++++ src/python_session/unix_stdin.rs | 12 +- src/python_session/windows_stdin.rs | 462 +++++++++++++ 5 files changed, 1022 insertions(+), 976 deletions(-) create mode 100644 src/python_session/state.rs create mode 100644 src/python_session/stdio.rs create mode 100644 src/python_session/windows_stdin.rs diff --git a/src/python_session.rs b/src/python_session.rs index 015673fe..1ae8f5d8 100644 --- a/src/python_session.rs +++ b/src/python_session.rs @@ -1,7 +1,4 @@ -use std::collections::VecDeque; use std::ffi::{CStr, CString, c_char, c_int, c_long}; -#[cfg(target_family = "unix")] -use std::os::unix::io::RawFd; use std::path::Path; use std::ptr; use std::sync::atomic::{AtomicPtr, Ordering}; @@ -9,20 +6,28 @@ use std::sync::{Arc, Condvar, Mutex, OnceLock}; use crate::ipc; use crate::python_ffi::{GilGuard, ModuleMethod, PyObject, PyPtr, PyThreadState, PythonApi}; -#[cfg(target_family = "unix")] -use crate::python_turn_input::PythonTurnInput; use crate::worker_protocol::TextStream; + +use state::{ + ActiveRequest, PythonReadlineState, RawStdinReadError, SESSION_STATE, SessionState, + StdinReadAccounting, begin_repl_turn, clear_current_readline_prompt, remember_emitted_prompt, + repl_prompt_for, request_active, session_state, set_current_readline_prompt, + set_current_repl_readline_prompt, +}; +#[cfg(not(any(target_family = "unix", windows)))] +use state::{input_hook_prompt, mark_stdin_wait_prompt_completed_request}; #[cfg(windows)] -use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE; -#[cfg(windows)] -use windows_sys::Win32::Storage::FileSystem::ReadFile; -#[cfg(windows)] -use windows_sys::Win32::System::Console::{GetStdHandle, STD_INPUT_HANDLE}; -#[cfg(windows)] -use windows_sys::Win32::System::Pipes::PeekNamedPipe; +use stdio::StdioLineRead; +use stdio::{PYTHON_STDIN_FILE, PythonRuntime, open_python_runtime}; +#[cfg(not(windows))] +use stdio::{read_stdio_line_bytes, read_stdio_line_bytes_allowing_python_threads}; +mod state; +mod stdio; #[cfg(target_family = "unix")] mod unix_stdin; +#[cfg(windows)] +mod windows_stdin; const MCP_REPL_PYTHON: &str = include_str!("../python/embedded.py"); const PYTHON_EOF: c_int = 11; @@ -60,7 +65,7 @@ impl PythonSession { #[cfg(windows)] pub fn begin_turn(&self, turn_id: u64, input: String) -> Result<(), String> { self.wait_until_ready()?; - begin_tracked_turn(turn_id, input) + windows_stdin::begin_tracked_turn(turn_id, input) } } @@ -113,11 +118,6 @@ impl SessionInit { } } -struct PythonRuntime { - #[cfg_attr(windows, allow(dead_code))] - stdin: *mut libc::FILE, -} - fn request_exit() { let Some(state) = SESSION_STATE.get() else { return; @@ -143,27 +143,7 @@ pub(crate) fn interrupt() { #[cfg(windows)] pub(crate) fn interrupt_turn(turn_id: u64) { - let Some(state) = SESSION_STATE.get() else { - return; - }; - { - let mut guard = state.inner.lock().unwrap(); - let write_in_flight = guard.turn_write_in_flight; - let Some(active) = guard.active_request.as_mut() else { - return; - }; - if active.turn_id != Some(turn_id) { - return; - } - active.queued_lines.clear(); - if write_in_flight { - guard.turn_cleanup_uncertain = true; - } - guard.interrupt_requested = true; - guard.waiting_for_input = false; - state.cvar.notify_all(); - } - request_platform_interrupt(); + windows_stdin::interrupt_turn(turn_id); } fn interrupt_for_request_generation(request_generation: Option) { @@ -231,7 +211,7 @@ pub(crate) fn append_turn_input(turn_id: u64, input: String) -> Result<(), Strin #[cfg(windows)] pub(crate) fn append_turn_input(turn_id: u64, input: String) -> Result<(), String> { - append_tracked_turn_input(turn_id, input) + windows_stdin::append_tracked_turn_input(turn_id, input) } #[cfg(not(any(target_family = "unix", windows)))] @@ -273,59 +253,12 @@ fn emit_protocol_failure(message: &str) { #[cfg(windows)] fn discard_pending_stdin() { - let stdin = PYTHON_STDIN_FILE.load(Ordering::SeqCst); - if !stdin.is_null() { - unsafe { - libc::fflush(stdin); - } - } - drain_stdin_pipe(); + windows_stdin::discard_pending_stdin(); } #[cfg(not(any(target_family = "unix", windows)))] fn discard_pending_stdin() {} -#[cfg(windows)] -fn drain_stdin_pipe() { - let handle = unsafe { GetStdHandle(STD_INPUT_HANDLE) }; - if handle.is_null() || handle == INVALID_HANDLE_VALUE { - return; - } - - let mut buffer = [0u8; 8192]; - loop { - let mut available = 0u32; - let ok = unsafe { - PeekNamedPipe( - handle, - ptr::null_mut(), - 0, - ptr::null_mut(), - &mut available, - ptr::null_mut(), - ) - }; - if ok == 0 || available == 0 { - break; - } - - let to_read = available.min(buffer.len() as u32); - let mut read = 0u32; - let ok = unsafe { - ReadFile( - handle, - buffer.as_mut_ptr().cast(), - to_read, - &mut read, - ptr::null_mut(), - ) - }; - if ok == 0 || read == 0 { - break; - } - } -} - fn run_session_on_current_thread(init: Arc) -> Result<(), String> { crate::diagnostics::startup_log("python-session: init begin"); let state = Arc::new(SessionState::new()); @@ -388,119 +321,6 @@ fn run_session_on_current_thread(init: Arc) -> Result<(), String> { Ok(()) } -fn open_python_runtime() -> Result { - #[cfg(target_family = "unix")] - { - open_python_runtime_with_pty_stdio() - } - - #[cfg(not(target_family = "unix"))] - { - let stdin = open_stdio_file(0, c"r")?; - set_stdio_unbuffered(stdin, 0)?; - let stdout = open_stdio_file(1, c"w")?; - PYTHON_STDIN_FILE.store(stdin, Ordering::SeqCst); - PYTHON_STDOUT_FILE.store(stdout, Ordering::SeqCst); - Ok(PythonRuntime { stdin }) - } -} - -#[cfg(target_family = "unix")] -fn open_python_runtime_with_pty_stdio() -> Result { - ensure_python_pty_stdio()?; - set_fd_close_on_exec(libc::STDIN_FILENO)?; - - let runtime_read_fd = duplicate_stdio_fd(libc::STDIN_FILENO)?; - set_fd_close_on_exec(runtime_read_fd)?; - let stdin = open_stdio_fd(runtime_read_fd, c"r")?; - set_stdio_unbuffered(stdin, runtime_read_fd)?; - let stdout = open_stdio_file(1, c"w")?; - unix_stdin::set_runtime_stdin_fd(runtime_read_fd); - PYTHON_STDIN_FILE.store(stdin, Ordering::SeqCst); - PYTHON_STDOUT_FILE.store(stdout, Ordering::SeqCst); - Ok(PythonRuntime { stdin }) -} - -#[cfg(target_family = "unix")] -fn ensure_python_pty_stdio() -> Result<(), String> { - let missing = [ - (libc::STDIN_FILENO, "stdin"), - (libc::STDOUT_FILENO, "stdout"), - (libc::STDERR_FILENO, "stderr"), - ] - .into_iter() - .filter_map(|(fd, label)| (!stdio_fd_is_tty(fd)).then_some(label)) - .collect::>(); - if missing.is_empty() { - return Ok(()); - } - Err(format!( - "Python PTY stdin transport requires TTY-backed C stdio; non-TTY fds: {}", - missing.join(", ") - )) -} - -#[cfg(target_family = "unix")] -fn stdio_fd_is_tty(fd: libc::c_int) -> bool { - unsafe { libc::isatty(fd) == 1 } -} - -#[cfg(target_family = "unix")] -fn duplicate_stdio_fd(fd: libc::c_int) -> Result { - let duplicated = unsafe { libc::dup(fd) }; - if duplicated < 0 { - Err(format!( - "failed to duplicate worker fd {fd}: {}", - std::io::Error::last_os_error() - )) - } else { - Ok(duplicated) - } -} - -#[cfg(target_family = "unix")] -fn set_fd_close_on_exec(fd: RawFd) -> Result<(), String> { - let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) }; - if flags < 0 { - return Err(format!( - "failed to read fd {fd} close-on-exec flags: {}", - std::io::Error::last_os_error() - )); - } - if unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) } < 0 { - return Err(format!( - "failed to set fd {fd} close-on-exec: {}", - std::io::Error::last_os_error() - )); - } - Ok(()) -} - -fn open_stdio_file(fd: libc::c_int, mode: &CStr) -> Result<*mut libc::FILE, String> { - open_stdio_fd(fd, mode) -} - -fn open_stdio_fd(fd: libc::c_int, mode: &CStr) -> Result<*mut libc::FILE, String> { - let file = unsafe { libc::fdopen(fd, mode.as_ptr()) }; - if file.is_null() { - Err(format!( - "failed to open worker fd {fd} as C stdio FILE: {}", - std::io::Error::last_os_error() - )) - } else { - Ok(file) - } -} - -fn set_stdio_unbuffered(file: *mut libc::FILE, fd: libc::c_int) -> Result<(), String> { - let rc = unsafe { libc::setvbuf(file, ptr::null_mut(), libc::_IONBF, 0) }; - if rc == 0 { - Ok(()) - } else { - Err(format!("failed to configure worker fd {fd} as unbuffered")) - } -} - fn initialize_python( api: &'static PythonApi, executable: &Path, @@ -593,113 +413,6 @@ fn finalize_python( } } -#[cfg(windows)] -fn begin_tracked_turn(turn_id: u64, input: String) -> Result<(), String> { - let state = session_state(); - let queued_lines = prepare_turn_input_lines(&input); - let byte_len = queued_lines - .iter() - .map(|line| line.bytes.len().saturating_sub(line.offset)) - .sum(); - let line_count = queued_lines.len(); - let mut guard = state.inner.lock().unwrap(); - if guard.shutdown { - return Err("Python session is shutting down".to_string()); - } - if guard.active_request.is_some() { - return Err("Python session already has an active turn".to_string()); - } - guard.interrupt_requested = false; - guard.request_completed_at_stdin_wait = false; - guard.request_active = true; - guard.plot_reset_pending = true; - guard.turn_write_in_flight = false; - guard.turn_cleanup_uncertain = false; - let started_after_continuation_prompt = guard.last_prompt_was_continuation; - guard.active_request = Some(ActiveRequest { - turn_id: Some(turn_id), - byte_len, - line_count, - fallback_prompt: None, - queued_lines, - consumed_lines: 0, - skip_next_hook: false, - stdin_write_complete: true, - repl_turn_finished: false, - started_after_continuation_prompt, - }); - state.cvar.notify_all(); - Ok(()) -} - -#[cfg(windows)] -fn append_tracked_turn_input(turn_id: u64, input: String) -> Result<(), String> { - let state = session_state(); - let queued_lines = prepare_turn_input_lines(&input); - let byte_len = queued_lines - .iter() - .map(|line| line.bytes.len().saturating_sub(line.offset)) - .sum(); - let line_count = queued_lines.len(); - let mut guard = state.inner.lock().unwrap(); - if guard.shutdown { - return Err("Python session is shutting down".to_string()); - } - if let Some(active) = guard.active_request.as_mut() { - if active.turn_id != Some(turn_id) { - return Err(format!( - "turn_input turn_id {turn_id} does not match active turn_id {:?}", - active.turn_id - )); - } - active.byte_len = active.byte_len.saturating_add(byte_len); - active.line_count = active.line_count.saturating_add(line_count); - active.queued_lines.extend(queued_lines); - } else { - guard.interrupt_requested = false; - guard.request_completed_at_stdin_wait = false; - guard.request_active = true; - guard.plot_reset_pending = true; - guard.turn_write_in_flight = false; - guard.turn_cleanup_uncertain = false; - let started_after_continuation_prompt = guard.last_prompt_was_continuation; - guard.active_request = Some(ActiveRequest { - turn_id: Some(turn_id), - byte_len, - line_count, - fallback_prompt: None, - queued_lines, - consumed_lines: 0, - skip_next_hook: false, - stdin_write_complete: true, - repl_turn_finished: false, - started_after_continuation_prompt, - }); - } - state.cvar.notify_all(); - Ok(()) -} - -#[cfg(windows)] -fn prepare_turn_input_lines(input: &str) -> VecDeque { - if input.is_empty() { - return VecDeque::new(); - } - let mut input = input.to_string(); - if !input.ends_with('\n') { - input.push('\n'); - } - input - .split_inclusive('\n') - .map(|line| TurnInputLine { - text: line.to_string(), - bytes: line.as_bytes().to_vec(), - offset: 0, - input_line_emitted: false, - }) - .collect() -} - fn set_python_prompts(primary: String, continuation: String) { let Some(state) = SESSION_STATE.get() else { return; @@ -709,162 +422,6 @@ fn set_python_prompts(primary: String, continuation: String) { guard.python_continuation_prompt = continuation; } -fn repl_prompt_for( - current_prompt: Option, - fallback_prompt: Option<&str>, - readline_state: Option, - primary_prompt: &str, - continuation_prompt: &str, -) -> String { - if let Some(prompt) = current_prompt { - return prompt; - } - if fallback_prompt.is_some() - || matches!(readline_state, Some(PythonReadlineState::Continuation)) - { - return continuation_prompt.to_string(); - } - primary_prompt.to_string() -} - -#[cfg_attr(target_family = "unix", allow(dead_code))] -fn input_hook_prompt(guard: &SessionStateInner, fallback_prompt: Option<&str>) -> String { - repl_prompt_for( - guard.current_prompt.clone(), - fallback_prompt, - guard.current_readline_state, - &guard.python_primary_prompt, - &guard.python_continuation_prompt, - ) -} - -#[cfg_attr(not(windows), allow(dead_code))] -fn emit_turn_input_line(turn_id: u64, prompt: &str, line: &mut TurnInputLine) { - if line.input_line_emitted { - return; - } - ipc::emit_input_line(turn_id, prompt, &line.text); - line.input_line_emitted = true; -} - -#[cfg(windows)] -fn read_windows_turn_line( - prompt: &str, - emit_prompt_to_stdout: bool, - release_gil_while_waiting: bool, -) -> Result { - let state = SESSION_STATE - .get() - .ok_or_else(|| "Python session state is not initialized".to_string())?; - let mut idle_repl_prompt_emitted = false; - - loop { - let action = { - let mut guard = state.inner.lock().unwrap(); - guard.waiting_for_input = true; - state.cvar.notify_all(); - - if guard.shutdown || guard.exit_requested { - return Ok(StdioLineRead { - bytes: Vec::new(), - interrupted: false, - }); - } - - if guard.interrupt_requested { - guard.interrupt_requested = false; - return Ok(StdioLineRead { - bytes: Vec::new(), - interrupted: true, - }); - } - - if guard.turn_cleanup_uncertain { - if release_gil_while_waiting { - let allow_threads = PythonThreadsAllowed::new(); - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - drop(allow_threads); - } else { - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - } - continue; - } - - match guard - .active_request - .as_mut() - .and_then(|active| active.turn_id) - { - Some(turn_id) => { - let line = { - let active = guard.active_request.as_mut().expect("active turn exists"); - let line = active.queued_lines.pop_front(); - if line.is_some() { - active.consumed_lines = active.consumed_lines.saturating_add(1); - } - line - }; - if let Some(line) = line { - let prompt_already_visible = - guard.visible_input_prompt.as_deref() == Some(prompt); - guard.visible_input_prompt = None; - guard.waiting_for_input = false; - guard.request_active = true; - Some((turn_id, Some(line), prompt_already_visible)) - } else { - guard.active_request.take(); - guard.visible_input_prompt = Some(prompt.to_string()); - Some((turn_id, None, false)) - } - } - None => { - let should_emit_idle_repl_prompt = !idle_repl_prompt_emitted - && (prompt == guard.python_primary_prompt - || prompt == guard.python_continuation_prompt); - if should_emit_idle_repl_prompt { - idle_repl_prompt_emitted = true; - guard.last_prompt_was_continuation = - prompt == guard.python_continuation_prompt; - drop(guard); - ipc::emit_readline_start(prompt); - continue; - } - if release_gil_while_waiting { - let allow_threads = PythonThreadsAllowed::new(); - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - drop(allow_threads); - } else { - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - } - continue; - } - } - }; - - if let Some((turn_id, Some(mut line), prompt_already_visible)) = action { - emit_turn_input_line(turn_id, prompt, &mut line); - if emit_prompt_to_stdout && !prompt.is_empty() && !prompt_already_visible { - emit_output_text(TextStream::Stdout, prompt.as_bytes()); - } - return Ok(StdioLineRead { - bytes: line.bytes[line.offset..].to_vec(), - interrupted: false, - }); - } - - if let Some((turn_id, None, _)) = action { - emit_plots(); - mark_stdin_wait_prompt_completed_request(); - remember_emitted_prompt(prompt); - ipc::emit_idle(turn_id, prompt); - } - } -} - fn handle_input_hook() { #[cfg(target_family = "unix")] { @@ -989,20 +546,6 @@ fn request_prompt_wait_should_complete( active.consumed_lines >= active.line_count } -#[cfg(target_family = "unix")] -#[cfg_attr(not(test), allow(dead_code))] -fn prompt_wait_can_complete( - active: &ActiveRequest, - current_readline_state: Option, -) -> bool { - active.consumed_lines >= active.line_count - || matches!( - current_readline_state, - Some(PythonReadlineState::ClientInput | PythonReadlineState::Continuation) - ) - || active.fallback_prompt.is_some() -} - fn request_repl_turn_should_complete(active: &ActiveRequest) -> bool { #[cfg(target_family = "unix")] { @@ -1023,7 +566,7 @@ fn request_input_drained(active: &ActiveRequest) -> bool { if !active.stdin_write_complete || active.byte_len == 0 { return false; } - stdin_pending_byte_count() == Some(0) + unix_stdin::stdin_pending_byte_count() == Some(0) } fn finish_repl_turn_request() { @@ -1086,36 +629,9 @@ fn finish_repl_turn_request() { } } -#[cfg(target_family = "unix")] -fn stdin_pending_byte_count() -> Option { - let mut count: libc::c_int = 0; - let rc = unsafe { libc::ioctl(libc::STDIN_FILENO, libc::FIONREAD, &mut count) }; - if rc == 0 && count >= 0 { - Some(count as usize) - } else { - None - } -} - #[cfg(windows)] fn stdin_pending_byte_count() -> Option { - let handle = unsafe { GetStdHandle(STD_INPUT_HANDLE) }; - if handle.is_null() || handle == INVALID_HANDLE_VALUE { - return None; - } - - let mut available = 0u32; - let ok = unsafe { - PeekNamedPipe( - handle, - ptr::null_mut(), - 0, - ptr::null_mut(), - &mut available, - ptr::null_mut(), - ) - }; - (ok != 0).then_some(available as usize) + windows_stdin::stdin_pending_byte_count() } #[cfg(not(any(target_family = "unix", windows)))] @@ -1167,7 +683,7 @@ unsafe extern "C" fn mcp_repl_readline( #[cfg(windows)] flush_original_stdio(); #[cfg(windows)] - let read = match read_windows_turn_line( + let read = match windows_stdin::read_windows_turn_line( &prompt_text, !prompt_text.is_empty() && !suppress_repl_prompt_echo, false, @@ -1251,163 +767,12 @@ fn note_cpython_readline_bytes_read( note_stdin_line_read("", bytes) } -struct StdioLineRead { - bytes: Vec, - interrupted: bool, -} - -#[cfg(not(windows))] -fn read_stdio_line_bytes(stdin: *mut libc::FILE) -> StdioLineRead { - let mut bytes = Vec::new(); - loop { - let ch = unsafe { libc::fgetc(stdin) }; - if ch == libc::EOF { - let interrupted = unsafe { libc::ferror(stdin) != 0 }; - if interrupted { - unsafe { clear_stdio_error(stdin) }; - } - return StdioLineRead { bytes, interrupted }; - } - bytes.push(ch as u8); - if ch == b'\n' as i32 { - return StdioLineRead { - bytes, - interrupted: false, - }; - } - } -} - -#[cfg(not(windows))] -unsafe fn clear_stdio_error(stdin: *mut libc::FILE) { - unsafe { libc::clearerr(stdin) }; -} - -#[cfg(not(windows))] -fn read_stdio_line_bytes_allowing_python_threads(stdin: *mut libc::FILE) -> StdioLineRead { - // _mcp_repl.readline is called from Python with the GIL held. Release it - // while stdin blocks so the IPC completion path can flush prompt-time plots. - let _allow_threads = PythonThreadsAllowed::new(); - read_stdio_line_bytes(stdin) -} - -struct PythonThreadsAllowed { - api: &'static PythonApi, - thread_state: *mut PyThreadState, -} - -impl PythonThreadsAllowed { - fn new() -> Self { - let api = PythonApi::global(); - let thread_state = unsafe { (api.py_eval_save_thread)() }; - assert!( - !thread_state.is_null(), - "PyEval_SaveThread returned a null thread state" - ); - Self { api, thread_state } - } -} - -impl Drop for PythonThreadsAllowed { - fn drop(&mut self) { - unsafe { (self.api.py_eval_restore_thread)(self.thread_state) }; - } -} - -#[derive(Clone, Copy, PartialEq, Eq)] -enum PythonReadlineState { - Primary, - Continuation, - ClientInput, -} - -fn begin_repl_turn() { - let Some(state) = SESSION_STATE.get() else { - return; - }; - let mut guard = state.inner.lock().unwrap(); - guard.repl_readline_count = 0; -} - -fn set_current_repl_readline_prompt(prompt: &str) -> PythonReadlineState { - let Some(state) = SESSION_STATE.get() else { - return PythonReadlineState::Primary; - }; - let mut guard = state.inner.lock().unwrap(); - let started_after_continuation_prompt = guard - .active_request - .as_ref() - .is_some_and(|active| active.started_after_continuation_prompt); - let readline_state = if prompt == guard.python_continuation_prompt - || (prompt.is_empty() && started_after_continuation_prompt) - { - PythonReadlineState::Continuation - } else if guard.repl_readline_count > 0 { - PythonReadlineState::ClientInput - } else { - PythonReadlineState::Primary - }; - guard.repl_readline_count = guard.repl_readline_count.saturating_add(1); - guard.current_prompt = if prompt.is_empty() { - None - } else { - Some(prompt.to_string()) - }; - guard.current_readline_state = Some(readline_state); - readline_state -} - -fn remember_emitted_prompt(prompt: &str) { - let Some(state) = SESSION_STATE.get() else { - return; - }; - let mut guard = state.inner.lock().unwrap(); - guard.last_prompt_was_continuation = prompt == guard.python_continuation_prompt; -} - -fn set_current_readline_prompt(prompt: &str, readline_state: PythonReadlineState) { - let Some(state) = SESSION_STATE.get() else { - return; - }; - let mut guard = state.inner.lock().unwrap(); - guard.current_prompt = Some(prompt.to_string()); - guard.current_readline_state = Some(readline_state); -} - -fn clear_current_readline_prompt() { - let Some(state) = SESSION_STATE.get() else { - return; - }; - let mut guard = state.inner.lock().unwrap(); - guard.current_prompt = None; - guard.current_readline_state = None; -} - enum CStdinLine { Line(String), Eof, Error, } -enum StdinReadAccounting { - Accounted, - #[cfg(target_family = "unix")] - DiscardedAfterInterrupt, -} - -impl StdinReadAccounting { - fn discarded_after_interrupt(&self) -> bool { - #[cfg(target_family = "unix")] - { - matches!(self, Self::DiscardedAfterInterrupt) - } - #[cfg(not(target_family = "unix"))] - { - false - } - } -} - fn read_c_stdin_line(prompt: &str) -> CStdinLine { #[cfg(target_family = "unix")] if ipc::worker_ipc_disabled_for_process() { @@ -1449,7 +814,7 @@ fn read_c_stdin_line(prompt: &str) -> CStdinLine { #[cfg(windows)] flush_original_stdio(); #[cfg(windows)] - let read = match read_windows_turn_line( + let read = match windows_stdin::read_windows_turn_line( prompt_for_sideband.to_str().unwrap_or(""), !prompt.is_empty(), true, @@ -1496,134 +861,9 @@ fn read_raw_stdin_bytes(size: usize) -> Result, RawStdinReadError> { unix_stdin::read_raw_stdin_bytes(size) } -#[cfg_attr(not(windows), allow(dead_code))] -enum RawStdinReadError { - Interrupted, - Runtime(String), -} - #[cfg(windows)] fn read_raw_stdin_bytes(size: usize) -> Result, RawStdinReadError> { - if size == 0 { - return Ok(Vec::new()); - } - take_raw_turn_input_bytes(size) -} - -#[cfg(windows)] -enum RawTurnInputEvent { - InputLine { - turn_id: u64, - prompt: String, - text: String, - }, - Idle { - turn_id: u64, - prompt: String, - }, - Consumed, -} - -#[cfg(windows)] -fn take_raw_turn_input_bytes(size: usize) -> Result, RawStdinReadError> { - let state = SESSION_STATE.get().ok_or_else(|| { - RawStdinReadError::Runtime("Python session state is not initialized".to_string()) - })?; - let mut output = Vec::new(); - while output.len() < size { - let event = { - let mut guard = state.inner.lock().unwrap(); - guard.waiting_for_input = true; - state.cvar.notify_all(); - - if guard.shutdown || guard.exit_requested { - return Ok(output); - } - - if guard.interrupt_requested { - guard.interrupt_requested = false; - return Err(RawStdinReadError::Interrupted); - } - - if guard.turn_cleanup_uncertain { - let allow_threads = PythonThreadsAllowed::new(); - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - drop(allow_threads); - continue; - } - - let prompt = input_hook_prompt(&guard, None); - let Some(turn_id) = guard - .active_request - .as_ref() - .and_then(|active| active.turn_id) - else { - if !output.is_empty() { - return Ok(output); - } - if guard.active_request.is_some() { - return Ok(output); - } - let allow_threads = PythonThreadsAllowed::new(); - guard = state.cvar.wait(guard).unwrap(); - drop(guard); - drop(allow_threads); - continue; - }; - - if guard - .active_request - .as_ref() - .is_some_and(|active| active.queued_lines.is_empty()) - { - if !output.is_empty() { - return Ok(output); - } - guard.active_request.take(); - RawTurnInputEvent::Idle { turn_id, prompt } - } else { - let active = guard.active_request.as_mut().expect("active turn exists"); - let line = active - .queued_lines - .front_mut() - .expect("active turn has queued input"); - let event = if line.input_line_emitted { - RawTurnInputEvent::Consumed - } else { - line.input_line_emitted = true; - RawTurnInputEvent::InputLine { - turn_id, - prompt, - text: line.text.clone(), - } - }; - let available = &line.bytes[line.offset..]; - let take = available.len().min(size - output.len()); - output.extend_from_slice(&available[..take]); - line.offset += take; - if line.offset >= line.bytes.len() { - active.queued_lines.pop_front(); - } - event - } - }; - match event { - RawTurnInputEvent::InputLine { - turn_id, - prompt, - text, - } => ipc::emit_input_line(turn_id, &prompt, &text), - RawTurnInputEvent::Idle { turn_id, prompt } => { - emit_plots(); - mark_stdin_wait_prompt_completed_request(); - remember_emitted_prompt(&prompt); - ipc::emit_idle(turn_id, &prompt); - } - RawTurnInputEvent::Consumed => {} - } - } - Ok(output) + windows_stdin::read_raw_stdin_bytes(size) } #[cfg(not(any(target_family = "unix", windows)))] @@ -1697,14 +937,6 @@ fn record_background_plots() { } } -fn request_active() -> bool { - let Some(state) = SESSION_STATE.get() else { - return false; - }; - let guard = state.inner.lock().unwrap(); - guard.request_active && !guard.request_completed_at_stdin_wait -} - fn flush_original_stdio() { { let _gil = GilGuard::acquire(); @@ -1735,99 +967,6 @@ fn flush_original_stdio() { } } -struct SessionState { - inner: Mutex, - cvar: Condvar, -} - -struct SessionStateInner { - active_request: Option, - request_active: bool, - request_completed_at_stdin_wait: bool, - current_prompt: Option, - current_readline_state: Option, - #[cfg(windows)] - visible_input_prompt: Option, - python_primary_prompt: String, - python_continuation_prompt: String, - repl_readline_count: usize, - last_prompt_was_continuation: bool, - waiting_for_input: bool, - exit_requested: bool, - shutdown: bool, - session_end_emitted: bool, - plot_reset_pending: bool, - interrupt_requested: bool, - #[cfg_attr(not(windows), allow(dead_code))] - turn_write_in_flight: bool, - #[cfg_attr(not(windows), allow(dead_code))] - turn_cleanup_uncertain: bool, - #[cfg(target_family = "unix")] - turn_input: PythonTurnInput, -} - -#[allow(dead_code)] -struct ActiveRequest { - turn_id: Option, - byte_len: usize, - line_count: usize, - fallback_prompt: Option, - queued_lines: VecDeque, - consumed_lines: usize, - skip_next_hook: bool, - stdin_write_complete: bool, - repl_turn_finished: bool, - started_after_continuation_prompt: bool, -} - -struct TurnInputLine { - #[cfg_attr(not(windows), allow(dead_code))] - text: String, - #[cfg_attr(not(windows), allow(dead_code))] - bytes: Vec, - #[cfg_attr(not(windows), allow(dead_code))] - offset: usize, - #[cfg_attr(not(windows), allow(dead_code))] - input_line_emitted: bool, -} - -impl SessionState { - fn new() -> Self { - Self { - inner: Mutex::new(SessionStateInner { - active_request: None, - request_active: false, - request_completed_at_stdin_wait: false, - current_prompt: None, - current_readline_state: None, - #[cfg(windows)] - visible_input_prompt: None, - python_primary_prompt: ">>> ".to_string(), - python_continuation_prompt: "... ".to_string(), - repl_readline_count: 0, - last_prompt_was_continuation: false, - waiting_for_input: false, - exit_requested: false, - shutdown: false, - session_end_emitted: false, - plot_reset_pending: false, - interrupt_requested: false, - turn_write_in_flight: false, - turn_cleanup_uncertain: false, - #[cfg(target_family = "unix")] - turn_input: PythonTurnInput::new(), - }), - cvar: Condvar::new(), - } - } -} - -fn session_state() -> &'static Arc { - SESSION_STATE - .get() - .expect("Python session state was not initialized") -} - fn complete_active_request_with_options( state: &Arc, active: Option, @@ -1875,21 +1014,6 @@ fn emit_output_text(stream: TextStream, bytes: &[u8]) { } } -fn mark_stdin_wait_prompt_completed_request() { - let Some(state) = SESSION_STATE.get() else { - return; - }; - let mut guard = state.inner.lock().unwrap(); - // An input()/sys.stdin.readline() prompt with no buffered answer is the - // response boundary for the current MCP request. The Python read can then - // block while background Python threads keep running. Clear the plot gate at - // this boundary to prevent those background updates from being attributed to - // the request that already completed. Callers flush prompt-time plots before - // closing this gate. - guard.request_active = false; - guard.request_completed_at_stdin_wait = true; -} - unsafe extern "C" fn initialize_mcp_repl_module() -> *mut PyObject { let api = PythonApi::global(); let methods = [ @@ -2137,75 +1261,5 @@ fn set_callback_error(message: &str) { PythonApi::global().set_runtime_error(exception, message); } -static SESSION_STATE: OnceLock> = OnceLock::new(); static SESSION: OnceLock = OnceLock::new(); static RUNTIME_ERROR: AtomicPtr = AtomicPtr::new(ptr::null_mut()); -static PYTHON_STDIN_FILE: AtomicPtr = AtomicPtr::new(ptr::null_mut()); -static PYTHON_STDOUT_FILE: AtomicPtr = AtomicPtr::new(ptr::null_mut()); - -#[cfg(test)] -mod tests { - #[cfg(target_family = "unix")] - use super::*; - - #[cfg(target_family = "unix")] - fn active_request_for_prompt_wait( - line_count: usize, - consumed_lines: usize, - fallback_prompt: Option<&str>, - ) -> ActiveRequest { - ActiveRequest { - turn_id: None, - byte_len: 1, - line_count, - fallback_prompt: fallback_prompt.map(str::to_string), - queued_lines: VecDeque::new(), - consumed_lines, - skip_next_hook: false, - stdin_write_complete: true, - repl_turn_finished: false, - started_after_continuation_prompt: false, - } - } - - #[cfg(target_family = "unix")] - #[test] - fn unix_prompt_wait_requires_progress_for_primary_prompt() { - let active = active_request_for_prompt_wait(3, 1, None); - - assert!(!prompt_wait_can_complete(&active, None)); - } - - #[cfg(target_family = "unix")] - #[test] - fn unix_prompt_wait_allows_client_input_prompt() { - let active = active_request_for_prompt_wait(1, 0, None); - - assert!(prompt_wait_can_complete( - &active, - Some(PythonReadlineState::ClientInput) - )); - } - - #[cfg(target_family = "unix")] - #[test] - fn unix_prompt_wait_allows_continuation_prompt() { - let active = active_request_for_prompt_wait(2, 1, None); - - assert!(prompt_wait_can_complete( - &active, - Some(PythonReadlineState::Continuation) - )); - } - - #[cfg(target_family = "unix")] - #[test] - fn unix_prompt_wait_requires_progress_for_custom_primary_prompt() { - let active = active_request_for_prompt_wait(1, 0, None); - - assert!(!prompt_wait_can_complete( - &active, - Some(PythonReadlineState::Primary) - )); - } -} diff --git a/src/python_session/state.rs b/src/python_session/state.rs new file mode 100644 index 00000000..01abb959 --- /dev/null +++ b/src/python_session/state.rs @@ -0,0 +1,330 @@ +use std::collections::VecDeque; +use std::sync::{Arc, Condvar, Mutex, OnceLock}; + +#[cfg(target_family = "unix")] +use crate::python_turn_input::PythonTurnInput; + +pub(super) static SESSION_STATE: OnceLock> = OnceLock::new(); + +pub(super) struct SessionState { + pub(super) inner: Mutex, + pub(super) cvar: Condvar, +} + +pub(super) struct SessionStateInner { + pub(super) active_request: Option, + pub(super) request_active: bool, + pub(super) request_completed_at_stdin_wait: bool, + pub(super) current_prompt: Option, + pub(super) current_readline_state: Option, + #[cfg(windows)] + pub(super) visible_input_prompt: Option, + pub(super) python_primary_prompt: String, + pub(super) python_continuation_prompt: String, + pub(super) repl_readline_count: usize, + pub(super) last_prompt_was_continuation: bool, + pub(super) waiting_for_input: bool, + pub(super) exit_requested: bool, + pub(super) shutdown: bool, + pub(super) session_end_emitted: bool, + pub(super) plot_reset_pending: bool, + pub(super) interrupt_requested: bool, + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) turn_write_in_flight: bool, + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) turn_cleanup_uncertain: bool, + #[cfg(target_family = "unix")] + pub(super) turn_input: PythonTurnInput, +} + +#[allow(dead_code)] +pub(super) struct ActiveRequest { + pub(super) turn_id: Option, + pub(super) byte_len: usize, + pub(super) line_count: usize, + pub(super) fallback_prompt: Option, + pub(super) queued_lines: VecDeque, + pub(super) consumed_lines: usize, + pub(super) skip_next_hook: bool, + pub(super) stdin_write_complete: bool, + pub(super) repl_turn_finished: bool, + pub(super) started_after_continuation_prompt: bool, +} + +pub(super) struct TurnInputLine { + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) text: String, + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) bytes: Vec, + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) offset: usize, + #[cfg_attr(not(windows), allow(dead_code))] + pub(super) input_line_emitted: bool, +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(super) enum PythonReadlineState { + Primary, + Continuation, + ClientInput, +} + +pub(super) enum StdinReadAccounting { + Accounted, + #[cfg(target_family = "unix")] + DiscardedAfterInterrupt, +} + +impl StdinReadAccounting { + pub(super) fn discarded_after_interrupt(&self) -> bool { + #[cfg(target_family = "unix")] + { + matches!(self, Self::DiscardedAfterInterrupt) + } + #[cfg(not(target_family = "unix"))] + { + false + } + } +} + +#[cfg_attr(not(windows), allow(dead_code))] +pub(super) enum RawStdinReadError { + Interrupted, + Runtime(String), +} + +impl SessionState { + pub(super) fn new() -> Self { + Self { + inner: Mutex::new(SessionStateInner { + active_request: None, + request_active: false, + request_completed_at_stdin_wait: false, + current_prompt: None, + current_readline_state: None, + #[cfg(windows)] + visible_input_prompt: None, + python_primary_prompt: ">>> ".to_string(), + python_continuation_prompt: "... ".to_string(), + repl_readline_count: 0, + last_prompt_was_continuation: false, + waiting_for_input: false, + exit_requested: false, + shutdown: false, + session_end_emitted: false, + plot_reset_pending: false, + interrupt_requested: false, + turn_write_in_flight: false, + turn_cleanup_uncertain: false, + #[cfg(target_family = "unix")] + turn_input: PythonTurnInput::new(), + }), + cvar: Condvar::new(), + } + } +} + +pub(super) fn session_state() -> &'static Arc { + SESSION_STATE + .get() + .expect("Python session state was not initialized") +} + +pub(super) fn repl_prompt_for( + current_prompt: Option, + fallback_prompt: Option<&str>, + readline_state: Option, + primary_prompt: &str, + continuation_prompt: &str, +) -> String { + if let Some(prompt) = current_prompt { + return prompt; + } + if fallback_prompt.is_some() + || matches!(readline_state, Some(PythonReadlineState::Continuation)) + { + return continuation_prompt.to_string(); + } + primary_prompt.to_string() +} + +#[cfg_attr(target_family = "unix", allow(dead_code))] +pub(super) fn input_hook_prompt( + guard: &SessionStateInner, + fallback_prompt: Option<&str>, +) -> String { + repl_prompt_for( + guard.current_prompt.clone(), + fallback_prompt, + guard.current_readline_state, + &guard.python_primary_prompt, + &guard.python_continuation_prompt, + ) +} + +pub(super) fn begin_repl_turn() { + let Some(state) = SESSION_STATE.get() else { + return; + }; + let mut guard = state.inner.lock().unwrap(); + guard.repl_readline_count = 0; +} + +pub(super) fn set_current_repl_readline_prompt(prompt: &str) -> PythonReadlineState { + let Some(state) = SESSION_STATE.get() else { + return PythonReadlineState::Primary; + }; + let mut guard = state.inner.lock().unwrap(); + let started_after_continuation_prompt = guard + .active_request + .as_ref() + .is_some_and(|active| active.started_after_continuation_prompt); + let readline_state = if prompt == guard.python_continuation_prompt + || (prompt.is_empty() && started_after_continuation_prompt) + { + PythonReadlineState::Continuation + } else if guard.repl_readline_count > 0 { + PythonReadlineState::ClientInput + } else { + PythonReadlineState::Primary + }; + guard.repl_readline_count = guard.repl_readline_count.saturating_add(1); + guard.current_prompt = if prompt.is_empty() { + None + } else { + Some(prompt.to_string()) + }; + guard.current_readline_state = Some(readline_state); + readline_state +} + +pub(super) fn set_current_readline_prompt(prompt: &str, readline_state: PythonReadlineState) { + let Some(state) = SESSION_STATE.get() else { + return; + }; + let mut guard = state.inner.lock().unwrap(); + guard.current_prompt = Some(prompt.to_string()); + guard.current_readline_state = Some(readline_state); +} + +pub(super) fn clear_current_readline_prompt() { + let Some(state) = SESSION_STATE.get() else { + return; + }; + let mut guard = state.inner.lock().unwrap(); + guard.current_prompt = None; + guard.current_readline_state = None; +} + +pub(super) fn remember_emitted_prompt(prompt: &str) { + let Some(state) = SESSION_STATE.get() else { + return; + }; + let mut guard = state.inner.lock().unwrap(); + guard.last_prompt_was_continuation = prompt == guard.python_continuation_prompt; +} + +pub(super) fn mark_stdin_wait_prompt_completed_request() { + let Some(state) = SESSION_STATE.get() else { + return; + }; + let mut guard = state.inner.lock().unwrap(); + // An input()/sys.stdin.readline() prompt with no buffered answer is the + // response boundary for the current MCP request. The Python read can then + // block while background Python threads keep running. Clear the plot gate at + // this boundary to prevent those background updates from being attributed to + // the request that already completed. Callers flush prompt-time plots before + // closing this gate. + guard.request_active = false; + guard.request_completed_at_stdin_wait = true; +} + +pub(super) fn request_active() -> bool { + let Some(state) = SESSION_STATE.get() else { + return false; + }; + let guard = state.inner.lock().unwrap(); + guard.request_active && !guard.request_completed_at_stdin_wait +} + +#[cfg(target_family = "unix")] +#[cfg_attr(not(test), allow(dead_code))] +fn prompt_wait_can_complete( + active: &ActiveRequest, + current_readline_state: Option, +) -> bool { + active.consumed_lines >= active.line_count + || matches!( + current_readline_state, + Some(PythonReadlineState::ClientInput | PythonReadlineState::Continuation) + ) + || active.fallback_prompt.is_some() +} + +#[cfg(test)] +mod tests { + #[cfg(target_family = "unix")] + use super::*; + + #[cfg(target_family = "unix")] + fn active_request_for_prompt_wait( + line_count: usize, + consumed_lines: usize, + fallback_prompt: Option<&str>, + ) -> ActiveRequest { + ActiveRequest { + turn_id: None, + byte_len: 1, + line_count, + fallback_prompt: fallback_prompt.map(str::to_string), + queued_lines: VecDeque::new(), + consumed_lines, + skip_next_hook: false, + stdin_write_complete: true, + repl_turn_finished: false, + started_after_continuation_prompt: false, + } + } + + #[cfg(target_family = "unix")] + #[test] + fn unix_prompt_wait_requires_progress_for_primary_prompt() { + let active = active_request_for_prompt_wait(3, 1, None); + + assert!(!prompt_wait_can_complete(&active, None)); + } + + #[cfg(target_family = "unix")] + #[test] + fn unix_prompt_wait_allows_client_input_prompt() { + let active = active_request_for_prompt_wait(1, 0, None); + + assert!(prompt_wait_can_complete( + &active, + Some(PythonReadlineState::ClientInput) + )); + } + + #[cfg(target_family = "unix")] + #[test] + fn unix_prompt_wait_allows_continuation_prompt() { + let active = active_request_for_prompt_wait(2, 1, None); + + assert!(prompt_wait_can_complete( + &active, + Some(PythonReadlineState::Continuation) + )); + } + + #[cfg(target_family = "unix")] + #[test] + fn unix_prompt_wait_requires_progress_for_custom_primary_prompt() { + let active = active_request_for_prompt_wait(1, 0, None); + + assert!(!prompt_wait_can_complete( + &active, + Some(PythonReadlineState::Primary) + )); + } +} diff --git a/src/python_session/stdio.rs b/src/python_session/stdio.rs new file mode 100644 index 00000000..bc89c2b4 --- /dev/null +++ b/src/python_session/stdio.rs @@ -0,0 +1,196 @@ +use std::ffi::CStr; +#[cfg(target_family = "unix")] +use std::os::unix::io::RawFd; +use std::ptr; +use std::sync::atomic::{AtomicPtr, Ordering}; + +use crate::python_ffi::{PyThreadState, PythonApi}; + +#[cfg(target_family = "unix")] +use super::unix_stdin; + +pub(super) static PYTHON_STDIN_FILE: AtomicPtr = AtomicPtr::new(ptr::null_mut()); +pub(super) static PYTHON_STDOUT_FILE: AtomicPtr = AtomicPtr::new(ptr::null_mut()); + +pub(super) struct PythonRuntime { + #[cfg_attr(windows, allow(dead_code))] + pub(super) stdin: *mut libc::FILE, +} + +pub(super) fn open_python_runtime() -> Result { + #[cfg(target_family = "unix")] + { + open_python_runtime_with_pty_stdio() + } + + #[cfg(not(target_family = "unix"))] + { + let stdin = open_stdio_file(0, c"r")?; + set_stdio_unbuffered(stdin, 0)?; + let stdout = open_stdio_file(1, c"w")?; + PYTHON_STDIN_FILE.store(stdin, Ordering::SeqCst); + PYTHON_STDOUT_FILE.store(stdout, Ordering::SeqCst); + Ok(PythonRuntime { stdin }) + } +} + +#[cfg(target_family = "unix")] +fn open_python_runtime_with_pty_stdio() -> Result { + ensure_python_pty_stdio()?; + set_fd_close_on_exec(libc::STDIN_FILENO)?; + + let runtime_read_fd = duplicate_stdio_fd(libc::STDIN_FILENO)?; + set_fd_close_on_exec(runtime_read_fd)?; + let stdin = open_stdio_fd(runtime_read_fd, c"r")?; + set_stdio_unbuffered(stdin, runtime_read_fd)?; + let stdout = open_stdio_file(1, c"w")?; + unix_stdin::set_runtime_stdin_fd(runtime_read_fd); + PYTHON_STDIN_FILE.store(stdin, Ordering::SeqCst); + PYTHON_STDOUT_FILE.store(stdout, Ordering::SeqCst); + Ok(PythonRuntime { stdin }) +} + +#[cfg(target_family = "unix")] +fn ensure_python_pty_stdio() -> Result<(), String> { + let missing = [ + (libc::STDIN_FILENO, "stdin"), + (libc::STDOUT_FILENO, "stdout"), + (libc::STDERR_FILENO, "stderr"), + ] + .into_iter() + .filter_map(|(fd, label)| (!stdio_fd_is_tty(fd)).then_some(label)) + .collect::>(); + if missing.is_empty() { + return Ok(()); + } + Err(format!( + "Python PTY stdin transport requires TTY-backed C stdio; non-TTY fds: {}", + missing.join(", ") + )) +} + +#[cfg(target_family = "unix")] +fn stdio_fd_is_tty(fd: libc::c_int) -> bool { + unsafe { libc::isatty(fd) == 1 } +} + +#[cfg(target_family = "unix")] +fn duplicate_stdio_fd(fd: libc::c_int) -> Result { + let duplicated = unsafe { libc::dup(fd) }; + if duplicated < 0 { + Err(format!( + "failed to duplicate worker fd {fd}: {}", + std::io::Error::last_os_error() + )) + } else { + Ok(duplicated) + } +} + +#[cfg(target_family = "unix")] +fn set_fd_close_on_exec(fd: RawFd) -> Result<(), String> { + let flags = unsafe { libc::fcntl(fd, libc::F_GETFD) }; + if flags < 0 { + return Err(format!( + "failed to read fd {fd} close-on-exec flags: {}", + std::io::Error::last_os_error() + )); + } + if unsafe { libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC) } < 0 { + return Err(format!( + "failed to set fd {fd} close-on-exec: {}", + std::io::Error::last_os_error() + )); + } + Ok(()) +} + +fn open_stdio_file(fd: libc::c_int, mode: &CStr) -> Result<*mut libc::FILE, String> { + open_stdio_fd(fd, mode) +} + +fn open_stdio_fd(fd: libc::c_int, mode: &CStr) -> Result<*mut libc::FILE, String> { + let file = unsafe { libc::fdopen(fd, mode.as_ptr()) }; + if file.is_null() { + Err(format!( + "failed to open worker fd {fd} as C stdio FILE: {}", + std::io::Error::last_os_error() + )) + } else { + Ok(file) + } +} + +fn set_stdio_unbuffered(file: *mut libc::FILE, fd: libc::c_int) -> Result<(), String> { + let rc = unsafe { libc::setvbuf(file, ptr::null_mut(), libc::_IONBF, 0) }; + if rc == 0 { + Ok(()) + } else { + Err(format!("failed to configure worker fd {fd} as unbuffered")) + } +} + +pub(super) struct StdioLineRead { + pub(super) bytes: Vec, + pub(super) interrupted: bool, +} + +#[cfg(not(windows))] +pub(super) fn read_stdio_line_bytes(stdin: *mut libc::FILE) -> StdioLineRead { + let mut bytes = Vec::new(); + loop { + let ch = unsafe { libc::fgetc(stdin) }; + if ch == libc::EOF { + let interrupted = unsafe { libc::ferror(stdin) != 0 }; + if interrupted { + unsafe { clear_stdio_error(stdin) }; + } + return StdioLineRead { bytes, interrupted }; + } + bytes.push(ch as u8); + if ch == b'\n' as i32 { + return StdioLineRead { + bytes, + interrupted: false, + }; + } + } +} + +#[cfg(not(windows))] +unsafe fn clear_stdio_error(stdin: *mut libc::FILE) { + unsafe { libc::clearerr(stdin) }; +} + +#[cfg(not(windows))] +pub(super) fn read_stdio_line_bytes_allowing_python_threads( + stdin: *mut libc::FILE, +) -> StdioLineRead { + // _mcp_repl.readline is called from Python with the GIL held. Release it + // while stdin blocks so the IPC completion path can flush prompt-time plots. + let _allow_threads = PythonThreadsAllowed::new(); + read_stdio_line_bytes(stdin) +} + +pub(super) struct PythonThreadsAllowed { + api: &'static PythonApi, + thread_state: *mut PyThreadState, +} + +impl PythonThreadsAllowed { + pub(super) fn new() -> Self { + let api = PythonApi::global(); + let thread_state = unsafe { (api.py_eval_save_thread)() }; + assert!( + !thread_state.is_null(), + "PyEval_SaveThread returned a null thread state" + ); + Self { api, thread_state } + } +} + +impl Drop for PythonThreadsAllowed { + fn drop(&mut self) { + unsafe { (self.api.py_eval_restore_thread)(self.thread_state) }; + } +} diff --git a/src/python_session/unix_stdin.rs b/src/python_session/unix_stdin.rs index 43369d1a..ac661332 100644 --- a/src/python_session/unix_stdin.rs +++ b/src/python_session/unix_stdin.rs @@ -5,10 +5,14 @@ use crate::python_turn_input::{PtyFeed, normalize_pty_turn_payload}; use crate::stdin_payload::prepare_worker_stdin_payload; use crate::worker_protocol::TextStream; +use super::state::{ + PythonReadlineState, RawStdinReadError, SESSION_STATE, SessionStateInner, StdinReadAccounting, + mark_stdin_wait_prompt_completed_request, +}; +use super::stdio::PythonThreadsAllowed; use super::{ - CStdinLine, PythonReadlineState, PythonThreadsAllowed, RawStdinReadError, SESSION_STATE, - SessionStateInner, StdinReadAccounting, emit_output_text, emit_plots, flush_original_stdio, - mark_stdin_wait_prompt_completed_request, record_background_plots, set_callback_error, + CStdinLine, emit_output_text, emit_plots, flush_original_stdio, record_background_plots, + set_callback_error, }; static PYTHON_RUNTIME_STDIN_FD: AtomicI32 = AtomicI32::new(-1); @@ -461,7 +465,7 @@ fn mark_request_input_delivered() { guard.waiting_for_input = false; } -fn stdin_pending_byte_count() -> Option { +pub(super) fn stdin_pending_byte_count() -> Option { let mut count: libc::c_int = 0; let rc = unsafe { libc::ioctl(libc::STDIN_FILENO, libc::FIONREAD, &mut count) }; if rc == 0 && count >= 0 { diff --git a/src/python_session/windows_stdin.rs b/src/python_session/windows_stdin.rs new file mode 100644 index 00000000..77fd2482 --- /dev/null +++ b/src/python_session/windows_stdin.rs @@ -0,0 +1,462 @@ +use std::collections::VecDeque; +use std::ptr; +use std::sync::atomic::Ordering; + +use windows_sys::Win32::Foundation::INVALID_HANDLE_VALUE; +use windows_sys::Win32::Storage::FileSystem::ReadFile; +use windows_sys::Win32::System::Console::{GetStdHandle, STD_INPUT_HANDLE}; +use windows_sys::Win32::System::Pipes::PeekNamedPipe; + +use crate::ipc; +use crate::worker_protocol::TextStream; + +use super::state::{ + ActiveRequest, RawStdinReadError, SESSION_STATE, TurnInputLine, input_hook_prompt, + mark_stdin_wait_prompt_completed_request, remember_emitted_prompt, session_state, +}; +use super::stdio::{PYTHON_STDIN_FILE, PythonThreadsAllowed, StdioLineRead}; +use super::{emit_output_text, emit_plots, request_platform_interrupt}; + +pub(super) fn interrupt_turn(turn_id: u64) { + let Some(state) = SESSION_STATE.get() else { + return; + }; + { + let mut guard = state.inner.lock().unwrap(); + let write_in_flight = guard.turn_write_in_flight; + let Some(active) = guard.active_request.as_mut() else { + return; + }; + if active.turn_id != Some(turn_id) { + return; + } + active.queued_lines.clear(); + if write_in_flight { + guard.turn_cleanup_uncertain = true; + } + guard.interrupt_requested = true; + guard.waiting_for_input = false; + state.cvar.notify_all(); + } + request_platform_interrupt(); +} + +pub(super) fn begin_tracked_turn(turn_id: u64, input: String) -> Result<(), String> { + let state = session_state(); + let queued_lines = prepare_turn_input_lines(&input); + let byte_len = queued_lines + .iter() + .map(|line| line.bytes.len().saturating_sub(line.offset)) + .sum(); + let line_count = queued_lines.len(); + let mut guard = state.inner.lock().unwrap(); + if guard.shutdown { + return Err("Python session is shutting down".to_string()); + } + if guard.active_request.is_some() { + return Err("Python session already has an active turn".to_string()); + } + guard.interrupt_requested = false; + guard.request_completed_at_stdin_wait = false; + guard.request_active = true; + guard.plot_reset_pending = true; + guard.turn_write_in_flight = false; + guard.turn_cleanup_uncertain = false; + let started_after_continuation_prompt = guard.last_prompt_was_continuation; + guard.active_request = Some(ActiveRequest { + turn_id: Some(turn_id), + byte_len, + line_count, + fallback_prompt: None, + queued_lines, + consumed_lines: 0, + skip_next_hook: false, + stdin_write_complete: true, + repl_turn_finished: false, + started_after_continuation_prompt, + }); + state.cvar.notify_all(); + Ok(()) +} + +pub(super) fn append_tracked_turn_input(turn_id: u64, input: String) -> Result<(), String> { + let state = session_state(); + let queued_lines = prepare_turn_input_lines(&input); + let byte_len = queued_lines + .iter() + .map(|line| line.bytes.len().saturating_sub(line.offset)) + .sum(); + let line_count = queued_lines.len(); + let mut guard = state.inner.lock().unwrap(); + if guard.shutdown { + return Err("Python session is shutting down".to_string()); + } + if let Some(active) = guard.active_request.as_mut() { + if active.turn_id != Some(turn_id) { + return Err(format!( + "turn_input turn_id {turn_id} does not match active turn_id {:?}", + active.turn_id + )); + } + active.byte_len = active.byte_len.saturating_add(byte_len); + active.line_count = active.line_count.saturating_add(line_count); + active.queued_lines.extend(queued_lines); + } else { + guard.interrupt_requested = false; + guard.request_completed_at_stdin_wait = false; + guard.request_active = true; + guard.plot_reset_pending = true; + guard.turn_write_in_flight = false; + guard.turn_cleanup_uncertain = false; + let started_after_continuation_prompt = guard.last_prompt_was_continuation; + guard.active_request = Some(ActiveRequest { + turn_id: Some(turn_id), + byte_len, + line_count, + fallback_prompt: None, + queued_lines, + consumed_lines: 0, + skip_next_hook: false, + stdin_write_complete: true, + repl_turn_finished: false, + started_after_continuation_prompt, + }); + } + state.cvar.notify_all(); + Ok(()) +} + +fn prepare_turn_input_lines(input: &str) -> VecDeque { + if input.is_empty() { + return VecDeque::new(); + } + let mut input = input.to_string(); + if !input.ends_with('\n') { + input.push('\n'); + } + input + .split_inclusive('\n') + .map(|line| TurnInputLine { + text: line.to_string(), + bytes: line.as_bytes().to_vec(), + offset: 0, + input_line_emitted: false, + }) + .collect() +} + +fn emit_turn_input_line(turn_id: u64, prompt: &str, line: &mut TurnInputLine) { + if line.input_line_emitted { + return; + } + ipc::emit_input_line(turn_id, prompt, &line.text); + line.input_line_emitted = true; +} + +pub(super) fn read_windows_turn_line( + prompt: &str, + emit_prompt_to_stdout: bool, + release_gil_while_waiting: bool, +) -> Result { + let state = SESSION_STATE + .get() + .ok_or_else(|| "Python session state is not initialized".to_string())?; + let mut idle_repl_prompt_emitted = false; + + loop { + let action = { + let mut guard = state.inner.lock().unwrap(); + guard.waiting_for_input = true; + state.cvar.notify_all(); + + if guard.shutdown || guard.exit_requested { + return Ok(StdioLineRead { + bytes: Vec::new(), + interrupted: false, + }); + } + + if guard.interrupt_requested { + guard.interrupt_requested = false; + return Ok(StdioLineRead { + bytes: Vec::new(), + interrupted: true, + }); + } + + if guard.turn_cleanup_uncertain { + if release_gil_while_waiting { + let allow_threads = PythonThreadsAllowed::new(); + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + drop(allow_threads); + } else { + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + } + continue; + } + + match guard + .active_request + .as_mut() + .and_then(|active| active.turn_id) + { + Some(turn_id) => { + let line = { + let active = guard.active_request.as_mut().expect("active turn exists"); + let line = active.queued_lines.pop_front(); + if line.is_some() { + active.consumed_lines = active.consumed_lines.saturating_add(1); + } + line + }; + if let Some(line) = line { + let prompt_already_visible = + guard.visible_input_prompt.as_deref() == Some(prompt); + guard.visible_input_prompt = None; + guard.waiting_for_input = false; + guard.request_active = true; + Some((turn_id, Some(line), prompt_already_visible)) + } else { + guard.active_request.take(); + guard.visible_input_prompt = Some(prompt.to_string()); + Some((turn_id, None, false)) + } + } + None => { + let should_emit_idle_repl_prompt = !idle_repl_prompt_emitted + && (prompt == guard.python_primary_prompt + || prompt == guard.python_continuation_prompt); + if should_emit_idle_repl_prompt { + idle_repl_prompt_emitted = true; + guard.last_prompt_was_continuation = + prompt == guard.python_continuation_prompt; + drop(guard); + ipc::emit_readline_start(prompt); + continue; + } + if release_gil_while_waiting { + let allow_threads = PythonThreadsAllowed::new(); + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + drop(allow_threads); + } else { + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + } + continue; + } + } + }; + + if let Some((turn_id, Some(mut line), prompt_already_visible)) = action { + emit_turn_input_line(turn_id, prompt, &mut line); + if emit_prompt_to_stdout && !prompt.is_empty() && !prompt_already_visible { + emit_output_text(TextStream::Stdout, prompt.as_bytes()); + } + return Ok(StdioLineRead { + bytes: line.bytes[line.offset..].to_vec(), + interrupted: false, + }); + } + + if let Some((turn_id, None, _)) = action { + emit_plots(); + mark_stdin_wait_prompt_completed_request(); + remember_emitted_prompt(prompt); + ipc::emit_idle(turn_id, prompt); + } + } +} + +pub(super) fn discard_pending_stdin() { + let stdin = PYTHON_STDIN_FILE.load(Ordering::SeqCst); + if !stdin.is_null() { + unsafe { + libc::fflush(stdin); + } + } + drain_stdin_pipe(); +} + +fn drain_stdin_pipe() { + let handle = unsafe { GetStdHandle(STD_INPUT_HANDLE) }; + if handle.is_null() || handle == INVALID_HANDLE_VALUE { + return; + } + + let mut buffer = [0u8; 8192]; + loop { + let mut available = 0u32; + let ok = unsafe { + PeekNamedPipe( + handle, + ptr::null_mut(), + 0, + ptr::null_mut(), + &mut available, + ptr::null_mut(), + ) + }; + if ok == 0 || available == 0 { + break; + } + + let to_read = available.min(buffer.len() as u32); + let mut read = 0u32; + let ok = unsafe { + ReadFile( + handle, + buffer.as_mut_ptr().cast(), + to_read, + &mut read, + ptr::null_mut(), + ) + }; + if ok == 0 || read == 0 { + break; + } + } +} + +pub(super) fn stdin_pending_byte_count() -> Option { + let handle = unsafe { GetStdHandle(STD_INPUT_HANDLE) }; + if handle.is_null() || handle == INVALID_HANDLE_VALUE { + return None; + } + + let mut available = 0u32; + let ok = unsafe { + PeekNamedPipe( + handle, + ptr::null_mut(), + 0, + ptr::null_mut(), + &mut available, + ptr::null_mut(), + ) + }; + (ok != 0).then_some(available as usize) +} + +pub(super) fn read_raw_stdin_bytes(size: usize) -> Result, RawStdinReadError> { + if size == 0 { + return Ok(Vec::new()); + } + take_raw_turn_input_bytes(size) +} + +enum RawTurnInputEvent { + InputLine { + turn_id: u64, + prompt: String, + text: String, + }, + Idle { + turn_id: u64, + prompt: String, + }, + Consumed, +} + +fn take_raw_turn_input_bytes(size: usize) -> Result, RawStdinReadError> { + let state = SESSION_STATE.get().ok_or_else(|| { + RawStdinReadError::Runtime("Python session state is not initialized".to_string()) + })?; + let mut output = Vec::new(); + while output.len() < size { + let event = { + let mut guard = state.inner.lock().unwrap(); + guard.waiting_for_input = true; + state.cvar.notify_all(); + + if guard.shutdown || guard.exit_requested { + return Ok(output); + } + + if guard.interrupt_requested { + guard.interrupt_requested = false; + return Err(RawStdinReadError::Interrupted); + } + + if guard.turn_cleanup_uncertain { + let allow_threads = PythonThreadsAllowed::new(); + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + drop(allow_threads); + continue; + } + + let prompt = input_hook_prompt(&guard, None); + let Some(turn_id) = guard + .active_request + .as_ref() + .and_then(|active| active.turn_id) + else { + if !output.is_empty() { + return Ok(output); + } + if guard.active_request.is_some() { + return Ok(output); + } + let allow_threads = PythonThreadsAllowed::new(); + guard = state.cvar.wait(guard).unwrap(); + drop(guard); + drop(allow_threads); + continue; + }; + + if guard + .active_request + .as_ref() + .is_some_and(|active| active.queued_lines.is_empty()) + { + if !output.is_empty() { + return Ok(output); + } + guard.active_request.take(); + RawTurnInputEvent::Idle { turn_id, prompt } + } else { + let active = guard.active_request.as_mut().expect("active turn exists"); + let line = active + .queued_lines + .front_mut() + .expect("active turn has queued input"); + let event = if line.input_line_emitted { + RawTurnInputEvent::Consumed + } else { + line.input_line_emitted = true; + RawTurnInputEvent::InputLine { + turn_id, + prompt, + text: line.text.clone(), + } + }; + let available = &line.bytes[line.offset..]; + let take = available.len().min(size - output.len()); + output.extend_from_slice(&available[..take]); + line.offset += take; + if line.offset >= line.bytes.len() { + active.queued_lines.pop_front(); + } + event + } + }; + match event { + RawTurnInputEvent::InputLine { + turn_id, + prompt, + text, + } => ipc::emit_input_line(turn_id, &prompt, &text), + RawTurnInputEvent::Idle { turn_id, prompt } => { + emit_plots(); + mark_stdin_wait_prompt_completed_request(); + remember_emitted_prompt(&prompt); + ipc::emit_idle(turn_id, &prompt); + } + RawTurnInputEvent::Consumed => {} + } + } + Ok(output) +}