|
23 | 23 | #include "xla/hlo/builder/xla_builder.h"
|
24 | 24 | #include "xla/hlo/builder/xla_computation.h"
|
25 | 25 | #include "xla/literal.h"
|
26 |
| -#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" |
27 | 26 | #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
28 | 27 | #include "xla/pjrt/pjrt_api.h"
|
29 | 28 | #include "xla/pjrt/pjrt_c_api_client.h"
|
@@ -152,8 +151,6 @@ PjRtComputationClient::Create() {
|
152 | 151 | }
|
153 | 152 |
|
154 | 153 | PjRtComputationClient::~PjRtComputationClient() {
|
155 |
| - // In the GPU case, the PjRtClient depends on the DistributedRuntimeClient |
156 |
| - // tracked in XlaCoordinator, so the PjRtClient must be destroyed first. |
157 | 154 | client_ = nullptr;
|
158 | 155 | coordinator_ = nullptr;
|
159 | 156 | }
|
@@ -1038,45 +1035,6 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo(
|
1038 | 1035 | };
|
1039 | 1036 | }
|
1040 | 1037 |
|
1041 |
| -void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name, |
1042 |
| - void* function_ptr, |
1043 |
| - const std::string& platform) { |
1044 |
| - if (platform != "CUDA") { |
1045 |
| - XLA_ERROR() << "Custom call targets can only be registered for " |
1046 |
| - "PJRT CUDA runtime."; |
1047 |
| - return; |
1048 |
| - } |
1049 |
| - |
1050 |
| - auto* c_api_client = dynamic_cast<xla::PjRtCApiClient*>(client_.get()); |
1051 |
| - if (!c_api_client) { |
1052 |
| - XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(fn_name, function_ptr, platform); |
1053 |
| - return; |
1054 |
| - } |
1055 |
| - const PJRT_Api* pjrt_api = c_api_client->pjrt_c_api(); |
1056 |
| - |
1057 |
| - // See openxla reference: |
1058 |
| - // https://github.com/openxla/xla/blob/b604c8d87df842002a7a8de79a434026329fbcb2/xla/pjrt/c/pjrt_c_api_gpu_test.cc#L414 |
1059 |
| - const PJRT_Extension_Base* next = |
1060 |
| - reinterpret_cast<const PJRT_Extension_Base*>(pjrt_api->extension_start); |
1061 |
| - while (next != nullptr && |
1062 |
| - next->type != |
1063 |
| - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { |
1064 |
| - next = next->next; |
1065 |
| - } |
1066 |
| - XLA_CHECK(next) << "Custom call extension not found"; |
1067 |
| - PJRT_Gpu_Register_Custom_Call_Args args; |
1068 |
| - args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; |
1069 |
| - args.function_name = fn_name.c_str(); |
1070 |
| - args.function_name_size = fn_name.size(); |
1071 |
| - args.api_version = 0; |
1072 |
| - args.handler_execute = function_ptr; |
1073 |
| - PJRT_Error* error = |
1074 |
| - reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args); |
1075 |
| - if (error) { |
1076 |
| - XLA_ERROR() << error->status; |
1077 |
| - } |
1078 |
| -} |
1079 |
| - |
1080 | 1038 | void PjRtComputationClient::OnReadyCallback(
|
1081 | 1039 | ComputationClient::DataPtr data, const std::function<void()>& callback) {
|
1082 | 1040 | std::shared_ptr<xla::PjRtBuffer> buffer;
|
|
0 commit comments