Skip to content

Commit 7c9a510

Browse files
chsiggcopybara-github
authored andcommitted
NFC: Use the free function variants for dyn_cast/cast/isa/....
The member functions in Type/Attribute/Value/Location/AffineExpr got [removed](llvm/llvm-project@0078cf7). PiperOrigin-RevId: 748263988
1 parent 6dc35e5 commit 7c9a510

File tree

16 files changed

+308
-278
lines changed

16 files changed

+308
-278
lines changed

BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,7 @@ tfrt_cc_library(
700700
"@llvm-project//llvm:Support",
701701
"@llvm-project//mlir:FuncDialect",
702702
"@llvm-project//mlir:IR",
703+
"@llvm-project//mlir:Support",
703704
],
704705
)
705706

@@ -1502,6 +1503,7 @@ tfrt_cc_library(
15021503
"@llvm-project//llvm:Support",
15031504
"@llvm-project//mlir:FuncDialect",
15041505
"@llvm-project//mlir:IR",
1506+
"@llvm-project//mlir:Support",
15051507
],
15061508
)
15071509

@@ -1596,6 +1598,7 @@ tfrt_cc_library(
15961598
":support",
15971599
"@llvm-project//llvm:Support",
15981600
"@llvm-project//mlir:IR",
1601+
"@llvm-project//mlir:Support",
15991602
],
16001603
)
16011604

cpp_tests/bef_converter/bef_attr_reader_test.cc

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/BuiltinAttributes.h"
2828
#include "mlir/IR/BuiltinOps.h"
2929
#include "mlir/IR/MLIRContext.h"
30+
#include "mlir/Support/LLVM.h"
3031
#include "mlir/Tools/mlir-translate/Translation.h"
3132
#include "tfrt/bef/bef_encoding.h"
3233
#include "tfrt/cpp_tests/test_util.h"
@@ -65,7 +66,7 @@ class BefAttrReaderTest : public ::testing::Test {
6566
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
6667
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
6768

68-
EXPECT_EQ(static_cast<T>(mlir_attr.template cast<mlir::IntegerAttr>()
69+
EXPECT_EQ(static_cast<T>(mlir::cast<mlir::IntegerAttr>(mlir_attr)
6970
.getValue()
7071
.getLimitedValue()),
7172
value);
@@ -109,9 +110,10 @@ TEST_F(BefAttrReaderTest, ReadF32Attribute) {
109110
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
110111
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
111112

112-
EXPECT_EQ(static_cast<float>(
113-
mlir_attr.cast<mlir::FloatAttr>().getValue().convertToFloat()),
114-
kTestFloat);
113+
EXPECT_EQ(
114+
static_cast<float>(
115+
mlir::cast<mlir::FloatAttr>(mlir_attr).getValue().convertToFloat()),
116+
kTestFloat);
115117
}
116118

117119
constexpr double kTestDeouble = -3.141592;
@@ -126,9 +128,10 @@ TEST_F(BefAttrReaderTest, ReadF64Attribute) {
126128
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
127129
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
128130

129-
EXPECT_EQ(static_cast<double>(
130-
mlir_attr.cast<mlir::FloatAttr>().getValue().convertToDouble()),
131-
kTestDeouble);
131+
EXPECT_EQ(
132+
static_cast<double>(
133+
mlir::cast<mlir::FloatAttr>(mlir_attr).getValue().convertToDouble()),
134+
kTestDeouble);
132135
}
133136

134137
constexpr char kTestString[] = "Hello, World";
@@ -143,7 +146,7 @@ TEST_F(BefAttrReaderTest, ReadStringAttribute) {
143146
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
144147
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
145148

146-
EXPECT_EQ(mlir_attr.cast<mlir::StringAttr>().getValue(), kTestString);
149+
EXPECT_EQ(mlir::cast<mlir::StringAttr>(mlir_attr).getValue(), kTestString);
147150
}
148151

149152
TEST_F(BefAttrReaderTest, ReadI32TypeAttribute) {
@@ -160,7 +163,8 @@ TEST_F(BefAttrReaderTest, ReadI32TypeAttribute) {
160163
auto read_attr = reader.ReadAttribute(attribute_type, offset);
161164
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
162165

163-
EXPECT_EQ(read_attr.cast<mlir::TypeAttr>().getValue(), mlir_attr.getValue());
166+
EXPECT_EQ(mlir::cast<mlir::TypeAttr>(read_attr).getValue(),
167+
mlir_attr.getValue());
164168
}
165169

166170
constexpr int64_t kTestShape[] = {1, 2, 3};
@@ -176,7 +180,7 @@ TEST_F(BefAttrReaderTest, ReadRankedShapeAttribute) {
176180
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
177181
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
178182

179-
auto shape_attr = mlir_attr.cast<tfrt::corert::ShapeAttr>();
183+
auto shape_attr = mlir::cast<tfrt::corert::ShapeAttr>(mlir_attr);
180184
EXPECT_TRUE(shape_attr.hasRank());
181185
EXPECT_EQ(shape_attr.getRank(), kTestShapeRank);
182186
for (int i = 0; i < kTestShapeRank; ++i) {
@@ -195,7 +199,7 @@ TEST_F(BefAttrReaderTest, ReadUnrankedShapeAttribute) {
195199
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
196200
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
197201

198-
auto shape_attr = mlir_attr.cast<tfrt::corert::ShapeAttr>();
202+
auto shape_attr = mlir::cast<tfrt::corert::ShapeAttr>(mlir_attr);
199203
EXPECT_FALSE(shape_attr.hasRank());
200204
}
201205

@@ -211,13 +215,14 @@ TEST_F(BefAttrReaderTest, ReadI32ArrayAttribute) {
211215
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
212216
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
213217

214-
auto array_attr = mlir_attr.cast<mlir::ArrayAttr>().getValue();
218+
auto array_attr = mlir::cast<mlir::ArrayAttr>(mlir_attr).getValue();
215219

216220
EXPECT_EQ(array_attr.size(), kTestI32ArraySize);
217221
for (int idx = 0; idx < kTestI32ArraySize; ++idx) {
218-
EXPECT_EQ(
219-
array_attr[idx].cast<mlir::IntegerAttr>().getValue().getLimitedValue(),
220-
kTestI32Array[idx]);
222+
EXPECT_EQ(mlir::cast<mlir::IntegerAttr>(array_attr[idx])
223+
.getValue()
224+
.getLimitedValue(),
225+
kTestI32Array[idx]);
221226
}
222227
}
223228

@@ -233,13 +238,14 @@ TEST_F(BefAttrReaderTest, ReadF64ArrayAttribute) {
233238
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
234239
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
235240

236-
auto array_attr = mlir_attr.cast<mlir::ArrayAttr>().getValue();
241+
auto array_attr = mlir::cast<mlir::ArrayAttr>(mlir_attr).getValue();
237242

238243
EXPECT_EQ(array_attr.size(), kTestF64ArraySize);
239244
for (int idx = 0; idx < kTestF64ArraySize; ++idx) {
240-
EXPECT_EQ(
241-
array_attr[idx].cast<mlir::FloatAttr>().getValue().convertToDouble(),
242-
kTestF64Array[idx]);
245+
EXPECT_EQ(mlir::cast<mlir::FloatAttr>(array_attr[idx])
246+
.getValue()
247+
.convertToDouble(),
248+
kTestF64Array[idx]);
243249
}
244250
}
245251

@@ -272,7 +278,7 @@ TEST_F(BefAttrReaderTest, ReadDenseAttribute) {
272278
auto mlir_attr = reader.ReadAttribute(attribute_type, offset);
273279
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(mlir_attr), attribute_type);
274280

275-
auto dense_attr = mlir_attr.cast<mlir::DenseElementsAttr>();
281+
auto dense_attr = mlir::cast<mlir::DenseElementsAttr>(mlir_attr);
276282
const auto shaped_type = dense_attr.getType();
277283

278284
EXPECT_EQ(
@@ -283,7 +289,7 @@ TEST_F(BefAttrReaderTest, ReadDenseAttribute) {
283289
EXPECT_EQ(shaped_type.getShape()[1], 2);
284290

285291
for (auto element : dense_attr.getValues<mlir::Attribute>()) {
286-
EXPECT_EQ(element.cast<mlir::FloatAttr>().getValue().convertToFloat(),
292+
EXPECT_EQ(mlir::cast<mlir::FloatAttr>(element).getValue().convertToFloat(),
287293
1.5f);
288294
}
289295
}
@@ -316,26 +322,28 @@ TEST_F(BefAttrReaderTest, EmitAggregateAttribute) {
316322
auto read_attr = reader.ReadAttribute(attribute_type, offset);
317323
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(read_attr), attribute_type);
318324

319-
auto aggregate_attr = read_attr.cast<mlir::ArrayAttr>();
325+
auto aggregate_attr = mlir::cast<mlir::ArrayAttr>(read_attr);
320326
EXPECT_EQ(aggregate_attr.size(), 3);
321327

322328
auto first = aggregate_attr[0];
323329
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(first),
324330
static_cast<BEFAttributeType>(DType::I32));
325-
EXPECT_EQ(static_cast<int32_t>(
326-
first.cast<mlir::IntegerAttr>().getValue().getLimitedValue()),
327-
kTestAggregateAttr1);
331+
EXPECT_EQ(
332+
static_cast<int32_t>(
333+
mlir::cast<mlir::IntegerAttr>(first).getValue().getLimitedValue()),
334+
kTestAggregateAttr1);
328335

329336
auto second = aggregate_attr[1];
330337
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(second),
331338
static_cast<BEFAttributeType>(DType::String));
332-
EXPECT_EQ(second.cast<mlir::StringAttr>().getValue(), kTestAggregateAttr2);
339+
EXPECT_EQ(mlir::cast<mlir::StringAttr>(second).getValue(),
340+
kTestAggregateAttr2);
333341

334342
auto third = aggregate_attr[2];
335343
EXPECT_EQ(BefAttrEmitter::GetBefAttributeType(third),
336344
static_cast<BEFAttributeType>(DType::F32));
337345
EXPECT_EQ(static_cast<float>(
338-
third.cast<mlir::FloatAttr>().getValue().convertToFloat()),
346+
mlir::cast<mlir::FloatAttr>(third).getValue().convertToFloat()),
339347
kTestAggregateAttr3);
340348
}
341349

cpp_tests/bef_converter/bef_location_emitter_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "gtest/gtest.h"
2020
#include "mlir/IR/Location.h"
2121
#include "mlir/IR/MLIRContext.h"
22+
#include "mlir/Support/LLVM.h"
2223
#include "tfrt/bef/bef_location.h"
2324

2425
namespace tfrt {
@@ -40,7 +41,7 @@ TEST_F(BefLocationEmitterTest, IsSupportedLocationNamedLoc) {
4041
EXPECT_TRUE(BefLocationEmitter::IsSupportedLocation(loc));
4142

4243
auto child = loc.getChildLoc();
43-
EXPECT_TRUE(child.isa<mlir::UnknownLoc>());
44+
EXPECT_TRUE(mlir::isa<mlir::UnknownLoc>(child));
4445
}
4546

4647
TEST_F(BefLocationEmitterTest, IsSupportedLocationCallSiteLoc) {

cpp_tests/bef_converter/bef_location_reader_test.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "gtest/gtest.h"
2525
#include "mlir/IR/Location.h"
2626
#include "mlir/IR/MLIRContext.h"
27+
#include "mlir/Support/LLVM.h"
2728

2829
namespace tfrt {
2930
namespace {
@@ -43,7 +44,8 @@ TEST_F(BefLocationReaderTest, FileLineColLoc) {
4344

4445
auto reader = BefLocationReader(emitter.GetStringsSectionEmitter().result(),
4546
emitter.result(), &context_);
46-
auto read_loc = reader.ReadLocation(offset).dyn_cast<mlir::FileLineColLoc>();
47+
auto read_loc =
48+
mlir::dyn_cast<mlir::FileLineColLoc>(reader.ReadLocation(offset));
4749
EXPECT_EQ(read_loc, loc);
4850
}
4951

@@ -57,7 +59,7 @@ TEST_F(BefLocationReaderTest, NameLoc) {
5759

5860
auto reader = BefLocationReader(emitter.GetStringsSectionEmitter().result(),
5961
emitter.result(), &context_);
60-
auto read_loc = reader.ReadLocation(offset).dyn_cast<mlir::NameLoc>();
62+
auto read_loc = mlir::dyn_cast<mlir::NameLoc>(reader.ReadLocation(offset));
6163
EXPECT_EQ(read_loc, loc);
6264
}
6365

@@ -71,7 +73,8 @@ TEST_F(BefLocationReaderTest, CallSiteLoc) {
7173

7274
auto reader = BefLocationReader(emitter.GetStringsSectionEmitter().result(),
7375
emitter.result(), &context_);
74-
auto read_loc = reader.ReadLocation(offset).dyn_cast<mlir::CallSiteLoc>();
76+
auto read_loc =
77+
mlir::dyn_cast<mlir::CallSiteLoc>(reader.ReadLocation(offset));
7578
EXPECT_EQ(read_loc, loc);
7679
}
7780

@@ -87,7 +90,7 @@ TEST_F(BefLocationReaderTest, FusedLoc) {
8790

8891
auto reader = BefLocationReader(emitter.GetStringsSectionEmitter().result(),
8992
emitter.result(), &context_);
90-
auto read_loc = reader.ReadLocation(offset).dyn_cast<mlir::FusedLoc>();
93+
auto read_loc = mlir::dyn_cast<mlir::FusedLoc>(reader.ReadLocation(offset));
9194
EXPECT_EQ(read_loc, loc);
9295
}
9396

include/tfrt/core_runtime/opdefs/corert_utils.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#define TFRT_CORE_RUNTIME_OPDEFS_CORERT_UTILS_H_
2020

2121
#include "mlir/IR/OpImplementation.h"
22+
#include "mlir/Support/LLVM.h"
2223

2324
using namespace mlir;
2425

@@ -29,9 +30,9 @@ template <typename OpTy>
2930
LogicalResult VerifyExecuteOpImpl(OpTy op) {
3031
auto op_attr_array = op.getOpAttrs().getValue();
3132
for (auto op_attr : op_attr_array) {
32-
auto key_value = op_attr.template dyn_cast<ArrayAttr>();
33+
auto key_value = mlir::dyn_cast<ArrayAttr>(op_attr);
3334
if (!key_value || key_value.getValue().size() != 2 ||
34-
!key_value.getValue()[0].template isa<StringAttr>())
35+
!mlir::isa<StringAttr>(key_value.getValue()[0]))
3536
return op.emitOpError() << "each op_attr should be a key-value pair, "
3637
"where the key is a string";
3738
}
@@ -43,12 +44,12 @@ void PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter &p, OpTy op) {
4344
auto op_func_attrs = op.getOpFuncAttrs();
4445
if (!op_func_attrs.empty()) {
4546
auto print_key_value = [&](mlir::Attribute attr) {
46-
auto key_value = attr.cast<mlir::ArrayAttr>().getValue();
47+
auto key_value = mlir::cast<mlir::ArrayAttr>(attr).getValue();
4748
assert(key_value.size() == 2 && "invalid named attribute format.");
4849
auto key = key_value[0];
4950
auto value = key_value[1];
5051

51-
p << key.cast<mlir::StringAttr>().getValue();
52+
p << mlir::cast<mlir::StringAttr>(key).getValue();
5253
p << " = ";
5354
p << value;
5455
};
@@ -65,12 +66,12 @@ void PrintExecuteOpImpl(OpAsmPrinter &p, OpTy op) {
6566
auto op_attrs = op.getOpAttrs();
6667
if (!op_attrs.empty()) {
6768
auto print_key_value = [&](mlir::Attribute attr) {
68-
auto key_value = attr.cast<ArrayAttr>().getValue();
69+
auto key_value = mlir::cast<ArrayAttr>(attr).getValue();
6970
assert(key_value.size() == 2 && "invalid named attribute format.");
7071
auto key = key_value[0];
7172
auto value = key_value[1];
7273

73-
p << key.cast<StringAttr>().getValue();
74+
p << mlir::cast<StringAttr>(key).getValue();
7475
p << " = ";
7576
p << value;
7677
};

lib/basic_kernels/opdefs/tfrt_base.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "tfrt/basic_kernels/opdefs/tfrt_base.h"
1818

19+
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Transforms/InliningUtils.h"
2021
#include "tfrt/basic_kernels/opdefs/basic_kernels.h"
2122
#include "tfrt/basic_kernels/opdefs/types.h"
@@ -95,13 +96,13 @@ mlir::Type TFRTDialect::parseType(mlir::DialectAsmParser &parser) const {
9596

9697
void TFRTDialect::printType(mlir::Type type,
9798
mlir::DialectAsmPrinter &printer) const {
98-
if (type.isa<compiler::ChainType>()) {
99+
if (mlir::isa<compiler::ChainType>(type)) {
99100
printer << "chain";
100-
} else if (type.isa<compiler::StringType>()) {
101+
} else if (mlir::isa<compiler::StringType>(type)) {
101102
printer << "string";
102-
} else if (type.isa<compiler::TensorTypeType>()) {
103+
} else if (mlir::isa<compiler::TensorTypeType>(type)) {
103104
printer << "tensor_type";
104-
} else if (type.isa<compiler::DeviceType>()) {
105+
} else if (mlir::isa<compiler::DeviceType>(type)) {
105106
printer << "device";
106107
} else {
107108
llvm_unreachable("unknown tfrt type");

lib/bef_converter/bef_to_mlir/bef_to_mlir.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
#include "mlir/IR/Location.h"
4545
#include "mlir/IR/MLIRContext.h"
4646
#include "mlir/Parser/Parser.h"
47+
#include "mlir/Support/LLVM.h"
4748
#include "mlir/Support/LogicalResult.h"
4849
#include "tfrt/bef/bef_encoding.h"
4950
#include "tfrt/bef/bef_reader.h"
@@ -931,7 +932,7 @@ mlir::LogicalResult BEFFunctionReader::AddDefinition(mlir::Value value,
931932
return mlir::failure();
932933
}
933934
assert(reg_info.type == value.getType() ||
934-
reg_info.type.isa<mlir::NoneType>());
935+
mlir::isa<mlir::NoneType>(reg_info.type));
935936
reg_info.value = value;
936937
return mlir::success();
937938
}

0 commit comments

Comments
 (0)