@@ -101,14 +101,15 @@ def _same(x):
101
101
return tf .constant (1 , dtype = x .dtype ) + tf .math .floor (- x_clip )
102
102
103
103
104
- def shard_tensors (axis : int , block_size : int ,
105
- * tensors : tf .Tensor ) -> Iterable [Sequence [tf .Tensor ]]:
104
+ def shard_tensors (
105
+ axis : int , block_size : int ,
106
+ tensors : Sequence [tf .Tensor ]) -> Iterable [Sequence [tf .Tensor ]]:
106
107
"""Consistently splits multiple tensors sharding-style.
107
108
108
109
Args:
109
110
axis: axis to be used to split tensors
110
111
block_size: block size to split tensors.
111
- * tensors: list of tensors.
112
+ tensors: list of tensors.
112
113
113
114
Returns:
114
115
List of shards, each shard has exactly one peace of each input tesnor.
@@ -211,7 +212,7 @@ def non_max_suppression_padded(boxes: tf.Tensor,
211
212
scores = tf .reshape (scores , [batch_size , boxes_size ])
212
213
block = max (1 , _RECOMMENDED_NMS_MEMORY // (boxes_size * boxes_size ))
213
214
indices = []
214
- for boxes_i , scores_i in shard_tensors (0 , block , boxes , scores ):
215
+ for boxes_i , scores_i in shard_tensors (0 , block , ( boxes , scores ) ):
215
216
indices .append (
216
217
_non_max_suppression_as_is (boxes_i , scores_i , output_size ,
217
218
iou_threshold , refinements ))
0 commit comments