Skip to content

Commit d4c1c8c

Browse files
authored
fix: embedding layer support in client server (#1149)
1 parent 9d9669d commit d4c1c8c

File tree

3 files changed

+72
-1
lines changed

3 files changed

+72
-1
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ filterwarnings = [
153153
"ignore:`np\\.object` is a deprecated alias for the builtin `object`\\. To silence this warning, use `object` by itself\\. Doing this will not modify any behavior and is safe\\.:DeprecationWarning",
154154
"ignore:Using or importing the ABCs from 'collections' instead of from 'collections\\.abc' is deprecated.*:DeprecationWarning",
155155
"ignore: distutils Version classes are deprecated. Use packaging\\.version instead.*:DeprecationWarning",
156+
"ignore:The distutils package is deprecated and slated for removal in Python 3\\.12.*:DeprecationWarning",
157+
"ignore:Distutils was imported before Setuptools.*:UserWarning",
158+
"ignore:Setuptools is replacing distutils.*:UserWarning",
159+
"ignore:The distutils\\.sysconfig module is deprecated.*:DeprecationWarning",
156160
"ignore: forcing n_jobs = 1 on mac for segfault issue",
157161
"ignore: allowzero=0 by default.*:UserWarning",
158162
"ignore:Implicitly cleaning up:ResourceWarning",

src/concrete/ml/deployment/fhe_client_server.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from ..common.serialization.loaders import load
1919
from ..common.utils import CiphertextFormat, to_tuple
2020
from ..quantization import QuantizedModule
21+
from ..torch.numpy_module import NumpyModule
2122
from ..version import __version__ as CML_VERSION
2223
from ._utils import deserialize_encrypted_values, serialize_encrypted_values
2324

@@ -270,8 +271,13 @@ def _export_model_to_json(self, is_training: bool = False) -> Path:
270271
"output_quantizers": module_to_export.output_quantizers,
271272
"is_training": is_training,
272273
"ciphertext_format": module_to_export.ciphertext_format,
274+
"onnx_preprocessing": None,
273275
}
274276

277+
preprocessing_module = getattr(module_to_export, "_preprocessing_module", None)
278+
if preprocessing_module is not None:
279+
serialized_processing["onnx_preprocessing"] = preprocessing_module.onnx_model
280+
275281
# Export the `is_fitted` attribute for built-in models
276282
if hasattr(self.model, "is_fitted"):
277283
serialized_processing["is_fitted"] = self.model.is_fitted
@@ -416,6 +422,13 @@ def load(self): # pylint: disable=no-value-for-parameter
416422

417423
self.model.ciphertext_format = serialized_processing["ciphertext_format"]
418424

425+
if hasattr(self.model, "_preprocessing_module"):
426+
onnx_preprocessing = serialized_processing.get("onnx_preprocessing")
427+
# pylint: disable-next=protected-access
428+
self.model._preprocessing_module = (
429+
NumpyModule(onnx_preprocessing) if onnx_preprocessing is not None else None
430+
)
431+
419432
# Load model parameters
420433
# Add some checks on post-processing-params
421434
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3131
@@ -487,6 +500,14 @@ def quantize_encrypt_serialize(
487500
Union[bytes, Tuple[bytes, ...]]: The quantized, encrypted and serialized values.
488501
"""
489502

503+
# Apply the same preprocessing as during standard forward passes when available so that
504+
# inputs expected by the FHE circuit (e.g., one-hot vectors for optimized embeddings) are
505+
# generated here too.
506+
if hasattr(self.model, "pre_processing"):
507+
x = to_tuple(self.model.pre_processing(*x))
508+
else:
509+
x = to_tuple(x)
510+
490511
# Quantize the values
491512
x_quant = to_tuple(self.model.quantize_input(*x))
492513

tests/deployment/test_client_server.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy
1212
import pytest
13+
import torch
1314
from torch import nn
1415

1516
from concrete import fhe
@@ -20,7 +21,7 @@
2021
FHEModelDev,
2122
FHEModelServer,
2223
)
23-
from concrete.ml.pytest.torch_models import FCSmall
24+
from concrete.ml.pytest.torch_models import EmbeddingModel, FCSmall
2425
from concrete.ml.pytest.utils import (
2526
MODELS_AND_DATASETS,
2627
_get_sklearn_tree_models,
@@ -283,6 +284,51 @@ def test_client_server_custom_model(
283284
)
284285

285286

287+
def test_client_server_torch_embedding_model(default_configuration):
288+
"""Ensure client/server flow works for a torch model containing an embedding."""
289+
290+
torch.manual_seed(0)
291+
num_embeddings = 6
292+
embedding_dim = 3
293+
seq_len = 2
294+
295+
torch_model = EmbeddingModel(num_embeddings, embedding_dim)
296+
torch_model.eval()
297+
298+
torch_inputset = torch.randint(0, num_embeddings, size=(8, seq_len)).long()
299+
sample = torch_inputset[:1].numpy()
300+
301+
quantized_module = compile_torch_model(
302+
torch_model,
303+
torch_inputset,
304+
configuration=default_configuration,
305+
n_bits=2,
306+
rounding_threshold_bits=2,
307+
)
308+
309+
network = OnDiskNetwork()
310+
fhe_model_dev = FHEModelDev(path_dir=network.dev_dir.name, model=quantized_module)
311+
fhe_model_dev.save()
312+
network.dev_send_clientspecs_and_modelspecs_to_client()
313+
network.dev_send_model_to_server()
314+
315+
key_dir = default_configuration.insecure_key_cache_location
316+
fhe_model_client = FHEModelClient(path_dir=network.client_dir.name, key_dir=key_dir)
317+
fhe_model_client.generate_private_and_evaluation_keys(force=True)
318+
evaluation_keys = fhe_model_client.get_serialized_evaluation_keys(include_tfhers_key=False)
319+
fhe_model_server = FHEModelServer(path_dir=network.server_dir.name)
320+
321+
q_x_encrypted_serialized = fhe_model_client.quantize_encrypt_serialize(sample)
322+
q_y_encrypted_serialized = fhe_model_server.run(q_x_encrypted_serialized, evaluation_keys)
323+
324+
client_result = fhe_model_client.deserialize_decrypt_dequantize(
325+
*to_tuple(q_y_encrypted_serialized)
326+
)
327+
simulate_result = quantized_module.forward(sample, fhe="simulate")
328+
329+
numpy.testing.assert_allclose(client_result, simulate_result, atol=1e-2)
330+
331+
286332
def check_client_server_files(model, mode="inference"):
287333
"""Test the client server interface API generates the expected file.
288334

0 commit comments

Comments
 (0)