@@ -27,6 +27,9 @@ limitations under the License.
2727#include " tensorflow/lite/kernels/internal/compatibility.h"
2828#include " tensorflow/lite/schema/schema_generated.h"
2929
30+ // TODO(sosagarcia): Rework all function implementations to wrap around the
31+ // compiler flatbuffer_conversions.
32+ // LINT.IfChange
3033namespace tflite {
3134
3235namespace {
@@ -928,6 +931,9 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
928931 return ParseStablehloShiftLeft (op, error_reporter, allocator,
929932 builtin_data);
930933 }
934+ case BuiltinOperator_STABLEHLO_CASE: {
935+ return ParseStablehloCase (op, error_reporter, allocator, builtin_data);
936+ }
931937 // TODO: skip param parsing for now since ops below don't have kernels
932938 case BuiltinOperator_STABLEHLO_SLICE:
933939 case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
@@ -2421,6 +2427,46 @@ TfLiteStatus ParseStablehloShiftLeft(const Operator* op,
24212427 return kTfLiteOk ;
24222428}
24232429
2430+ TfLiteStatus ParseStablehloCase (const Operator* op,
2431+ ErrorReporter* error_reporter,
2432+ BuiltinDataAllocator* allocator,
2433+ void ** builtin_data) {
2434+ CheckParsePointerParams (op, error_reporter, allocator, builtin_data);
2435+
2436+ SafeBuiltinDataAllocator safe_allocator (allocator);
2437+ auto params = safe_allocator.Allocate <TfLiteStablehloCaseParams>();
2438+
2439+ const StablehloCaseOptions* schema_params =
2440+ op->builtin_options_2_as_StablehloCaseOptions ();
2441+ if (schema_params) {
2442+ auto LoadAttr =
2443+ [&error_reporter](
2444+ int32_t * params_array, const size_t params_array_size_bytes,
2445+ const flatbuffers::Vector<int32_t >* const flatbuffer_vector,
2446+ const char * const attr_name) -> TfLiteStatus {
2447+ TfLiteStatus status = FlatBufferIntVectorToArray (
2448+ params_array_size_bytes, flatbuffer_vector, params_array,
2449+ error_reporter, " stablehlo.case" );
2450+ if (status != kTfLiteOk ) {
2451+ TF_LITE_REPORT_ERROR (error_reporter, " Check the '%s' attribute." ,
2452+ attr_name);
2453+ }
2454+ return status;
2455+ };
2456+
2457+ TF_LITE_ENSURE_STATUS (LoadAttr (params->branch_subgraph_indices ,
2458+ sizeof (params->branch_subgraph_indices ),
2459+ schema_params->branch_subgraph_indices (),
2460+ " branch subgraph indices" ));
2461+ params->num_branches = schema_params->branch_subgraph_indices ()->size ();
2462+ *builtin_data = params.release ();
2463+ return kTfLiteOk ;
2464+ }
2465+ TF_LITE_REPORT_ERROR (error_reporter,
2466+ " Could not get 'stablehlo.case' operation parameters." );
2467+ return kTfLiteError ;
2468+ }
2469+
24242470// We have this parse function instead of directly returning kTfLiteOk from the
24252471// switch-case in ParseOpData because this function is used as part of the
24262472// selective registration for the OpResolver implementation in micro.
@@ -2943,3 +2989,4 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
29432989}
29442990
29452991} // namespace tflite
2992+ // LINT.ThenChange(//tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.cc)
0 commit comments