|
| 1 | +From c79f202be6fde802b4e5d697a5925d7eccea3d25 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Hugo Mano <hugo@zml.ai> |
| 3 | +Date: Wed, 5 Feb 2025 19:25:03 +0100 |
| 4 | +Subject: [PATCH] Added FFI handler registration API to the FFI PjRt |
| 5 | + |
| 6 | +PR: https://github.com/openxla/xla/pull/13420 |
| 7 | +--- |
| 8 | + xla/pjrt/c/BUILD | 5 ++++ |
| 9 | + xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++ |
| 10 | + xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++-- |
| 11 | + 3 files changed, 54 insertions(+), 2 deletions(-) |
| 12 | + |
| 13 | +diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD |
| 14 | +index ad1b3987fe..0598281ad1 100644 |
| 15 | +--- a/xla/pjrt/c/BUILD |
| 16 | ++++ b/xla/pjrt/c/BUILD |
| 17 | +@@ -69,7 +69,12 @@ cc_library( |
| 18 | + ":pjrt_c_api_wrapper_impl", |
| 19 | + "//xla/ffi:execution_context", |
| 20 | + "//xla/ffi:type_id_registry", |
| 21 | ++ "//xla/ffi:ffi_api", |
| 22 | ++ "//xla/ffi/api:c_api", |
| 23 | ++ "//xla/ffi/api:ffi", |
| 24 | ++ "//xla/service:custom_call_target_registry", |
| 25 | + "@com_google_absl//absl/status", |
| 26 | ++ "@com_google_absl//absl/strings:str_format", |
| 27 | + ], |
| 28 | + ) |
| 29 | + |
| 30 | +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h |
| 31 | +index c5766f2a19..3d74e7cbf3 100644 |
| 32 | +--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h |
| 33 | ++++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h |
| 34 | +@@ -67,12 +67,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); |
| 35 | + // Adds a user data to the execute context. |
| 36 | + typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); |
| 37 | + |
| 38 | ++struct PJRT_FFI_Register_Handler_Args { |
| 39 | ++ size_t struct_size; |
| 40 | ++ const char* target_name; |
| 41 | ++ size_t target_name_size; |
| 42 | ++ int api_version; // 0 for an untyped call, 1 -- for typed |
| 43 | ++ void* handler; |
| 44 | ++ const char* platform_name; |
| 45 | ++ size_t platform_name_size; |
| 46 | ++}; |
| 47 | ++PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler); |
| 48 | ++ |
| 49 | ++// Registers an FFI call handler for a specific platform. |
| 50 | ++typedef PJRT_Error* PJRT_FFI_Register_Handler( |
| 51 | ++ PJRT_FFI_Register_Handler_Args* args); |
| 52 | ++ |
| 53 | + typedef struct PJRT_FFI_Extension { |
| 54 | + size_t struct_size; |
| 55 | + PJRT_Extension_Type type; |
| 56 | + PJRT_Extension_Base* next; |
| 57 | + PJRT_FFI_TypeID_Register* type_id_register; |
| 58 | + PJRT_FFI_UserData_Add* user_data_add; |
| 59 | ++ PJRT_FFI_Register_Handler* register_handler; |
| 60 | + } PJRT_FFI; |
| 61 | + PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); |
| 62 | + |
| 63 | +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc |
| 64 | +index 0375b39d0b..3527a0756e 100644 |
| 65 | +--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc |
| 66 | ++++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc |
| 67 | +@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and |
| 68 | + limitations under the License. |
| 69 | + ==============================================================================*/ |
| 70 | + |
| 71 | +-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" |
| 72 | ++#include <string> |
| 73 | + |
| 74 | + #include "absl/status/status.h" |
| 75 | ++#include "absl/strings/str_format.h" |
| 76 | ++#include "xla/ffi/api/c_api.h" |
| 77 | ++#include "xla/ffi/api/ffi.h" |
| 78 | + #include "xla/ffi/execution_context.h" |
| 79 | +-#include "xla/ffi/type_id_registry.h" |
| 80 | ++ #include "xla/ffi/type_id_registry.h" |
| 81 | ++#include "xla/ffi/ffi_api.h" |
| 82 | + #include "xla/pjrt/c/pjrt_c_api.h" |
| 83 | + #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" |
| 84 | + #include "xla/pjrt/c/pjrt_c_api_helpers.h" |
| 85 | + #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" |
| 86 | ++#include "xla/service/custom_call_target_registry.h" |
| 87 | + |
| 88 | + namespace pjrt { |
| 89 | + |
| 90 | +@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { |
| 91 | + return nullptr; |
| 92 | + } |
| 93 | + |
| 94 | ++static PJRT_Error* PJRT_FFI_Register_Handler( |
| 95 | ++ PJRT_FFI_Register_Handler_Args* args) { |
| 96 | ++ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( |
| 97 | ++ "PJRT_FFI_Register_Handler_Args", |
| 98 | ++ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); |
| 99 | ++ std::string target_name(args->target_name, args->target_name_size); |
| 100 | ++ std::string platform_name(args->platform_name, args->platform_name_size); |
| 101 | ++ switch (args->api_version) { |
| 102 | ++ case 0: |
| 103 | ++ xla::CustomCallTargetRegistry::Global()->Register( |
| 104 | ++ target_name, args->handler, platform_name); |
| 105 | ++ return nullptr; |
| 106 | ++ case 1: |
| 107 | ++ xla::ffi::Ffi::RegisterStaticHandler( |
| 108 | ++ xla::ffi::GetXlaFfiApi(), target_name, platform_name, |
| 109 | ++ reinterpret_cast<XLA_FFI_Handler*>(args->handler)); |
| 110 | ++ return nullptr; |
| 111 | ++ default: |
| 112 | ++ return new PJRT_Error{absl::UnimplementedError( |
| 113 | ++ absl::StrFormat("API version %d not supported for PJRT GPU plugin. " |
| 114 | ++ "Supported versions are 0 and 1.", |
| 115 | ++ args->api_version))}; |
| 116 | ++ } |
| 117 | ++} |
| 118 | ++ |
| 119 | + PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { |
| 120 | + return { |
| 121 | + /*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE, |
| 122 | +@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { |
| 123 | + /*next=*/next, |
| 124 | + /*type_id_register=*/PJRT_FFI_TypeID_Register, |
| 125 | + /*user_data_add=*/PJRT_FFI_UserData_Add, |
| 126 | ++ /*register_handler=*/PJRT_FFI_Register_Handler, |
| 127 | + }; |
| 128 | + } |
| 129 | + |
| 130 | +-- |
| 131 | +2.39.5 (Apple Git-154) |
0 commit comments