diff --git a/src/uint/boxed.rs b/src/uint/boxed.rs index 73ab2142..5a3d147c 100644 --- a/src/uint/boxed.rs +++ b/src/uint/boxed.rs @@ -263,6 +263,19 @@ impl BoxedUint { (NonZero::new_unchecked(nz), !is_zero) } + /// Convert to an [`Odd`], defaulting to one. + /// + /// Returns a pair consisting of an [`Odd`], and a [`Choice`] + /// indicating whether the original value was odd (and preserved). + #[inline(always)] + #[must_use] + pub(crate) fn to_odd_or_one(&self) -> (Odd, Choice) { + let is_odd = self.is_odd(); + let mut odd = self.clone(); + odd.as_mut_uint_ref().conditional_set_one(is_odd.not()); + (Odd::new_unchecked(odd), is_odd) + } + /// Construct an [`Odd`] reference, returning [`None`] in the event `self` is even. #[must_use] pub fn as_odd_vartime(&self) -> Option<&Odd> { diff --git a/src/uint/boxed/invert_mod.rs b/src/uint/boxed/invert_mod.rs index e92d5df8..0112f805 100644 --- a/src/uint/boxed/invert_mod.rs +++ b/src/uint/boxed/invert_mod.rs @@ -1,7 +1,7 @@ //! [`BoxedUint`] modular inverse (i.e. reciprocal) operations. use crate::{ - BoxedUint, Choice, CtEq, CtLt, CtOption, CtSelect, Integer, InvertMod, Limb, NonZero, Odd, U64, + BoxedUint, Choice, CtEq, CtLt, CtOption, CtSelect, InvertMod, Limb, NonZero, Odd, U64, modular::safegcd, uint::invert_mod::expand_invert_mod2k, }; @@ -50,13 +50,8 @@ impl BoxedUint { } else if k > bits { (Self::zero_with_precision(bits), Choice::FALSE) } else { - let is_some = self.is_odd(); - let inv = Odd::new_unchecked(Self::ct_select( - &Self::one_with_precision(bits), - self, - is_some, - )) - .invert_mod2k_vartime(k); + let (odd, is_some) = self.to_odd_or_one(); + let inv = odd.invert_mod2k_vartime(k); (inv, is_some) } } @@ -78,13 +73,9 @@ impl BoxedUint { #[must_use] pub fn invert_mod2k(&self, k: u32) -> (Self, Choice) { let bits = self.bits_precision(); - let is_some = k.ct_lt(&(bits + 1)) & (k.ct_eq(&0) | self.is_odd()); - let mut inv = Odd::new_unchecked(Self::ct_select( - &Self::one_with_precision(bits), - self, - is_some, - )) - .invert_mod_precision(); + let (odd, is_odd) = self.to_odd_or_one(); + let is_some = k.ct_lt(&(bits + 1)) & (k.ct_eq(&0) | is_odd); + let mut inv = odd.invert_mod_precision(); inv.restrict_bits(k); (inv, is_some) }