Skip to content

Commit 98d58ad

Browse files
committed
fix: BlockDecomposer
The BlockDecomposer gave the possibility when the number of bits per block was not a multiple of the number of bits in the original integer to force the extra bits of the last block to a particular value. However, the way this was done could only work when setting these bits to 1, when wanting to set them to 0 it would not work. Good news is that we actually never wanted to set them to 0, but it should still be fixed for completeness, and allow other feature to be added without bugs
1 parent 8962d1f commit 98d58ad

File tree

2 files changed

+96
-52
lines changed

2 files changed

+96
-52
lines changed

tfhe/src/integer/block_decomposition.rs

Lines changed: 91 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::core_crypto::prelude::{CastFrom, CastInto, Numeric};
22
use crate::integer::bigint::static_signed::StaticSignedBigInt;
33
use crate::integer::bigint::static_unsigned::StaticUnsignedBigInt;
44
use core::ops::{AddAssign, BitAnd, ShlAssign, ShrAssign};
5-
use std::ops::{BitOrAssign, Shl, Sub};
5+
use std::ops::{BitOrAssign, Not, Shl, Shr, Sub};
66

77
// These work for signed number as rust uses 2-Complements
88
// And Arithmetic shift for signed number (logical for unsigned)
@@ -14,8 +14,10 @@ pub trait Decomposable:
1414
+ ShrAssign<u32>
1515
+ Eq
1616
+ CastFrom<u32>
17+
+ Shr<u32, Output = Self>
1718
+ Shl<u32, Output = Self>
1819
+ BitOrAssign<Self>
20+
+ Not<Output = Self>
1921
{
2022
}
2123
pub trait Recomposable:
@@ -86,33 +88,48 @@ impl<const N: usize> RecomposableFrom<u8> for StaticUnsignedBigInt<N> {}
8688
impl<const N: usize> DecomposableInto<u64> for StaticUnsignedBigInt<N> {}
8789
impl<const N: usize> DecomposableInto<u8> for StaticUnsignedBigInt<N> {}
8890

91+
#[derive(Copy, Clone)]
92+
#[repr(u32)]
93+
pub enum PaddingBitValue {
94+
Zero = 0,
95+
One = 1,
96+
}
97+
8998
#[derive(Clone)]
9099
pub struct BlockDecomposer<T> {
91100
data: T,
92101
bit_mask: T,
93102
num_bits_in_mask: u32,
94103
num_bits_valid: u32,
95-
padding_bit: T,
104+
padding_bit: Option<PaddingBitValue>,
96105
limit: Option<T>,
97106
}
98107

99108
impl<T> BlockDecomposer<T>
100109
where
101110
T: Decomposable,
102111
{
112+
/// Creates a block decomposer that will stop when the value reaches zero
103113
pub fn with_early_stop_at_zero(value: T, bits_per_block: u32) -> Self {
104-
Self::new_(value, bits_per_block, Some(T::ZERO), T::ZERO)
114+
Self::new_(value, bits_per_block, Some(T::ZERO), None)
105115
}
106116

107-
pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: T) -> Self {
108-
Self::new_(value, bits_per_block, None, padding_bit)
117+
/// Creates a block decomposer that will set the surplus bits to a specific value
118+
/// when bits_per_block is not a multiple of T::BITS
119+
pub fn with_padding_bit(value: T, bits_per_block: u32, padding_bit: PaddingBitValue) -> Self {
120+
Self::new_(value, bits_per_block, None, Some(padding_bit))
109121
}
110122

111123
pub fn new(value: T, bits_per_block: u32) -> Self {
112-
Self::new_(value, bits_per_block, None, T::ZERO)
124+
Self::new_(value, bits_per_block, None, None)
113125
}
114126

115-
fn new_(value: T, bits_per_block: u32, limit: Option<T>, padding_bit: T) -> Self {
127+
fn new_(
128+
value: T,
129+
bits_per_block: u32,
130+
limit: Option<T>,
131+
padding_bit: Option<PaddingBitValue>,
132+
) -> Self {
116133
assert!(bits_per_block <= T::BITS as u32);
117134
let num_bits_valid = T::BITS as u32;
118135

@@ -129,6 +146,31 @@ where
129146
padding_bit,
130147
}
131148
}
149+
150+
// We concretize the iterator type to allow usage of callbacks working on iterator for generic
151+
// integer encryption
152+
pub fn iter_as<V>(self) -> std::iter::Map<Self, fn(T) -> V>
153+
where
154+
V: Numeric,
155+
T: CastInto<V>,
156+
{
157+
assert!(self.num_bits_in_mask <= V::BITS as u32);
158+
self.map(CastInto::cast_into)
159+
}
160+
161+
pub fn next_as<V>(&mut self) -> Option<V>
162+
where
163+
V: CastFrom<T>,
164+
{
165+
self.next().map(|masked| V::cast_from(masked))
166+
}
167+
168+
pub fn checked_next_as<V>(&mut self) -> Option<V>
169+
where
170+
V: TryFrom<T>,
171+
{
172+
self.next().and_then(|masked| V::try_from(masked).ok())
173+
}
132174
}
133175

134176
impl<T> Iterator for BlockDecomposer<T>
@@ -159,11 +201,18 @@ where
159201

160202
if self.num_bits_valid < self.num_bits_in_mask {
161203
// This will be the case when self.num_bits_in_mask is not a multiple
162-
// of T::BITS. We replace bits that
163-
// do not come from the actual T but from the padding
164-
// intoduced by the shift, to a specific value.
165-
for i in self.num_bits_valid..self.num_bits_in_mask {
166-
masked |= self.padding_bit << i;
204+
// of T::BITS.
205+
//
206+
// We replace bits that do not come from the actual T but from the padding
207+
// introduced by the shift, to a specific value, if one was provided.
208+
if let Some(padding_bit) = self.padding_bit {
209+
let padding_mask = (self.bit_mask >> self.num_bits_valid) << self.num_bits_valid;
210+
masked = masked & !padding_mask;
211+
212+
let padding_bit = T::cast_from(padding_bit as u32);
213+
for i in self.num_bits_valid..self.num_bits_in_mask {
214+
masked |= padding_bit << i;
215+
}
167216
}
168217
}
169218

@@ -184,36 +233,6 @@ where
184233
}
185234
}
186235

187-
impl<T> BlockDecomposer<T>
188-
where
189-
T: Decomposable,
190-
{
191-
// We concretize the iterator type to allow usage of callbacks working on iterator for generic
192-
// integer encryption
193-
pub fn iter_as<V>(self) -> std::iter::Map<Self, fn(T) -> V>
194-
where
195-
V: Numeric,
196-
T: CastInto<V>,
197-
{
198-
assert!(self.num_bits_in_mask <= V::BITS as u32);
199-
self.map(CastInto::cast_into)
200-
}
201-
202-
pub fn next_as<V>(&mut self) -> Option<V>
203-
where
204-
V: CastFrom<T>,
205-
{
206-
self.next().map(|masked| V::cast_from(masked))
207-
}
208-
209-
pub fn checked_next_as<V>(&mut self) -> Option<V>
210-
where
211-
V: TryFrom<T>,
212-
{
213-
self.next().and_then(|masked| V::try_from(masked).ok())
214-
}
215-
}
216-
217236
pub struct BlockRecomposer<T> {
218237
data: T,
219238
bit_mask: T,
@@ -310,6 +329,36 @@ mod tests {
310329
assert_eq!(expected_blocks, blocks);
311330
}
312331

332+
#[test]
333+
fn test_bit_block_decomposer_3() {
334+
let bits_per_block = 3;
335+
336+
let value = -1i8;
337+
let blocks = BlockDecomposer::new(value, bits_per_block)
338+
.iter_as::<u64>()
339+
.collect::<Vec<_>>();
340+
// We expect the last block padded with 1s as a consequence of arithmetic shift
341+
let expected_blocks = vec![7, 7, 7];
342+
assert_eq!(expected_blocks, blocks);
343+
344+
let value = i8::MIN;
345+
let blocks = BlockDecomposer::new(value, bits_per_block)
346+
.iter_as::<u64>()
347+
.collect::<Vec<_>>();
348+
// We expect the last block padded with 1s as a consequence of arithmetic shift
349+
let expected_blocks = vec![0, 0, 6];
350+
assert_eq!(expected_blocks, blocks);
351+
352+
let value = -1i8;
353+
let blocks =
354+
BlockDecomposer::with_padding_bit(value, bits_per_block, PaddingBitValue::Zero)
355+
.iter_as::<u64>()
356+
.collect::<Vec<_>>();
357+
// We expect the last block padded with 0s as we force that
358+
let expected_blocks = vec![7, 7, 3];
359+
assert_eq!(expected_blocks, blocks);
360+
}
361+
313362
#[test]
314363
fn test_bit_block_decomposer_recomposer_carry_handling_in_between() {
315364
let value = u16::MAX as u32;

tfhe/src/integer/server_key/radix/scalar_sub.rs

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::core_crypto::prelude::Numeric;
2-
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
2+
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto, PaddingBitValue};
33
use crate::integer::ciphertext::{IntegerRadixCiphertext, RadixCiphertext};
44
use crate::integer::server_key::CheckError;
55
use crate::integer::ServerKey;
@@ -92,17 +92,12 @@ impl ServerKey {
9292
// The only case where these msb could become 0 after the addition
9393
// is if scalar == T::ZERO (=> !T::ZERO == T::MAX => T::MAX + 1 == overflow),
9494
// but this case has been handled earlier.
95-
let padding_bit = 1u32; // To handle when bits is not a multiple of T::BITS
96-
// All bits of message set to one
9795
let pad_block = (1 << bits_in_message as u8) - 1;
9896

99-
let decomposer = BlockDecomposer::with_padding_bit(
100-
neg_scalar,
101-
bits_in_message,
102-
Scalar::cast_from(padding_bit),
103-
)
104-
.iter_as::<u8>()
105-
.chain(std::iter::repeat(pad_block));
97+
let decomposer =
98+
BlockDecomposer::with_padding_bit(neg_scalar, bits_in_message, PaddingBitValue::One)
99+
.iter_as::<u8>()
100+
.chain(std::iter::repeat(pad_block));
106101
Some(decomposer)
107102
}
108103

0 commit comments

Comments
 (0)