Skip to content

Commit 2d17498

Browse files
committed
chore: add deserialization tests with corrupted inputs
1 parent 195ca69 commit 2d17498

21 files changed

+339
-7
lines changed

.github/workflows/aws_tfhe_backward_compat_tests.yml renamed to .github/workflows/aws_data_tests.yml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
# Run backward compatibility tests
2-
name: aws_tfhe_backward_compat_tests
1+
# Run data related tests
2+
name: aws_data_tests
33

44
env:
55
CARGO_TERM_COLOR: always
@@ -30,8 +30,8 @@ permissions:
3030
# zizmor: ignore[concurrency-limits] concurrency is managed after instance setup to ensure safe provisioning
3131

3232
jobs:
33-
backward-compat-tests:
34-
name: aws_tfhe_backward_compat_tests/backward-compat-tests (bpr)
33+
data-tests:
34+
name: aws_data_tests/data-tests (bpr)
3535
if: (github.event_name == 'push' && github.repository == 'zama-ai/tfhe-rs') ||
3636
github.event_name != 'push'
3737
runs-on: "runs-on=${{ github.run_id }}/runner=cpu-small"
@@ -49,7 +49,7 @@ jobs:
4949
- name: Get LFS data sha
5050
id: hash-lfs-data
5151
run: |
52-
SHA=$(git lfs ls-files -l -I utils/tfhe-backward-compat-data | sha256sum | cut -d' ' -f1)
52+
SHA=$(git lfs ls-files -l -I utils/tfhe-backward-compat-data,tests/corrupted_inputs_deserialization | sha256sum | cut -d' ' -f1)
5353
echo "sha=${SHA}" >> "${GITHUB_OUTPUT}"
5454
5555
- name: Retrieve data from cache
@@ -59,12 +59,14 @@ jobs:
5959
path: |
6060
utils/tfhe-backward-compat-data/**/*.cbor
6161
utils/tfhe-backward-compat-data/**/*.bcode
62+
tests/corrupted_inputs_deserialization/**/*.bcode
6263
key: ${{ steps.hash-lfs-data.outputs.sha }}
6364

6465
- name: Pull test data
6566
if: steps.retrieve-data-cache.outputs.cache-hit != 'true'
6667
run: |
6768
make pull_backward_compat_data
69+
make pull_corrupted_inputs_data
6870
6971
# Pull token was stored by action/checkout to be used by lfs, we don't need it anymore
7072
- name: Remove git credentials
@@ -80,6 +82,10 @@ jobs:
8082
run: |
8183
make test_backward_compatibility_ci
8284
85+
- name: Run corrupted inputs deserialization tests
86+
run: |
87+
make test_corrupted_inputs_ci
88+
8389
- name: Store data in cache
8490
if: steps.retrieve-data-cache.outputs.cache-hit != 'true'
8591
continue-on-error: true
@@ -88,6 +94,7 @@ jobs:
8894
path: |
8995
utils/tfhe-backward-compat-data/**/*.cbor
9096
utils/tfhe-backward-compat-data/**/*.bcode
97+
tests/corrupted_inputs_deserialization/**/*.bcode
9198
key: ${{ steps.hash-lfs-data.outputs.sha }}
9299

93100
- name: Set pull-request URL

.linelint.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ ignore:
1111
- coverage
1212
- utils/tfhe-lints/tests/*/main.stderr
1313
- utils/tfhe-backward-compat-data/**/*.ron # ron files are autogenerated
14+
- tests/corrupted_inputs_deserialization/data/proven_compact_list/**/metadata.txt
1415

1516
rules:
1617
# checks if file ends in a newline character

Makefile

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ BENCH_CUSTOM_COMMAND:=
2626
NODE_VERSION=24.12
2727
BACKWARD_COMPAT_DATA_DIR=utils/tfhe-backward-compat-data
2828
BACKWARD_COMPAT_DATA_GEN_VERSION:=$(TFHE_VERSION)
29+
CORRUPTED_INPUTS_TEST=tests/corrupted_inputs_deserialization
2930
TEST_VECTORS_DIR=apps/test-vectors
3031
CURRENT_TFHE_VERSION:=$(shell grep '^version[[:space:]]*=' tfhe/Cargo.toml | cut -d '=' -f 2 | xargs)
3132
WASM_PACK_VERSION="0.13.1"
@@ -498,7 +499,7 @@ clippy_trivium: install_rs_check_toolchain
498499
.PHONY: clippy_ws_tests # Run clippy on the workspace level tests
499500
clippy_ws_tests: install_rs_check_toolchain
500501
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --tests \
501-
-p tests --features=shortint,integer,zk-pok -- --no-deps -D warnings
502+
-p tests --features=shortint,integer,zk-pok,strings -- --no-deps -D warnings
502503

503504
.PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.)
504505
clippy_all_targets: install_rs_check_toolchain
@@ -1253,11 +1254,19 @@ new_backward_compat_crate:
12531254
.PHONY: test_backward_compatibility_ci
12541255
test_backward_compatibility_ci:
12551256
TFHE_BACKWARD_COMPAT_DATA_DIR="../$(BACKWARD_COMPAT_DATA_DIR)" RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
1256-
--features=shortint,integer,zk-pok -p tests test_backward_compatibility -- --nocapture
1257+
--features=shortint,integer,zk-pok,strings -p tests test_backward_compatibility -- --nocapture
12571258

12581259
.PHONY: test_backward_compatibility # Same as test_backward_compatibility_ci but tries to clone the data repo first if needed
12591260
test_backward_compatibility: pull_backward_compat_data test_backward_compatibility_ci
12601261

1262+
.PHONY: test_corrupted_inputs_ci
1263+
test_corrupted_inputs_ci:
1264+
RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \
1265+
--features=integer,zk-pok,strings -p tests test_corrupted_inputs_deserialization -- --nocapture
1266+
1267+
.PHONY: test_corrupted_inputs # Same as test_corrupted_inputs_ci but pulls data first
1268+
test_corrupted_inputs: pull_corrupted_inputs_data test_corrupted_inputs_ci
1269+
12611270
# Generate the test vectors and update the hash file
12621271
.PHONY: gen_test_vectors
12631272
gen_test_vectors:
@@ -2041,6 +2050,10 @@ write_params_to_file: install_rs_check_toolchain
20412050
pull_backward_compat_data:
20422051
./scripts/pull_lfs_data.sh $(BACKWARD_COMPAT_DATA_DIR)
20432052

2053+
.PHONY: pull_corrupted_inputs_data # Pull the data files needed for corrupted inputs deserialization tests
2054+
pull_corrupted_inputs_data:
2055+
./scripts/pull_lfs_data.sh $(CORRUPTED_INPUTS_TEST)
2056+
20442057
.PHONY: pull_hpu_files # Pull the hpu files
20452058
pull_hpu_files:
20462059
./scripts/pull_lfs_data.sh backends/tfhe-hpu-backend/

tests/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ cargo_toml = "0.22"
1616
name = "backward_compatibility_tests"
1717
path = "backward_compatibility_tests.rs"
1818

19+
[[test]]
20+
name = "corrupted_inputs_deserialization"
21+
path = "corrupted_inputs_deserialization.rs"
22+
1923
[features]
2024
shortint = ["tfhe/shortint"]
2125
integer = ["shortint", "tfhe/integer"]
2226
zk-pok = ["tfhe/zk-pok"]
27+
strings = ["integer", "tfhe/strings"]
Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
//! This test tries to load various kind of corrupted serialized inputs and see that they are
2+
//! handled without crashes.
3+
4+
use std::fs::File;
5+
use std::io::Read;
6+
use std::ops::{Add, Mul, Sub};
7+
use std::path::Path;
8+
9+
use tfhe::conformance::ListSizeConstraint;
10+
use tfhe::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
11+
use tfhe::prelude::{CiphertextList, Tagged};
12+
use tfhe::safe_serialization::{safe_deserialize, safe_deserialize_conformant};
13+
use tfhe::zk::CompactPkeCrs;
14+
use tfhe::{
15+
set_server_key, CompactCiphertextList, CompactCiphertextListConformanceParams,
16+
CompactCiphertextListExpander, CompactPublicKey, FheAsciiString, HlExpandable,
17+
ProvenCompactCiphertextList, ServerKey,
18+
};
19+
20+
const DATA_DIR: &str = "./corrupted_inputs_deserialization/data";
21+
const AUX_DIR: &str = "aux_data";
22+
const GIGA: u64 = 1024 * 1024 * 1024;
23+
24+
fn load_server_key(aux_dir: &Path) -> ServerKey {
25+
let path = aux_dir.join("server_key.bcode");
26+
let f = File::open(&path).unwrap_or_else(|e| panic!("failed to open {}: {e}", path.display()));
27+
safe_deserialize(f, 4 * GIGA).unwrap()
28+
}
29+
30+
fn load_crs(aux_dir: &Path) -> CompactPkeCrs {
31+
let path = aux_dir.join("crs.bcode");
32+
let f = File::open(&path).unwrap_or_else(|e| panic!("failed to open {}: {e}", path.display()));
33+
safe_deserialize(f, 4 * GIGA).unwrap()
34+
}
35+
36+
fn load_public_key(aux_dir: &Path) -> CompactPublicKey {
37+
let path = aux_dir.join("pubkey.bcode");
38+
let f = File::open(&path).unwrap_or_else(|e| panic!("failed to open {}: {e}", path.display()));
39+
safe_deserialize(f, 4 * GIGA).unwrap()
40+
}
41+
42+
fn load_metadata(aux_dir: &Path) -> Vec<u8> {
43+
let path = aux_dir.join("metadata.txt");
44+
std::fs::read(&path).unwrap_or_else(|e| panic!("failed to open {}: {e}", path.display()))
45+
}
46+
47+
fn list_subdirs(dir: &Path) -> Vec<std::path::PathBuf> {
48+
std::fs::read_dir(dir)
49+
.unwrap_or_else(|e| panic!("failed to read {}: {e}", dir.display()))
50+
.filter_map(|entry| {
51+
let entry = entry.unwrap();
52+
entry.file_type().unwrap().is_dir().then(|| entry.path())
53+
})
54+
.collect()
55+
}
56+
57+
/// Read all .bcode files in `dir` and call `handler` on each file.
58+
fn process_inputs(dir: &Path, mut handler: impl FnMut(&[u8])) -> u64 {
59+
let mut total_tests = 0;
60+
let entries =
61+
std::fs::read_dir(dir).unwrap_or_else(|e| panic!("failed to read {}: {e}", dir.display()));
62+
63+
for entry in entries {
64+
let path = entry.unwrap().path();
65+
if path.extension().and_then(|e| e.to_str()) == Some("bcode") {
66+
println!("Processing {}", path.display());
67+
total_tests += 1;
68+
let input = std::fs::read(&path)
69+
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
70+
handler(&input);
71+
}
72+
}
73+
74+
total_tests
75+
}
76+
77+
fn test_integer<FheType>(exp: &CompactCiphertextListExpander, i: usize)
78+
where
79+
FheType: HlExpandable
80+
+ Tagged
81+
+ Clone
82+
+ Add<FheType, Output = FheType>
83+
+ Sub<FheType, Output = FheType>
84+
+ Mul<FheType, Output = FheType>,
85+
{
86+
let ct = match exp.get::<FheType>(i) {
87+
Ok(Some(ct)) => ct,
88+
Ok(None) => {
89+
println!("No ct found at idx {}\n", i);
90+
return;
91+
}
92+
Err(e) => {
93+
println!("Error caught while trying to get ct at idx {}:\n{e}\n", i);
94+
return;
95+
}
96+
};
97+
let res = ct.clone() + ct.clone();
98+
let res = res * ct.clone();
99+
let _ = std::hint::black_box(res - ct);
100+
}
101+
102+
fn test_string(exp: &CompactCiphertextListExpander, i: usize) {
103+
let ct = match exp.get::<FheAsciiString>(i) {
104+
Ok(Some(ct)) => ct,
105+
Ok(None) => {
106+
println!("No ct found at idx {i}\n");
107+
return;
108+
}
109+
Err(e) => {
110+
println!("Error caught while trying to get ct at idx {i}:\n{e}\n");
111+
return;
112+
}
113+
};
114+
115+
let _ = std::hint::black_box(ct.len());
116+
}
117+
118+
/// Try to expand each element and perform some operations, dispatching on the
119+
/// actual type reported by the expander.
120+
fn use_list(exp: &CompactCiphertextListExpander) {
121+
for i in 0..exp.len() {
122+
let Some(kind) = exp.get_kind_of(i) else {
123+
println!("No metadata for ct at idx {i}\n");
124+
return;
125+
};
126+
127+
match kind {
128+
FheTypes::Uint8 => test_integer(exp, i),
129+
FheTypes::Uint16 => test_integer(exp, i),
130+
FheTypes::Uint32 => test_integer(exp, i),
131+
FheTypes::Uint64 => test_integer(exp, i),
132+
FheTypes::Int8 => test_integer(exp, i),
133+
FheTypes::Int16 => test_integer(exp, i),
134+
FheTypes::Int32 => test_integer(exp, i),
135+
FheTypes::Int64 => test_integer(exp, i),
136+
FheTypes::AsciiString => test_string(exp, i),
137+
other => panic!(
138+
"unsupported FheTypes variant {other:?} at index {i}, \
139+
this test should be updated to handle it"
140+
),
141+
}
142+
}
143+
}
144+
145+
fn handle_ct_list(input: &[u8], conformance_params: &CompactCiphertextListConformanceParams) {
146+
let ct_list: CompactCiphertextList = match safe_deserialize_conformant::<CompactCiphertextList>(
147+
input,
148+
4 * GIGA,
149+
conformance_params,
150+
) {
151+
Ok(ct_list) => ct_list,
152+
Err(e) => {
153+
println!("Error caught during deserialization:\n{e}\n");
154+
return;
155+
}
156+
};
157+
158+
let exp = match ct_list.expand() {
159+
Ok(exp) => exp,
160+
Err(e) => {
161+
println!("Error caught during expand:\n{e}\n");
162+
return;
163+
}
164+
};
165+
166+
use_list(&exp);
167+
println!("List used without error\n")
168+
}
169+
170+
fn handle_proven_ct_list(
171+
input: &[u8],
172+
conformance_params: &IntegerProvenCompactCiphertextListConformanceParams,
173+
crs: &CompactPkeCrs,
174+
public_key: &CompactPublicKey,
175+
metadata: &[u8],
176+
) {
177+
let ct_list: ProvenCompactCiphertextList = match safe_deserialize_conformant::<
178+
ProvenCompactCiphertextList,
179+
>(input, 4 * GIGA, conformance_params)
180+
{
181+
Ok(ct_list) => ct_list,
182+
Err(e) => {
183+
println!("Error caught during deserialization:\n{e}\n");
184+
return;
185+
}
186+
};
187+
188+
let exp = match ct_list.verify_and_expand(crs, public_key, metadata) {
189+
Ok(exp) => exp,
190+
Err(e) => {
191+
println!("Error caught during verify_and_expand:\n{e}\n");
192+
return;
193+
}
194+
};
195+
196+
use_list(&exp);
197+
println!("List used without error\n")
198+
}
199+
200+
#[test]
201+
fn test_corrupted_inputs_deserialization() {
202+
let mut total_tests = 0;
203+
let data_dir = Path::new(DATA_DIR);
204+
205+
let compact_list_dir = data_dir.join("compact_list");
206+
for group_dir in list_subdirs(&compact_list_dir) {
207+
println!("compact_list group: {}", group_dir.display());
208+
let aux_dir = group_dir.join(AUX_DIR);
209+
210+
let server_key = load_server_key(&aux_dir);
211+
let pubkey = load_public_key(&aux_dir);
212+
213+
let cpk_conformance_params =
214+
CompactCiphertextListConformanceParams::from_parameters_and_size_constraint(
215+
pubkey.parameters(),
216+
ListSizeConstraint::try_size_in_range(4, usize::MAX).unwrap(),
217+
)
218+
.allow_unpacked();
219+
220+
set_server_key(server_key);
221+
222+
total_tests += process_inputs(&group_dir, |input| {
223+
handle_ct_list(input, &cpk_conformance_params);
224+
});
225+
}
226+
227+
let proven_compact_list_dir = data_dir.join("proven_compact_list");
228+
for group_dir in list_subdirs(&proven_compact_list_dir) {
229+
println!("proven_compact_list group: {}", group_dir.display());
230+
let aux_dir = group_dir.join(AUX_DIR);
231+
232+
let server_key = load_server_key(&aux_dir);
233+
let pubkey = load_public_key(&aux_dir);
234+
let crs = load_crs(&aux_dir);
235+
let metadata = load_metadata(&aux_dir);
236+
237+
let proven_cpk_conformance_params =
238+
IntegerProvenCompactCiphertextListConformanceParams::from_public_key_encryption_parameters_and_crs_parameters(
239+
pubkey.parameters(),
240+
&crs,
241+
)
242+
.allow_unpacked();
243+
244+
set_server_key(server_key);
245+
246+
total_tests += process_inputs(&group_dir, |input| {
247+
handle_proven_ct_list(
248+
input,
249+
&proven_cpk_conformance_params,
250+
&crs,
251+
&pubkey,
252+
&metadata,
253+
);
254+
});
255+
}
256+
257+
println!("Executed {} tests", total_tests);
258+
// If we ran 0 test, it is likely that something wrong happened
259+
assert!(total_tests != 0);
260+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:59a6aeba2a010e1b7f6aa57d9a1dd5b40716e01ec4c5a2b0aa35710657ba77c5
3+
size 762

0 commit comments

Comments
 (0)