diff --git a/src/arguments/add.rs b/src/arguments/add.rs index 20f0a57..875de4c 100644 --- a/src/arguments/add.rs +++ b/src/arguments/add.rs @@ -40,7 +40,7 @@ pub struct AddArgs { long, default_value_t = 6, default_value_if("type", "STEAM", "5"), - value_parser=value_parser!(u64).range(0..=9) + value_parser=value_parser!(u64).range(1..=10) )] pub digits: u64, diff --git a/src/arguments/edit.rs b/src/arguments/edit.rs index d1081b6..398b57f 100644 --- a/src/arguments/edit.rs +++ b/src/arguments/edit.rs @@ -24,7 +24,7 @@ pub struct EditArgs { pub algorithm: Option, /// Code digits - #[arg(short, long, value_parser=value_parser!(u64).range(0..=9))] + #[arg(short, long, value_parser=value_parser!(u64).range(1..=10))] pub digits: Option, /// Code period diff --git a/src/otp/otp_element.rs b/src/otp/otp_element.rs index 3ad12ad..5bd3e35 100644 --- a/src/otp/otp_element.rs +++ b/src/otp/otp_element.rs @@ -154,6 +154,8 @@ pub struct OTPElement { pub pin: Option, } +static ALLOWED_DIGITS_RANGE: std::ops::RangeInclusive = 1..=10; + impl OTPElement { pub fn get_otpauth_uri(&self) -> String { let otp_type = self.type_.to_string().to_lowercase(); @@ -183,6 +185,10 @@ impl OTPElement { } pub fn get_otp_code(&self) -> Result { + if !ALLOWED_DIGITS_RANGE.contains(&self.digits) { + return Err(OtpError::InvalidDigits); + } + match self.type_ { OTPType::Totp => { let code = totp(&self.secret, self.algorithm)?; @@ -222,11 +228,11 @@ impl OTPElement { fn format_code(&self, value: u32) -> Result { // Get the formatted code - let exponential = 10_u32 + let exponential = 10_u64 .checked_pow(self.digits as u32) .ok_or(OtpError::InvalidDigits)?; - let s = (value % exponential).to_string(); - Ok("0".repeat(self.digits as usize - s.chars().count()) + s.as_str()) + let s = (value as u64 % exponential).to_string(); + Ok("0".repeat((self.digits as usize).saturating_sub(s.chars().count())) + s.as_str()) } } @@ -357,7 +363,7 @@ mod test { #[test] fn test_invalid_digits_should_not_overflow() { // Arrange - let invalid_digits_value = 10; + let invalid_digits_value = 11; let element = OTPElement { secret: "xr5gh44x7bprcqgrdtulafeevt5rxqlbh5wvked22re43dh2d4mapv5g".to_uppercase(), @@ -378,6 +384,30 @@ mod test { assert_eq!(Err(OtpError::InvalidDigits), result); } + #[test] + fn test_10_digits_should_be_allowed() { + // Arrange + let invalid_digits_value = 10; + + let element = OTPElement { + secret: "xr5gh44x7bprcqgrdtulafeevt5rxqlbh5wvked22re43dh2d4mapv5g".to_uppercase(), + issuer: String::from("IssuerText"), + label: String::from("LabelText"), + digits: invalid_digits_value, + type_: Totp, + algorithm: Sha1, + period: 30, + counter: None, + pin: None, + }; + + // Act + let result = element.get_otp_code(); + + // Assert + assert!(result.is_ok()); + } + #[test] fn test_lowercase_secret() { // Arrange / Act