Skip to content

Commit 763e5b7

Browse files
authored
Remove gpu_custom_call logic. (#9600)
This PR removes the implementation of `gpu_custom_call`. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Delete both `ops/gpu_custom_call.cpp` and `ops/gpu_custom_call.h` - (`tensor_methods.{h,cpp}`) Remove `tensor_methods::gpu_custom_call` - (`ops/xla_ops.{h,cpp}`) Remove `OpKindWrapper xla_gpu_custom_call` global variable - (`init_python_bindings.cpp`) Remove the Python API function `_xla_gpu_custom_call` - (`init_python_bindings.cpp`) Make `XlaCustomCall` function into a TPU specific function `TpuCustomCall`
1 parent 004f19e commit 763e5b7

File tree

9 files changed

+5
-157
lines changed

9 files changed

+5
-157
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
#include "pybind11/pytypes.h"
3939
#include "pybind11/stl.h"
4040
#include "pybind11/stl_bind.h"
41-
#include "status.h"
4241
#include "torch_xla/csrc/XLANativeFunctions.h"
4342
#include "torch_xla/csrc/aten_autograd_ops.h"
4443
#include "torch_xla/csrc/aten_fallback.h"
@@ -345,22 +344,18 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
345344
return replica_groups;
346345
}
347346

348-
std::vector<at::Tensor> XlaCustomCall(
347+
std::vector<at::Tensor> TpuCustomCall(
349348
const std::vector<at::Tensor>& inputs, const std::string& payload,
350349
const std::vector<std::vector<int64_t>>& output_shapes,
351-
const std::vector<py::object>& output_dtypes, bool is_tpu) {
350+
const std::vector<py::object>& output_dtypes) {
352351
std::vector<at::ScalarType> dtypes;
353352
dtypes.reserve(output_dtypes.size());
354353
for (auto& dtype : output_dtypes) {
355354
dtypes.push_back(reinterpret_cast<THPDtype*>(dtype.ptr())->scalar_type);
356355
}
357356
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_inputs,
358357
bridge::GetXlaTensors(inputs));
359-
if (is_tpu) {
360-
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
361-
xla_inputs, payload, output_shapes, dtypes));
362-
}
363-
return bridge::AtenFromXlaTensors(tensor_methods::gpu_custom_call(
358+
return bridge::AtenFromXlaTensors(tensor_methods::tpu_custom_call(
364359
xla_inputs, payload, output_shapes, dtypes));
365360
}
366361

@@ -3058,8 +3053,7 @@ void InitXlaModuleBindings(py::module m) {
30583053
const std::vector<std::vector<int64_t>>& output_shapes,
30593054
const std::vector<py::object>& output_dtypes)
30603055
-> std::vector<at::Tensor> {
3061-
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
3062-
/*is_tpu=*/true);
3056+
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
30633057
})
30643058
.def("_has_cuda_support",
30653059
[]() {
@@ -3069,14 +3063,6 @@ void InitXlaModuleBindings(py::module m) {
30693063
return false;
30703064
#endif
30713065
})
3072-
.def("_xla_gpu_custom_call",
3073-
[](const std::vector<at::Tensor>& inputs, const std::string& payload,
3074-
const std::vector<std::vector<int64_t>>& output_shapes,
3075-
const std::vector<py::object>& output_dtypes)
3076-
-> std::vector<at::Tensor> {
3077-
return XlaCustomCall(inputs, payload, output_shapes, output_dtypes,
3078-
/*is_tpu=*/false);
3079-
})
30803066
.def("_xla_register_custom_call_target",
30813067
[](const std::string& fn_name, const py::capsule& function_ptr,
30823068
const std::string& platform) {

torch_xla/csrc/ops/gpu_custom_call.cpp

Lines changed: 0 additions & 37 deletions
This file was deleted.

torch_xla/csrc/ops/gpu_custom_call.h

Lines changed: 0 additions & 25 deletions
This file was deleted.

torch_xla/csrc/ops/xla_ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,5 @@ const OpKindWrapper xla_unselect("xla::unselect");
3939
const OpKindWrapper xla_update_slice("xla::update_slice");
4040
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
4141
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");
42-
const OpKindWrapper xla_gpu_custom_call("xla::gpu_custom_call");
4342

4443
} // namespace torch_xla

torch_xla/csrc/ops/xla_ops.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ extern const OpKindWrapper xla_unselect;
6464
extern const OpKindWrapper xla_update_slice;
6565
extern const OpKindWrapper xla_custom_sharding;
6666
extern const OpKindWrapper xla_tpu_custom_call;
67-
extern const OpKindWrapper xla_gpu_custom_call;
6867

6968
} // namespace torch_xla
7069

71-
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_
70+
#endif // XLA_TORCH_XLA_CSRC_OPS_XLA_OPS_H_

torch_xla/csrc/tensor_methods.cpp

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565
#include "torch_xla/csrc/ops/generic.h"
6666
#include "torch_xla/csrc/ops/generic_slice.h"
6767
#include "torch_xla/csrc/ops/get_dimensions_size.h"
68-
#include "torch_xla/csrc/ops/gpu_custom_call.h"
6968
#include "torch_xla/csrc/ops/hardtanh_backward.h"
7069
#include "torch_xla/csrc/ops/index_ops.h"
7170
#include "torch_xla/csrc/ops/index_select.h"
@@ -767,45 +766,6 @@ void custom_sharding_(
767766
input->SetShardingSpec(*sharding_spec);
768767
}
769768

770-
std::vector<XLATensorPtr> gpu_custom_call(
771-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
772-
const std::vector<std::vector<int64_t>>& output_shapes,
773-
const std::vector<at::ScalarType>& output_dtypes) {
774-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
775-
776-
std::vector<torch::lazy::Value> values;
777-
values.reserve(inputs.size());
778-
for (const auto& input : inputs) {
779-
values.push_back(input->GetIrValue());
780-
}
781-
782-
XLA_CHECK_EQ(output_shapes.size(), output_dtypes.size());
783-
std::vector<xla::Shape> output_xla_shapes;
784-
output_xla_shapes.reserve(output_shapes.size());
785-
for (size_t i = 0; i < output_shapes.size(); ++i) {
786-
output_xla_shapes.push_back(xla::ShapeUtil::MakeShape(
787-
MakeXlaPrimitiveType(output_dtypes[i], &(inputs[0]->GetDevice())),
788-
output_shapes[i]));
789-
}
790-
791-
auto node = torch_xla::MakeNode<GpuCustomCall>(
792-
values, xla::ShapeUtil::MakeTupleShape(output_xla_shapes), payload);
793-
794-
std::vector<XLATensorPtr> outputs;
795-
outputs.reserve(output_shapes.size());
796-
for (size_t i = 0; i < output_shapes.size(); ++i) {
797-
outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i),
798-
output_dtypes[i],
799-
/*delay_eager_execution=*/true));
800-
}
801-
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
802-
if (graph_executor->UseEagerMode()) {
803-
// Execute the HLO that will run the `custom` and in one hlo
804-
graph_executor->ApplyEagerSync(outputs);
805-
}
806-
return outputs;
807-
}
808-
809769
std::vector<XLATensorPtr> tpu_custom_call(
810770
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
811771
const std::vector<std::vector<int64_t>>& output_shapes,

torch_xla/csrc/tensor_methods.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,6 @@ void custom_sharding_(
103103
const std::shared_ptr<XLATensor::ShardingSpec>& spec,
104104
const CustomSharding::Type& type = CustomSharding::Type::kSharding);
105105

106-
std::vector<XLATensorPtr> gpu_custom_call(
107-
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
108-
const std::vector<std::vector<int64_t>>& output_shapes,
109-
const std::vector<at::ScalarType>& output_dtypes);
110-
111106
std::vector<XLATensorPtr> tpu_custom_call(
112107
const std::vector<XLATensorPtr>& inputs, const std::string& payload,
113108
const std::vector<std::vector<int64_t>>& output_shapes,

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1281,31 +1281,6 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input, const std::string& type,
12811281
output_shape);
12821282
}
12831283

1284-
std::vector<xla::XlaOp> BuildGpuCustomCall(
1285-
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
1286-
const std::string& payload) {
1287-
std::vector<xla::Shape> input_shapes;
1288-
input_shapes.reserve(inputs.size());
1289-
for (const auto& input : inputs) {
1290-
input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input));
1291-
}
1292-
1293-
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
1294-
xla::XlaOp outputs = xla::CustomCallWithLayout(
1295-
inputs[0].builder(),
1296-
/*call_target_name=*/"triton_kernel_call", inputs, output_shape,
1297-
input_shapes, payload, false, {}, nullptr,
1298-
xla::CustomCallSchedule::SCHEDULE_NONE,
1299-
xla::CustomCallApiVersion::API_VERSION_STATUS_RETURNING);
1300-
std::vector<xla::XlaOp> result;
1301-
int num_outputs = output_shape.tuple_shapes_size();
1302-
result.reserve(num_outputs);
1303-
for (int i = 0; i < num_outputs; ++i) {
1304-
result.push_back(xla::GetTupleElement(outputs, i));
1305-
}
1306-
return result;
1307-
}
1308-
13091284
std::vector<xla::XlaOp> BuildTpuCustomCall(
13101285
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
13111286
const std::string& payload) {

torch_xla/csrc/xla_lower_util.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ std::vector<xla::XlaOp> BuildTpuCustomCall(
162162
xla::XlaOp BuildNms(xla::XlaOp boxes, xla::XlaOp scores,
163163
xla::XlaOp iou_threshold);
164164

165-
std::vector<xla::XlaOp> BuildGpuCustomCall(
166-
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
167-
const std::string& payload);
168-
169165
} // namespace torch_xla
170166

171167
#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_

0 commit comments

Comments
 (0)