Skip to content

Commit 1e68317

Browse files
committed
Internal change.
PiperOrigin-RevId: 492381485
1 parent c0c87ec commit 1e68317

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

official/vision/modeling/layers/edgetpu.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,15 @@ def _same(x):
101101
return tf.constant(1, dtype=x.dtype) + tf.math.floor(-x_clip)
102102

103103

104-
def shard_tensors(axis: int, block_size: int,
105-
*tensors: tf.Tensor) -> Iterable[Sequence[tf.Tensor]]:
104+
def shard_tensors(
105+
axis: int, block_size: int,
106+
tensors: Sequence[tf.Tensor]) -> Iterable[Sequence[tf.Tensor]]:
106107
"""Consistently splits multiple tensors sharding-style.
107108
108109
Args:
109110
axis: axis to be used to split tensors
110111
block_size: block size to split tensors.
111-
*tensors: list of tensors.
112+
tensors: list of tensors.
112113
113114
Returns:
114115
List of shards, each shard has exactly one peace of each input tesnor.
@@ -211,7 +212,7 @@ def non_max_suppression_padded(boxes: tf.Tensor,
211212
scores = tf.reshape(scores, [batch_size, boxes_size])
212213
block = max(1, _RECOMMENDED_NMS_MEMORY // (boxes_size * boxes_size))
213214
indices = []
214-
for boxes_i, scores_i in shard_tensors(0, block, boxes, scores):
215+
for boxes_i, scores_i in shard_tensors(0, block, (boxes, scores)):
215216
indices.append(
216217
_non_max_suppression_as_is(boxes_i, scores_i, output_size,
217218
iou_threshold, refinements))

official/vision/modeling/layers/edgetpu_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_shard_tensors(self):
191191
[15, 16, 17, 18, 19],
192192
[20, 21, 22, 23, 24],
193193
]])
194-
for i, (a_i, b_i) in enumerate(edgetpu.shard_tensors(1, 3, a, b)):
194+
for i, (a_i, b_i) in enumerate(edgetpu.shard_tensors(1, 3, (a, b))):
195195
self.assertAllEqual(a_i, a[:, i * 3:i * 3 + 3])
196196
self.assertAllEqual(b_i, b[:, i * 3:i * 3 + 3, :])
197197

@@ -233,7 +233,7 @@ def test_top_k_sharded_fusion_vs_top_k_unsharded(self, axis: int):
233233
shape=axis * [1] + [10000], dtype=tf.float32)
234234
top_1000_direct: tf.Tensor = tf.math.top_k(sample, 1000).values
235235
top_1000_sharded: Optional[tf.Tensor] = None
236-
for (piece,) in edgetpu.shard_tensors(axis, 1500, sample):
236+
for (piece,) in edgetpu.shard_tensors(axis, 1500, (sample,)):
237237
(top_1000_sharded,) = edgetpu.concat_and_top_k(
238238
1000, (top_1000_sharded, piece))
239239
self.assertAllEqual(top_1000_direct, top_1000_sharded)

0 commit comments

Comments
 (0)