@@ -67,8 +67,10 @@ class PruneLowMagnitude(Wrapper):
67
67
Custom keras layers:
68
68
The pruning wrapper can also be applied to a user-defined keras layer.
69
69
Such a layer may contain one or more weight tensors that may be pruned.
70
- To apply pruning wrapper to such layers, set prunable_weight_names to mark
71
- the weight tensors for pruning.
70
+ To apply pruning wrapper to such layers, the layer should be a `PrunableLayer`
71
+ instance or, more directly, user should define a `get_prunable_weights` method
72
+ for the layer (Check the pruning_wrapper_test.CustomLayerPrunable for more
73
+ details about how to define a user-defined prunable layer).
72
74
73
75
Sparsity function:
74
76
The target sparsity for the weight tensors are set through the
@@ -141,7 +143,8 @@ def __init__(self,
141
143
kwargs .update ({'name' : '{}_{}' .format (
142
144
generic_utils .to_snake_case (self .__class__ .__name__ ), layer .name )})
143
145
144
- if isinstance (layer , prunable_layer .PrunableLayer ):
146
+ if isinstance (layer , prunable_layer .PrunableLayer ) or hasattr (
147
+ layer , 'get_prunable_weights' ):
145
148
# Custom layer in client code which supports pruning.
146
149
super (PruneLowMagnitude , self ).__init__ (layer , ** kwargs )
147
150
elif prune_registry .PruneRegistry .supports (layer ):
@@ -151,8 +154,10 @@ def __init__(self,
151
154
else :
152
155
raise ValueError (
153
156
'Please initialize `Prune` with a supported layer. Layers should '
154
- 'either be a `PrunableLayer` instance, or should be supported by the '
155
- 'PruneRegistry. You passed: {input}' .format (input = layer .__class__ ))
157
+ 'either be supported by the PruneRegistry (built-in keras layers) or '
158
+ 'should be a `PrunableLayer` instance, or should has a customer '
159
+ 'defined `get_prunable_weights` method. You passed: '
160
+ '{input}' .format (input = layer .__class__ ))
156
161
157
162
self ._track_trackable (layer , name = 'layer' )
158
163
0 commit comments