@@ -62,8 +62,8 @@ ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) {
62
62
63
63
xla::XlaOp GetPromotedMask (xla::XlaOp mask, const xla::Shape& input_shape) {
64
64
const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp (mask);
65
- xla::Shape promoted_mask_shape =
66
- GetValueOrThrow ( XlaHelpers::GetPromotedShape (mask_shape, input_shape));
65
+ XLA_ASSIGN_OR_THROW ( xla::Shape promoted_mask_shape,
66
+ XlaHelpers::GetPromotedShape (mask_shape, input_shape));
67
67
return XlaHelpers::ImplicitBroadcast (mask, mask_shape, promoted_mask_shape);
68
68
}
69
69
@@ -150,7 +150,9 @@ xla::XlaComputation MakeScatterComputation(
150
150
if (combiner != nullptr ) {
151
151
result = combiner (p0, result);
152
152
}
153
- return GetValueOrThrow (cb.Build (result));
153
+ XLA_ASSIGN_OR_THROW (xla::XlaComputation scatter_computation,
154
+ cb.Build (result));
155
+ return scatter_computation;
154
156
}
155
157
156
158
xla::XlaOp CreateIndexAlongDim (
@@ -543,8 +545,8 @@ std::vector<xla::XlaOp> CreateBroadcastTensors(
543
545
for (const xla::XlaOp operand : operands) {
544
546
const xla::Shape& operand_shape = ShapeHelper::ShapeOfXlaOp (operand);
545
547
operand_shapes.push_back (operand_shape);
546
- result_shape = GetValueOrThrow (
547
- XlaHelpers::GetPromotedShape ( result_shape, operand_shape));
548
+ XLA_ASSIGN_OR_THROW ( result_shape, XlaHelpers::GetPromotedShape (
549
+ result_shape, operand_shape));
548
550
}
549
551
std::vector<xla::XlaOp> result;
550
552
for (size_t i = 0 ; i < operands.size (); ++i) {
@@ -1366,54 +1368,59 @@ std::vector<xla::XlaOp> BuildBoxSelectionLoop(int64_t num_boxes,
1366
1368
// 3. The actual IoU threshold matrix.
1367
1369
init_values[2 ] = iou_threshold_mask;
1368
1370
1369
- return GetValueOrThrow (xla::WhileLoopHelper (
1370
- [=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1371
- xla::XlaOp box_index = values[0 ];
1372
- // Check: current loop counter is within bounds, i.e. has a
1373
- // corresponding box.
1374
- return xla::Lt (box_index,
1375
- xla::ConstantR0<IndexType>(builder, num_boxes));
1376
- },
1377
- [=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1378
- const xla::XlaOp ONE = xla::One (builder, XLAIndexType);
1379
- const xla::XlaOp ZERO = xla::Zero (builder, XLAIndexType);
1380
-
1381
- xla::XlaOp box_index = values[0 ];
1382
- xla::XlaOp state = values[1 ];
1383
- xla::XlaOp iou_threshold_mask = values[2 ];
1384
-
1385
- // Retrieve the IoU mask row corresponding to this box.
1386
- xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice (
1387
- iou_threshold_mask, {box_index, ZERO}, {1 , num_boxes});
1388
-
1389
- // Update the current state with the IoU mask.
1390
- // Basically, sets to false every box X whose IoU with the current box
1391
- // is less-than or equal than the given threshold.
1392
- xla::XlaOp updated_state = xla::And (
1393
- state,
1394
- // Update the mask so that if we select this box
1395
- // (i.e. state[box] == true), we don't de-select it.
1396
- xla::DynamicUpdateSlice (
1397
- // Before that, we need to pre-process the mask.
1398
- // 1. Negate the mask: if this box is selected, we only want
1399
- // those that have a low intersection ratio.
1400
- // 2. Reshape it to: [num_boxes].
1401
- xla::Reshape (xla::Not (box_iou_threshold_mask), {num_boxes}),
1402
- xla::ConstantR1<bool >(builder, {true }), {box_index}));
1403
-
1404
- // Flag: should this box (loop counter) be included in the output?
1405
- xla::XlaOp should_include = xla::DynamicSlice (state, {box_index}, {1 });
1406
- // Pick the new values of state, depending on whether we should include
1407
- // this box or not.
1408
- xla::XlaOp new_state =
1409
- xla::Select (xla::BroadcastInDim (should_include, {num_boxes}, {0 }),
1410
- updated_state, state);
1411
-
1412
- xla::XlaOp next_box_index = box_index + ONE;
1413
- return std::vector<xla::XlaOp>{next_box_index, new_state,
1414
- iou_threshold_mask};
1415
- },
1416
- init_values, " BoxSelectionLoop" , builder));
1371
+ XLA_ASSIGN_OR_THROW (
1372
+ std::vector<xla::XlaOp> result,
1373
+ xla::WhileLoopHelper (
1374
+ [=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1375
+ xla::XlaOp box_index = values[0 ];
1376
+ // Check: current loop counter is within bounds, i.e. has a
1377
+ // corresponding box.
1378
+ return xla::Lt (box_index,
1379
+ xla::ConstantR0<IndexType>(builder, num_boxes));
1380
+ },
1381
+ [=](absl::Span<const xla::XlaOp> values, xla::XlaBuilder* builder) {
1382
+ const xla::XlaOp ONE = xla::One (builder, XLAIndexType);
1383
+ const xla::XlaOp ZERO = xla::Zero (builder, XLAIndexType);
1384
+
1385
+ xla::XlaOp box_index = values[0 ];
1386
+ xla::XlaOp state = values[1 ];
1387
+ xla::XlaOp iou_threshold_mask = values[2 ];
1388
+
1389
+ // Retrieve the IoU mask row corresponding to this box.
1390
+ xla::XlaOp box_iou_threshold_mask = xla::DynamicSlice (
1391
+ iou_threshold_mask, {box_index, ZERO}, {1 , num_boxes});
1392
+
1393
+ // Update the current state with the IoU mask.
1394
+ // Basically, sets to false every box X whose IoU with the current
1395
+ // box is less-than or equal than the given threshold.
1396
+ xla::XlaOp updated_state = xla::And (
1397
+ state,
1398
+ // Update the mask so that if we select this box
1399
+ // (i.e. state[box] == true), we don't de-select it.
1400
+ xla::DynamicUpdateSlice (
1401
+ // Before that, we need to pre-process the mask.
1402
+ // 1. Negate the mask: if this box is selected, we only
1403
+ // want
1404
+ // those that have a low intersection ratio.
1405
+ // 2. Reshape it to: [num_boxes].
1406
+ xla::Reshape (xla::Not (box_iou_threshold_mask), {num_boxes}),
1407
+ xla::ConstantR1<bool >(builder, {true }), {box_index}));
1408
+
1409
+ // Flag: should this box (loop counter) be included in the output?
1410
+ xla::XlaOp should_include =
1411
+ xla::DynamicSlice (state, {box_index}, {1 });
1412
+ // Pick the new values of state, depending on whether we should
1413
+ // include this box or not.
1414
+ xla::XlaOp new_state = xla::Select (
1415
+ xla::BroadcastInDim (should_include, {num_boxes}, {0 }),
1416
+ updated_state, state);
1417
+
1418
+ xla::XlaOp next_box_index = box_index + ONE;
1419
+ return std::vector<xla::XlaOp>{next_box_index, new_state,
1420
+ iou_threshold_mask};
1421
+ },
1422
+ init_values, " BoxSelectionLoop" , builder));
1423
+ return result;
1417
1424
}
1418
1425
1419
1426
xla::XlaOp BuildNms (xla::XlaOp boxes, xla::XlaOp scores,
0 commit comments