@@ -280,7 +280,6 @@ template <typename To, typename From>
280280void convert_and_store (From f, void * dst) {
281281 *reinterpret_cast <To*>(dst) = static_cast <To>(f);
282282}
283- } // namespace internal
284283
285284template <typename CTYPE_COMMON>
286285using load_to_common_fn = CTYPE_COMMON (*)(const void *);
@@ -296,6 +295,17 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296295 return result;
297296}
298297
298+ template <typename CTYPE_COMMON, const char * op_name>
299+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
300+ const Tensor& t) {
301+ CTYPE_COMMON (*result)(const void *) = nullptr ;
302+ ET_SWITCH_TWO_TYPES (
303+ Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
304+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
305+ });
306+ return result;
307+ }
308+
299309template <typename CTYPE_COMMON>
300310using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void *);
301311
@@ -310,6 +320,75 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310320 return result;
311321}
312322
323+ template <typename CTYPE_COMMON, const char * op_name>
324+ store_common_to_tensor_fn<CTYPE_COMMON>
325+ get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
326+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
327+ ET_SWITCH_TWO_TYPES (
328+ Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
329+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
330+ });
331+ return result;
332+ }
333+ } // namespace internal
334+
335+ enum class SupportedTensorDtypes {
336+ REALHBBF16,
337+ BOOL_OR_BYTE,
338+ SAME_AS_COMMON,
339+ };
340+
341+ namespace internal {
342+ template <typename CTYPE_COMMON, const char * op_name>
343+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn (
344+ const Tensor& t,
345+ SupportedTensorDtypes dtypes) {
346+ switch (dtypes) {
347+ case SupportedTensorDtypes::REALHBBF16:
348+ return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
349+ case SupportedTensorDtypes::BOOL_OR_BYTE:
350+ return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
351+ case SupportedTensorDtypes::SAME_AS_COMMON: {
352+ constexpr auto common_scalar_type =
353+ CppTypeToScalarType<CTYPE_COMMON>::value;
354+ ET_CHECK_MSG (
355+ t.scalar_type () == common_scalar_type,
356+ " Unhandled dtype %s for %s" ,
357+ ::executorch::runtime::toString (common_scalar_type),
358+ op_name);
359+ return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
360+ }
361+ }
362+ ET_CHECK (false );
363+ return nullptr ;
364+ }
365+
366+ template <typename CTYPE_COMMON, const char * op_name>
367+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn (
368+ const Tensor& t,
369+ SupportedTensorDtypes dtypes) {
370+ switch (dtypes) {
371+ case SupportedTensorDtypes::REALHBBF16:
372+ return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
373+ case SupportedTensorDtypes::BOOL_OR_BYTE:
374+ return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
375+ t);
376+ case SupportedTensorDtypes::SAME_AS_COMMON: {
377+ constexpr auto common_scalar_type =
378+ CppTypeToScalarType<CTYPE_COMMON>::value;
379+ ET_CHECK_MSG (
380+ t.scalar_type () == common_scalar_type,
381+ " Unhandled dtype %s for %s" ,
382+ ::executorch::runtime::toString (common_scalar_type),
383+ op_name);
384+ return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
385+ }
386+ }
387+ ET_CHECK (false );
388+ return nullptr ;
389+ }
390+ } // namespace internal
391+
313392/* *
314393 * Useful for binary elementwise operators. For each element of the inputs,
315394 * perform a computation and write to the corresponding element of the output.
@@ -356,33 +435,45 @@ inline void apply_binary_elementwise_fn(
356435 *
357436 * In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358437 * |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359- * are passed as CTYPE_COMMON. We require compute_fun to return
360- * CTYPE_COMMON, and we require loading conversion functions from each
361- * input type to CTYPE_COMMON and a storing conversion from
362- * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363- * must take a void* pointing to an element of the corresponding
364- * tensor, load that element, and convert it to CTYPE_COMMON. The
365- * storing conversion function must have the signature
366- * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367- * and store it to the given location.
438+ * are passed as CTYPE_COMMON.
439+ *
440+ * Each tensor's supported dtypes set must be provided. The tensor
441+ * will be checked to ensure that its dtype falls into that set.
442+ *
443+ * op_name is used to support dtype selective build, as with the
444+ * ET_SWITCH family of macros. Note: because of C++17 quirks, you
445+ * can't pass a string literal for op_name. Instead, you should do the
446+ * following:
447+ *
448+ * static constexpr const char op_name[] = "my_op";
449+ * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368450 */
369- template <typename CTYPE_COMMON, typename Op>
451+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
370452inline void apply_ternary_elementwise_fn (
371453 const Op& compute_fun,
372454 const Tensor& a,
455+ SupportedTensorDtypes a_dtypes,
373456 const Tensor& b,
457+ SupportedTensorDtypes b_dtypes,
374458 const Tensor& c,
459+ SupportedTensorDtypes c_dtypes,
375460 const Tensor& out,
376- CTYPE_COMMON (*load_a_to_common)(const void *),
377- CTYPE_COMMON (*load_b_to_common)(const void *),
378- CTYPE_COMMON (*load_c_to_common)(const void *),
379- void (*store_common_to_out)(CTYPE_COMMON, void *)) {
461+ SupportedTensorDtypes out_dtypes) {
380462 const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
381463 const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
382464 const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
383465 const bool any_is_broadcasted =
384466 (a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385467
468+ const auto load_a_to_common =
469+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
470+ const auto load_b_to_common =
471+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
472+ const auto load_c_to_common =
473+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
474+ const auto store_common_to_out =
475+ internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
476+ out, out_dtypes);
386477 const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
387478 const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
388479 const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
0 commit comments