diff --git a/sh/src/builtin/read.rs b/sh/src/builtin/read.rs index f6c15c4fb..6935ab748 100644 --- a/sh/src/builtin/read.rs +++ b/sh/src/builtin/read.rs @@ -81,7 +81,7 @@ fn read_until_from_non_blocking_fd( } } // might receive signals while reading - shell.update_global_state(); + shell.handle_async_events(); std::thread::sleep(Duration::from_millis(16)); } if !buffer.is_empty() { diff --git a/sh/src/builtin/trap.rs b/sh/src/builtin/trap.rs index a05bf580a..08996477b 100644 --- a/sh/src/builtin/trap.rs +++ b/sh/src/builtin/trap.rs @@ -14,7 +14,7 @@ use crate::signals::Signal; use std::fmt::Display; use std::str::FromStr; -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub enum TrapAction { Default, Ignore, diff --git a/sh/src/cli/vi/mod.rs b/sh/src/cli/vi/mod.rs index bde7cf34b..37ced06e3 100644 --- a/sh/src/cli/vi/mod.rs +++ b/sh/src/cli/vi/mod.rs @@ -821,6 +821,13 @@ impl ViEditor { } Ok(Action::None) } + + pub fn reset_current_line(&mut self) { + self.edit_line.clear(); + self.cursor.position = 0; + self.current_history_command = 0; + self.mode = EditorMode::Insert; + } } impl Default for ViEditor { diff --git a/sh/src/main.rs b/sh/src/main.rs index 29fa0e36e..9240c5d5b 100644 --- a/sh/src/main.rs +++ b/sh/src/main.rs @@ -11,13 +11,13 @@ use crate::cli::args::{parse_args, ExecutionMode}; use crate::cli::terminal::is_attached_to_terminal; use crate::cli::{clear_line, set_cursor_pos}; use crate::shell::Shell; -use crate::signals::setup_signal_handling; +use crate::signals::{ + handle_signal_ignore, handle_signal_write_to_signal_buffer, setup_signal_handling, Signal, +}; use crate::utils::is_process_in_foreground; use cli::terminal::read_nonblocking_char; use cli::vi::{Action, ViEditor}; use gettextrs::{bind_textdomain_codeset, setlocale, textdomain, LocaleCategory}; -use nix::sys::signal::{sigaction, SaFlags, SigAction, SigSet}; -use nix::sys::signal::{SigHandler, Signal as NixSignal}; use std::error::Error; use std::io; use std::io::Write; @@ -140,7 +140,14 @@ fn standard_repl(shell: &mut Shell) { flush_stdout(); } std::thread::sleep(Duration::from_millis(16)); - shell.update_global_state(); + shell.signal_manager.reset_sigint_count(); + shell.handle_async_events(); + if shell.signal_manager.get_sigint_count() > 0 { + program_buffer.clear(); + line_buffer.clear(); + println!(); + eprint!("{}", shell.get_ps1()); + } if shell.set_options.vi { return; } @@ -149,7 +156,7 @@ fn standard_repl(shell: &mut Shell) { fn vi_repl(shell: &mut Shell) { let mut editor = ViEditor::default(); - let mut buffer = Vec::new(); + let mut program_buffer = Vec::new(); let mut print_ps2 = false; clear_line(); flush_stdout(); @@ -158,15 +165,15 @@ fn vi_repl(shell: &mut Shell) { while let Some(c) = read_nonblocking_char() { match editor.process_new_input(c, shell) { Ok(Action::Execute(command)) => { - buffer.extend(command.iter()); - if buffer.ends_with(b"\\\n") { + program_buffer.extend(command.iter()); + if program_buffer.ends_with(b"\\\n") { continue; } - let program_string = match std::str::from_utf8(&buffer) { + let program_string = match std::str::from_utf8(&program_buffer) { Ok(buf) => buf, Err(_) => { eprintln!("sh: invalid utf-8 sequence"); - buffer.clear(); + program_buffer.clear(); continue; } }; @@ -174,13 +181,13 @@ fn vi_repl(shell: &mut Shell) { shell.terminal.reset(); match shell.execute_program(program_string) { Ok(_) => { - buffer.clear(); + program_buffer.clear(); print_ps2 = false; } Err(syntax_err) => { if !syntax_err.could_be_resolved_with_more_input { eprintln!("sh: syntax error: {}", syntax_err.message); - buffer.clear(); + program_buffer.clear(); } else { print_ps2 = true; } @@ -205,7 +212,14 @@ fn vi_repl(shell: &mut Shell) { flush_stdout() } std::thread::sleep(Duration::from_millis(16)); - shell.update_global_state(); + shell.signal_manager.reset_sigint_count(); + shell.handle_async_events(); + if shell.signal_manager.get_sigint_count() > 0 { + program_buffer.clear(); + editor.reset_current_line(); + println!(); + eprint!("{}", shell.get_ps1()); + } if !shell.set_options.vi { return; } @@ -218,24 +232,14 @@ fn interactive_shell(shell: &mut Shell) { nix::unistd::tcsetpgrp(io::stdin().as_fd(), pgid).unwrap(); } shell.terminal.set_nonblocking_no_echo(); - let ignore_action = SigAction::new(SigHandler::SigIgn, SaFlags::empty(), SigSet::empty()); - unsafe { - sigaction(NixSignal::SIGQUIT, &ignore_action).unwrap(); - } - unsafe { - sigaction(NixSignal::SIGTERM, &ignore_action).unwrap(); - } + unsafe { handle_signal_ignore(Signal::SigQuit) } + unsafe { handle_signal_ignore(Signal::SigTerm) } + unsafe { handle_signal_write_to_signal_buffer(Signal::SigInt) } if shell.set_options.monitor { // job control signals - unsafe { - sigaction(NixSignal::SIGTTIN, &ignore_action).unwrap(); - } - unsafe { - sigaction(NixSignal::SIGTTOU, &ignore_action).unwrap(); - } - unsafe { - sigaction(NixSignal::SIGTSTP, &ignore_action).unwrap(); - } + unsafe { handle_signal_ignore(Signal::SigTtin) } + unsafe { handle_signal_ignore(Signal::SigTtou) } + unsafe { handle_signal_ignore(Signal::SigTstp) } } loop { if shell.set_options.vi { diff --git a/sh/src/shell/mod.rs b/sh/src/shell/mod.rs index 870511067..a1184f2c6 100644 --- a/sh/src/shell/mod.rs +++ b/sh/src/shell/mod.rs @@ -34,8 +34,10 @@ use crate::utils::{ use crate::wordexp::{expand_word, expand_word_to_string, word_to_pattern}; use nix::errno::Errno; use nix::libc; +use nix::sys::signal::kill; +use nix::sys::signal::Signal as NixSignal; use nix::sys::wait::{WaitPidFlag, WaitStatus}; -use nix::unistd::{getcwd, getpgrp, getpid, getppid, setpgid, ForkResult, Pid}; +use nix::unistd::{getcwd, getpgid, getpgrp, getpid, getppid, setpgid, tcsetpgrp, ForkResult, Pid}; use std::collections::HashMap; use std::ffi::{CString, OsString}; use std::fmt::{Display, Formatter}; @@ -217,7 +219,7 @@ impl Shell { return Ok(signal_to_exit_status(signal)); } WaitStatus::StillAlive => { - self.update_global_state(); + self.handle_async_events(); std::thread::sleep(Duration::from_millis(16)); } _ => unreachable!(), @@ -225,7 +227,7 @@ impl Shell { } } - pub fn update_global_state(&mut self) { + pub fn handle_async_events(&mut self) { self.process_signals(); if self.set_options.monitor { if let Err(err) = self.background_jobs.update_jobs() { @@ -262,8 +264,8 @@ impl Shell { } pub fn process_signals(&mut self) { - while let Some(action) = self.signal_manager.get_pending_action() { - self.execute_action(action.clone()) + while let Some(action) = self.signal_manager.get_pending_action().cloned() { + self.execute_action(action) } } @@ -729,17 +731,22 @@ impl Shell { match fork()? { ForkResult::Child => { self.become_subshell(); - setpgid(Pid::from_raw(0), Pid::from_raw(0)) - .expect("failed to create new process group for pipeline"); + // this should never fail as both arguments are valid + setpgid(Pid::from_raw(0), Pid::from_raw(0)).unwrap(); let pipeline_pgid = getpgrp(); + // wait for the parent process to put the subshell in the foreground + if let Err(err) = kill(Pid::from_raw(0), NixSignal::SIGTSTP) { + self.eprint(&format!("sh: internal call to kill failed ({err})")); + self.exit(1); + } let mut current_stdin = libc::STDIN_FILENO; for command in pipeline.commands.head() { let (read_pipe, write_pipe) = pipe()?; match fork()? { ForkResult::Child => { - setpgid(Pid::from_raw(0), pipeline_pgid) - .expect("failed to set pipeline pgid"); + // should never fail as `pipeline_pgid` is a valid process group + setpgid(Pid::from_raw(0), pipeline_pgid).unwrap(); drop(read_pipe); dup2(current_stdin, libc::STDIN_FILENO)?; dup2(write_pipe.as_raw_fd(), libc::STDOUT_FILENO)?; @@ -763,12 +770,46 @@ impl Shell { self.exit(return_status); } ForkResult::Parent { child } => { - if is_process_in_foreground() { - nix::unistd::tcsetpgrp(io::stdin().as_fd(), child).unwrap(); - pipeline_exit_status = self.wait_child_process(child)?; - nix::unistd::tcsetpgrp(io::stdin().as_fd(), getpgrp()).unwrap(); - } else { - pipeline_exit_status = self.wait_child_process(child)?; + loop { + match waitpid(child, Some(WaitPidFlag::WNOHANG | WaitPidFlag::WUNTRACED))? { + WaitStatus::Exited(_, _) => { + // the only way this happened is if there was an error before going + // the child went to sleep + return Ok(1); + } + WaitStatus::Signaled(_, _, _) => { + self.eprint("sh: unsynchronised pipeline was terminated by another process\n"); + return Ok(1); + } + WaitStatus::Continued(_) => { + self.eprint("sh: unsynchronised pipeline was restarted by another process\n"); + return Ok(1); + } + WaitStatus::Stopped(_, _) => { + if is_process_in_foreground() { + // should never fail as child is a valid process id and + // in the same session as the current shell + let child_gpid = getpgid(Some(child)).unwrap(); + // should never fail as stdin is a valid file descriptor and + // child gpid is valid and in the same session + tcsetpgrp(io::stdin().as_fd(), child_gpid).unwrap(); + kill(child, NixSignal::SIGCONT).unwrap(); + pipeline_exit_status = self.wait_child_process(child)?; + // should never fail + tcsetpgrp(io::stdin().as_fd(), getpgrp()).unwrap(); + break; + } else { + kill(child, NixSignal::SIGCONT).unwrap(); + pipeline_exit_status = self.wait_child_process(child)?; + break; + } + } + WaitStatus::StillAlive => { + self.handle_async_events(); + std::thread::sleep(Duration::from_millis(16)); + } + _ => unreachable!(), + } } } } @@ -958,6 +999,7 @@ impl Shell { history, set_options, is_interactive, + signal_manager: SignalManager::new(is_interactive), ..Default::default() } } @@ -1019,7 +1061,7 @@ impl Default for Shell { is_interactive: false, last_lineno: 0, exit_action: TrapAction::Default, - signal_manager: SignalManager::default(), + signal_manager: SignalManager::new(false), background_jobs: JobManager::default(), history: History::new(32767), umask: !0o022 & 0o777, diff --git a/sh/src/signals.rs b/sh/src/signals.rs index dc67cfbee..26dc86081 100644 --- a/sh/src/signals.rs +++ b/sh/src/signals.rs @@ -260,7 +260,7 @@ pub const SIGNALS: &[Signal] = &[ static mut SIGNAL_WRITE: Option = None; static mut SIGNAL_READ: Option = None; -extern "C" fn handle_signals(signal: libc::c_int) { +extern "C" fn write_signal_to_buffer(signal: libc::c_int) { // SIGNAL_WRITE is never modified after the initial // setup, and is a valid file descriptor, so this is safe let fd = unsafe { BorrowedFd::borrow_raw(SIGNAL_WRITE.unwrap()) }; @@ -304,12 +304,50 @@ fn get_pending_signal() -> Option { } } +pub unsafe fn handle_signal_ignore(signal: Signal) { + sigaction( + signal.into(), + &SigAction::new(SigHandler::SigIgn, SaFlags::empty(), SigSet::empty()), + ) + .unwrap(); +} + +pub unsafe fn handle_signal_default(signal: Signal) { + sigaction( + signal.into(), + &SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty()), + ) + .unwrap(); +} + +pub unsafe fn handle_signal_write_to_signal_buffer(signal: Signal) { + sigaction( + signal.into(), + &SigAction::new( + SigHandler::Handler(write_signal_to_buffer), + SaFlags::empty(), + SigSet::empty(), + ), + ) + .unwrap(); +} + #[derive(Clone)] pub struct SignalManager { actions: [TrapAction; Signal::Count as usize], + is_interactive: bool, + sigint_count: u32, } impl SignalManager { + pub fn new(is_interactive: bool) -> Self { + Self { + actions: [const { TrapAction::Default }; Signal::Count as usize], + is_interactive, + sigint_count: 0, + } + } + pub fn reset(&mut self) { for signal in SIGNALS { let signal = *signal; @@ -336,44 +374,35 @@ impl SignalManager { pub fn set_action(&mut self, signal: Signal, action: TrapAction) { assert!(signal != Signal::SigKill && signal != Signal::SigStop); + + if self.is_interactive + && signal == Signal::SigInt + && (action == TrapAction::Ignore || action == TrapAction::Default) + { + // in interactive mode we always want catch sigint + unsafe { handle_signal_write_to_signal_buffer(Signal::SigInt) }; + self.actions[signal as usize] = action; + return; + } match action { TrapAction::Default => { - unsafe { - sigaction( - signal.into(), - &SigAction::new(SigHandler::SigDfl, SaFlags::empty(), SigSet::empty()), - ) - .unwrap() - }; + unsafe { handle_signal_default(signal) }; } TrapAction::Ignore => { - unsafe { - sigaction( - signal.into(), - &SigAction::new(SigHandler::SigIgn, SaFlags::empty(), SigSet::empty()), - ) - .unwrap() - }; + unsafe { handle_signal_ignore(signal) }; } TrapAction::Commands(_) => { - unsafe { - sigaction( - signal.into(), - &SigAction::new( - SigHandler::Handler(handle_signals), - SaFlags::empty(), - SigSet::empty(), - ), - ) - .unwrap() - }; + unsafe { handle_signal_write_to_signal_buffer(signal) }; } } self.actions[signal as usize] = action; } - pub fn get_pending_action(&self) -> Option<&TrapAction> { + pub fn get_pending_action(&mut self) -> Option<&TrapAction> { if let Some(signal) = get_pending_signal() { + if signal == Signal::SigInt { + self.sigint_count += 1; + } Some(&self.actions[signal as usize]) } else { None @@ -385,12 +414,12 @@ impl SignalManager { .iter() .map(move |&signal| (signal, &self.actions[signal as usize])) } -} -impl Default for SignalManager { - fn default() -> Self { - Self { - actions: [const { TrapAction::Default }; Signal::Count as usize], - } + pub fn reset_sigint_count(&mut self) { + self.sigint_count = 0; + } + + pub fn get_sigint_count(&self) -> u32 { + self.sigint_count } }