diff --git a/src/uu/stty/src/flags.rs b/src/uu/stty/src/flags.rs index c10e7c04b39..17ea52da174 100644 --- a/src/uu/stty/src/flags.rs +++ b/src/uu/stty/src/flags.rs @@ -27,6 +27,12 @@ use nix::sys::termios::{ SpecialCharacterIndices as S, }; +pub enum BaudType { + Input, + Output, + Both, +} + pub enum AllFlags<'a> { #[cfg(any( target_os = "freebsd", @@ -36,7 +42,7 @@ pub enum AllFlags<'a> { target_os = "netbsd", target_os = "openbsd" ))] - Baud(u32), + Baud(u32, BaudType), #[cfg(not(any( target_os = "freebsd", target_os = "dragonfly", @@ -45,7 +51,7 @@ pub enum AllFlags<'a> { target_os = "netbsd", target_os = "openbsd" )))] - Baud(BaudRate), + Baud(BaudRate, BaudType), ControlFlags((&'a Flag, bool)), InputFlags((&'a Flag, bool)), LocalFlags((&'a Flag, bool)), diff --git a/src/uu/stty/src/stty.rs b/src/uu/stty/src/stty.rs index 8b8da5135a2..acd6cddfccf 100644 --- a/src/uu/stty/src/stty.rs +++ b/src/uu/stty/src/stty.rs @@ -10,7 +10,7 @@ // spell-checker:ignore isig icanon iexten echoe crterase echok echonl noflsh xcase tostop echoprt prterase echoctl ctlecho echoke crtkill flusho extproc // spell-checker:ignore lnext rprnt susp swtch vdiscard veof veol verase vintr vkill vlnext vquit vreprint vstart vstop vsusp vswtc vwerase werase // spell-checker:ignore sigquit sigtstp -// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS +// spell-checker:ignore cbreak decctlq evenp litout oddp tcsadrain exta extb NCCS cfsetispeed mod flags; @@ -20,7 +20,7 @@ use clap::{Arg, ArgAction, ArgMatches, Command}; use nix::libc::{O_NONBLOCK, TIOCGWINSZ, TIOCSWINSZ, c_ushort}; use nix::sys::termios::{ ControlFlags, InputFlags, LocalFlags, OutputFlags, SetArg, SpecialCharacterIndices as S, - Termios, cfgetospeed, cfsetospeed, tcgetattr, tcsetattr, + Termios, cfgetospeed, cfsetispeed, cfsetospeed, tcgetattr, tcsetattr, }; use nix::{ioctl_read_bad, ioctl_write_ptr_bad}; use std::cmp::Ordering; @@ -273,19 +273,24 @@ fn stty(opts: &Options) -> UResult<()> { let mut args_iter = args.iter(); while let Some(&arg) = args_iter.next() { match arg { - "ispeed" | "ospeed" => match args_iter.next() { + "ispeed" => match args_iter.next() { Some(speed) => { - if let Some(baud_flag) = string_to_baud(speed) { + if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Input) { valid_args.push(ArgOptions::Flags(baud_flag)); } else { - return Err(USimpleError::new( - 1, - translate!( - "stty-error-invalid-speed", - "arg" => *arg, - "speed" => *speed, - ), - )); + return invalid_speed(arg, speed); + } + } + None => { + return missing_arg(arg); + } + }, + "ospeed" => match args_iter.next() { + Some(speed) => { + if let Some(baud_flag) = string_to_baud(speed, flags::BaudType::Output) { + valid_args.push(ArgOptions::Flags(baud_flag)); + } else { + return invalid_speed(arg, speed); } } None => { @@ -382,12 +387,12 @@ fn stty(opts: &Options) -> UResult<()> { return missing_arg(arg); } // baud rate - } else if let Some(baud_flag) = string_to_baud(arg) { + } else if let Some(baud_flag) = string_to_baud(arg, flags::BaudType::Both) { valid_args.push(ArgOptions::Flags(baud_flag)); // non control char flag } else if let Some(flag) = string_to_flag(arg) { let remove_group = match flag { - AllFlags::Baud(_) => false, + AllFlags::Baud(_, _) => false, AllFlags::ControlFlags((flag, remove)) => { check_flag_group(flag, remove) } @@ -416,7 +421,7 @@ fn stty(opts: &Options) -> UResult<()> { for arg in &valid_args { match arg { ArgOptions::Mapping(mapping) => apply_char_mapping(&mut termios, mapping), - ArgOptions::Flags(flag) => apply_setting(&mut termios, flag), + ArgOptions::Flags(flag) => apply_setting(&mut termios, flag)?, ArgOptions::Special(setting) => { apply_special_setting(&mut termios, setting, opts.file.as_raw_fd())?; } @@ -468,6 +473,17 @@ fn invalid_integer_arg(arg: &str) -> Result> { )) } +fn invalid_speed(arg: &str, speed: &str) -> Result> { + Err(UUsageError::new( + 1, + translate!( + "stty-error-invalid-speed", + "arg" => arg, + "speed" => speed, + ), + )) +} + /// GNU uses different error messages if values overflow or underflow a u8, /// this function returns the appropriate error message in the case of overflow or underflow, or u8 on success fn parse_u8_or_err(arg: &str) -> Result { @@ -657,7 +673,7 @@ fn parse_baud_with_rounding(normalized: &str) -> Option { Some(value) } -fn string_to_baud(arg: &str) -> Option> { +fn string_to_baud(arg: &str, baud_type: flags::BaudType) -> Option> { // Reject invalid formats if arg != arg.trim_end() || arg.trim().starts_with('-') @@ -682,7 +698,7 @@ fn string_to_baud(arg: &str) -> Option> { target_os = "netbsd", target_os = "openbsd" ))] - return Some(AllFlags::Baud(value)); + return Some(AllFlags::Baud(value, baud_type)); #[cfg(not(any( target_os = "freebsd", @@ -695,7 +711,7 @@ fn string_to_baud(arg: &str) -> Option> { { for (text, baud_rate) in BAUD_RATES { if text.parse::().ok() == Some(value) { - return Some(AllFlags::Baud(*baud_rate)); + return Some(AllFlags::Baud(*baud_rate, baud_type)); } } None @@ -853,9 +869,9 @@ fn print_flags(termios: &Termios, opts: &Options, flags: &[Flag< } /// Apply a single setting -fn apply_setting(termios: &mut Termios, setting: &AllFlags) { +fn apply_setting(termios: &mut Termios, setting: &AllFlags) -> nix::Result<()> { match setting { - AllFlags::Baud(_) => apply_baud_rate_flag(termios, setting), + AllFlags::Baud(_, _) => apply_baud_rate_flag(termios, setting)?, AllFlags::ControlFlags((setting, disable)) => { setting.flag.apply(termios, !disable); } @@ -869,9 +885,10 @@ fn apply_setting(termios: &mut Termios, setting: &AllFlags) { setting.flag.apply(termios, !disable); } } + Ok(()) } -fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) { +fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) -> nix::Result<()> { // BSDs use a u32 for the baud rate, so any decimal number applies. #[cfg(any( target_os = "freebsd", @@ -881,8 +898,15 @@ fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) { target_os = "netbsd", target_os = "openbsd" ))] - if let AllFlags::Baud(n) = input { - cfsetospeed(termios, *n).expect("Failed to set baud rate"); + if let AllFlags::Baud(n, baud_type) = input { + match baud_type { + flags::BaudType::Input => cfsetispeed(termios, *n)?, + flags::BaudType::Output => cfsetospeed(termios, *n)?, + flags::BaudType::Both => { + cfsetispeed(termios, *n)?; + cfsetospeed(termios, *n)?; + } + } } // Other platforms use an enum. @@ -894,9 +918,17 @@ fn apply_baud_rate_flag(termios: &mut Termios, input: &AllFlags) { target_os = "netbsd", target_os = "openbsd" )))] - if let AllFlags::Baud(br) = input { - cfsetospeed(termios, *br).expect("Failed to set baud rate"); + if let AllFlags::Baud(br, baud_type) = input { + match baud_type { + flags::BaudType::Input => cfsetispeed(termios, *br)?, + flags::BaudType::Output => cfsetospeed(termios, *br)?, + flags::BaudType::Both => { + cfsetispeed(termios, *br)?; + cfsetospeed(termios, *br)?; + } + } } + Ok(()) } fn apply_char_mapping(termios: &mut Termios, mapping: &(S, u8)) { diff --git a/tests/by-util/test_stty.rs b/tests/by-util/test_stty.rs index f68de5daf5b..3d11cf4b25b 100644 --- a/tests/by-util/test_stty.rs +++ b/tests/by-util/test_stty.rs @@ -526,3 +526,67 @@ fn test_saved_state_with_control_chars() { .stderr_is(exp_result.stderr_str()) .code_is(exp_result.code()); } + +#[test] +#[cfg(unix)] +fn test_ispeed_ospeed_valid_speeds() { + let (path, _controller, _replica) = pty_path(); + let (_at, ts) = at_and_ts!(); + + // Test various valid baud rates for both ispeed and ospeed + let test_cases = [ + ("ispeed", "50"), + ("ispeed", "9600"), + ("ispeed", "19200"), + ("ospeed", "1200"), + ("ospeed", "9600"), + ("ospeed", "38400"), + ]; + + for (arg, speed) in test_cases { + let result = ts.ucmd().args(&["--file", &path, arg, speed]).run(); + let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed])); + let normalized_stderr = normalize_stderr(result.stderr_str()); + + result + .stdout_is(exp_result.stdout_str()) + .code_is(exp_result.code()); + assert_eq!(normalized_stderr, exp_result.stderr_str()); + } +} + +#[test] +#[cfg(all( + unix, + not(any( + target_os = "freebsd", + target_os = "dragonfly", + target_os = "ios", + target_os = "macos", + target_os = "netbsd", + target_os = "openbsd" + )) +))] +fn test_ispeed_ospeed_invalid_speeds() { + let (path, _controller, _replica) = pty_path(); + let (_at, ts) = at_and_ts!(); + + // Test invalid speed values (non-standard baud rates) + let test_cases = [ + ("ispeed", "12345"), + ("ospeed", "99999"), + ("ispeed", "abc"), + ("ospeed", "xyz"), + ]; + + for (arg, speed) in test_cases { + let result = ts.ucmd().args(&["--file", &path, arg, speed]).run(); + let exp_result = unwrap_or_return!(expected_result(&ts, &["--file", &path, arg, speed])); + let normalized_stderr = normalize_stderr(result.stderr_str()); + + result + .stdout_is(exp_result.stdout_str()) + .code_is(exp_result.code()); + assert_eq!(normalized_stderr, exp_result.stderr_str()); + } +}