Skip to content

Commit a4a2527

Browse files
authored
Sync from upstream TF. (#3488)
1 parent f2b2b3f commit a4a2527

File tree

3 files changed

+45
-33
lines changed

3 files changed

+45
-33
lines changed

tensorflow/lite/core/api/flatbuffer_conversions.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,9 +1882,7 @@ TfLiteStatus ParseMul(const Operator* op, ErrorReporter* error_reporter,
18821882
params->activation =
18831883
ConvertActivation(schema_params->fused_activation_function());
18841884
} else {
1885-
// TODO(b/157480169): We should either return kTfLiteError or fill in some
1886-
// reasonable defaults in the params struct. We are not doing so until we
1887-
// better understand the ramifications of changing the legacy behavior.
1885+
// Default activation is none.
18881886
}
18891887

18901888
*builtin_data = params.release();
@@ -2430,6 +2428,18 @@ TfLiteStatus ParseStablehloComposite(const Operator* op,
24302428
const StableHLOCompositeOptions* schema_params =
24312429
op->builtin_options_2_as_StableHLOCompositeOptions();
24322430
if (schema_params) {
2431+
if (schema_params->name() == nullptr) {
2432+
TF_LITE_REPORT_ERROR(
2433+
error_reporter,
2434+
"'stablehlo.composite' missing required option 'name'.");
2435+
return kTfLiteError;
2436+
}
2437+
if (schema_params->composite_attributes() == nullptr) {
2438+
TF_LITE_REPORT_ERROR(error_reporter,
2439+
"'stablehlo.composite' missing required option "
2440+
"'composite_attributes'.");
2441+
return kTfLiteError;
2442+
}
24332443
params->name = schema_params->name()->c_str();
24342444
params->version = schema_params->version();
24352445
params->subgraph_index = schema_params->decomposition_subgraph_index();

tensorflow/lite/kernels/kernel_util.cc

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -528,51 +528,50 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
528528

529529
// Size of string is not constant, return 0 in such case.
530530
int TfLiteTypeGetSize(TfLiteType type) {
531+
int size_bits = TfLiteTypeGetSizeBits(type);
532+
if (size_bits % 8 == 0) {
533+
return size_bits / 8;
534+
} else {
535+
// For non-byte sized types, return 0.
536+
return 0;
537+
}
538+
}
539+
540+
int TfLiteTypeGetSizeBits(TfLiteType type) {
531541
switch (type) {
542+
case kTfLiteInt2:
543+
return 2;
544+
case kTfLiteInt4:
545+
case kTfLiteUInt4:
546+
return 4;
532547
case kTfLiteUInt8:
533-
static_assert(sizeof(uint8_t) == 1, "");
534-
return 1;
535548
case kTfLiteInt8:
536-
static_assert(sizeof(int8_t) == 1, "");
537-
return 1;
538-
case kTfLiteBool:
539-
return sizeof(bool);
549+
return 8;
540550
case kTfLiteUInt16:
541-
static_assert(sizeof(uint16_t) == 2, "");
542-
return 2;
543551
case kTfLiteInt16:
544-
static_assert(sizeof(int16_t) == 2, "");
545-
return 2;
546552
case kTfLiteFloat16:
547-
static_assert(sizeof(int16_t) == 2, "");
548-
return 2;
553+
case kTfLiteBFloat16:
554+
return 16;
549555
case kTfLiteFloat32:
550-
static_assert(sizeof(float) == 4, "");
551-
return 4;
552556
case kTfLiteInt32:
553-
static_assert(sizeof(int32_t) == 4, "");
554-
return 4;
555557
case kTfLiteUInt32:
556-
static_assert(sizeof(uint32_t) == 4, "");
557-
return 4;
558+
return 32;
558559
case kTfLiteInt64:
559-
static_assert(sizeof(int64_t) == 8, "");
560-
return 8;
561560
case kTfLiteUInt64:
562-
static_assert(sizeof(uint64_t) == 8, "");
563-
return 8;
564561
case kTfLiteFloat64:
565-
static_assert(sizeof(double) == 8, "");
566-
return 8;
567562
case kTfLiteComplex64:
568-
static_assert(sizeof(std::complex<float>) == 8, "");
569-
return 8;
563+
return 64;
570564
case kTfLiteComplex128:
571-
static_assert(sizeof(std::complex<double>) == 16, "");
572-
return 16;
573-
default:
574-
return 0;
565+
return 128;
566+
case kTfLiteBool:
567+
return sizeof(bool) * 8;
568+
case kTfLiteString:
569+
case kTfLiteNoType:
570+
case kTfLiteResource:
571+
case kTfLiteVariant:
572+
break;
575573
}
574+
return 0;
576575
}
577576

578577
bool IsMobilePlatform() {

tensorflow/lite/kernels/kernel_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,9 @@ TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context,
332332
// Return the size of given type in bytes. Return 0 in case of string.
333333
int TfLiteTypeGetSize(TfLiteType type);
334334

335+
// Return the size of given type in bits. Returns 0 in case of string.
336+
int TfLiteTypeGetSizeBits(TfLiteType type);
337+
335338
// Whether the current platform is mobile (Android or iOS).
336339
bool IsMobilePlatform();
337340

0 commit comments

Comments
 (0)