@@ -305,6 +305,34 @@ def get_activations_and_quantizers(self, layer):
305
305
return [(getattr (layer , activation_attr ), self .activation_quantizer )
306
306
for activation_attr in self .activation_attrs ]
307
307
308
+ def set_quantize_weights (self , layer , quantize_weights ):
309
+ if len (self .weight_attrs ) != len (quantize_weights ):
310
+ raise ValueError (
311
+ '`set_quantize_weights` called on layer {} with {} '
312
+ 'weight parameters, but layer expects {} values.' .format (
313
+ layer .name , len (quantize_weights ), len (self .weight_attrs )))
314
+
315
+ for weight_attr , weight in zip (self .weight_attrs , quantize_weights ):
316
+ current_weight = getattr (layer , weight_attr )
317
+ if current_weight .shape != weight .shape :
318
+ raise ValueError ('Existing layer weight shape {} is incompatible with'
319
+ 'provided weight shape {}' .format (
320
+ current_weight .shape , weight .shape ))
321
+
322
+ setattr (layer , weight_attr , weight )
323
+
324
+ def set_quantize_activations (self , layer , quantize_activations ):
325
+ if len (self .activation_attrs ) != len (quantize_activations ):
326
+ raise ValueError (
327
+ '`set_quantize_activations` called on layer {} with {} '
328
+ 'activation parameters, but layer expects {} values.' .format (
329
+ layer .name , len (quantize_activations ),
330
+ len (self .activation_attrs )))
331
+
332
+ for activation_attr , activation in \
333
+ zip (self .activation_attrs , quantize_activations ):
334
+ setattr (layer , activation_attr , activation )
335
+
308
336
309
337
class TFLiteQuantizeProviderRNN (TFLiteQuantizeProvider , _RNNHelper ):
310
338
"""QuantizeProvider for RNN layers."""
@@ -328,3 +356,49 @@ def get_activations_and_quantizers(self, layer):
328
356
(getattr (rnn_cell , activation_attr ), self .activation_quantizer ))
329
357
330
358
return activations_quantizers
359
+
360
+ def _flatten (self , list_of_lists ):
361
+ flat_list = []
362
+ for sublist in list_of_lists :
363
+ for item in sublist :
364
+ flat_list .append (item )
365
+ return flat_list
366
+
367
+ def set_quantize_weights (self , layer , quantize_weights ):
368
+ flattened_weight_attrs = self ._flatten (self .weight_attrs )
369
+ if len (flattened_weight_attrs ) != len (quantize_weights ):
370
+ raise ValueError (
371
+ '`set_quantize_weights` called on layer {} with {} '
372
+ 'weight parameters, but layer expects {} values.' .format (
373
+ layer .name , len (quantize_weights ), len (flattened_weight_attrs )))
374
+
375
+ i = 0
376
+ for weight_attrs_cell , rnn_cell in \
377
+ zip (self .weight_attrs , self ._get_rnn_cells (layer )):
378
+ for weight_attr in weight_attrs_cell :
379
+ current_weight = getattr (rnn_cell , weight_attr )
380
+ quantize_weight = quantize_weights [i ]
381
+
382
+ if current_weight .shape != quantize_weight .shape :
383
+ raise ValueError ('Existing layer weight shape {} is incompatible with'
384
+ 'provided weight shape {}' .format (
385
+ current_weight .shape , quantize_weight .shape ))
386
+
387
+ setattr (rnn_cell , weight_attr , quantize_weight )
388
+ i += 1
389
+
390
+ def set_quantize_activations (self , layer , quantize_activations ):
391
+ flattened_activation_attrs = self ._flatten (self .activation_attrs )
392
+ if len (flattened_activation_attrs ) != len (quantize_activations ):
393
+ raise ValueError (
394
+ '`set_quantize_activations` called on layer {} with {} '
395
+ 'activation parameters, but layer expects {} values.' .format (
396
+ layer .name , len (quantize_activations ),
397
+ len (flattened_activation_attrs )))
398
+
399
+ i = 0
400
+ for activation_attrs_cell , rnn_cell in \
401
+ zip (self .activation_attrs , self ._get_rnn_cells (layer )):
402
+ for activation_attr in activation_attrs_cell :
403
+ setattr (rnn_cell , activation_attr , quantize_activations [i ])
404
+ i += 1
0 commit comments