Skip to content

Commit b737b4a

Browse files
Deprecate ShapeOfXlaOp in favor of GetShape (#9381)
Co-authored-by: Zhanyong Wan <[email protected]>
1 parent f8bd6c4 commit b737b4a

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

torch_xla/csrc/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,8 @@ ptxla_cc_library(
243243
srcs = ["shape_builder.cpp"],
244244
hdrs = ["shape_builder.h"],
245245
deps = [
246+
"@com_google_absl//absl/base:core_headers",
247+
"@com_google_absl//absl/base:nullability",
246248
"@com_google_absl//absl/types:span",
247249
"@xla//xla:shape_util",
248250
"@xla//xla:types",

torch_xla/csrc/shape_helper.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
namespace torch_xla {
77

88
const xla::Shape& ShapeHelper::ShapeOfXlaOp(xla::XlaOp op) {
9-
const xla::Shape* shape = ConsumeValue(op.builder()->GetShapePtr(op));
10-
return *shape;
9+
return *ConsumeValue(GetShape(op));
10+
}
11+
12+
absl::StatusOr<const xla::Shape * absl_nonnull> GetShape(xla::XlaOp op) {
13+
return op.builder()->GetShapePtr(op);
1114
}
1215

1316
} // namespace torch_xla

torch_xla/csrc/shape_helper.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
11
#ifndef XLA_TORCH_XLA_SHAPE_HELPER_H_
22
#define XLA_TORCH_XLA_SHAPE_HELPER_H_
33

4+
#include "absl/base/attributes.h"
5+
#include "absl/base/nullability.h"
46
#include "xla/hlo/builder/xla_builder.h"
57

68
namespace torch_xla {
79

810
class ShapeHelper {
911
public:
1012
// Returns the shape of the given XLA operation.
13+
ABSL_DEPRECATED(
14+
"Use GetShape(op) instead. ShapeOfXlaOp() blindly "
15+
"de-references StatusOr returned by XLA, which is unsafe.")
1116
static const xla::Shape& ShapeOfXlaOp(xla::XlaOp op);
1217
};
1318

19+
// Returns the shape of the given XLA operation.
20+
absl::StatusOr<const xla::Shape * absl_nonnull> GetShape(xla::XlaOp op);
21+
1422
} // namespace torch_xla
1523

1624
#endif // XLA_TORCH_XLA_SHAPE_HELPER_H_

0 commit comments

Comments
 (0)