diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py index 1d0c4154a..27bdb2057 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py @@ -285,6 +285,13 @@ def quantize_annotate_layer(to_annotate, quantize_config=None): layer=to_annotate, quantize_config=quantize_config) +def _clone_model_with_weights(model_to_clone): + cloned_model = keras.models.clone_model(model_to_clone) + cloned_model.set_weights(model_to_clone.get_weights()) + + return cloned_model + + @metrics.MonitorBoolGauge('quantize_apply_usage') def quantize_apply( model, @@ -361,12 +368,6 @@ def quantize_apply( 'been built yet. Please call `model.build(input_shape)` ' 'before quantizing your model.') - def _clone_model_with_weights(model_to_clone): - cloned_model = keras.models.clone_model(model_to_clone) - cloned_model.set_weights(model_to_clone.get_weights()) - - return cloned_model - def _extract_original_model(model_to_unwrap): """Extracts original model by removing wrappers.""" layer_quantize_map = {}