Skip to content

Commit 7441665

Browse files
wecingtensorflower-gardener
authored andcommitted
Remove mlir::Type::isa() usages.
This is removed in MLIR upstream: llvm/llvm-project#135556 PiperOrigin-RevId: 748822888
1 parent 820834f commit 7441665

File tree

6 files changed

+12
-13
lines changed

6 files changed

+12
-13
lines changed

tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern<quantfork::StatisticsOp> {
226226
// Per axis quantization (or per channel quantization)
227227
int stats_num = op.getAxisStats()->getNumElements();
228228
if (stats_num == 0 || stats_num % 2 != 0) return failure();
229-
auto stats = op.getAxisStats()->dyn_cast<DenseFPElementsAttr>();
229+
auto stats =
230+
llvm::dyn_cast<DenseFPElementsAttr>(op.getAxisStats().value());
230231
if (!stats) return failure();
231232

232233
for (auto it = stats.begin(), e = stats.end(); it != e; ++it) {
@@ -671,9 +672,7 @@ class QuantizationPattern : public RewritePattern {
671672
void RewireFloatModelBackbone(Operation* quantized_op,
672673
Operation* float_op) const {
673674
for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) {
674-
if (!float_op->getResult(i)
675-
.getType()
676-
.cast<ShapedType>()
675+
if (!llvm::cast<ShapedType>(float_op->getResult(i).getType())
677676
.getElementType()
678677
.isF32()) {
679678
continue;

tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ def TfExecutor_Dialect : Dialect {
4747
}
4848

4949
// Control type.
50-
def TfeControlType : Type<CPred<"$_self.isa<ControlType>()">, "control">,
50+
def TfeControlType : Type<CPred<"llvm::isa<ControlType>($_self)">, "control">,
5151
BuildableType<"$_builder.getType<ControlType>()">;
5252

5353
// Token type.
54-
def TfeTokenType : Type<CPred<"$_self.isa<TokenType>()">, "token">,
54+
def TfeTokenType : Type<CPred<"llvm::isa<TokenType>($_self)">, "token">,
5555
BuildableType<"$_builder.getType<TokenType>()">;
5656

5757
// TODO(hinsu): Define and use TensorType instead of AnyType for data operands

tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ LogicalResult ParseExampleV2Op::verify() {
561561
template <typename CallOpClass>
562562
static LogicalResult VerifyPartitionedCall(CallOpClass op,
563563
SymbolTableCollection &symbolTable) {
564-
SymbolRefAttr func = op->getAttr("f").template cast<SymbolRefAttr>();
564+
SymbolRefAttr func = llvm::cast<SymbolRefAttr>(op->getAttr("f"));
565565
auto function = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(op, func);
566566
if (!function) {
567567
return op.emitError("'f' attribute refers to an undefined function: ")

tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ absl::Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
180180
absl::StrCat("Converting ", debugString(type), " to DataType"));
181181
}
182182

183-
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
184-
if (type.isa<mlir::tf_type::tftype##Type>()) { \
185-
*dtype = DT_##enumerant; \
186-
return OkStatus(); \
183+
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
184+
if (llvm::isa<mlir::tf_type::tftype##Type>(type)) { \
185+
*dtype = DT_##enumerant; \
186+
return OkStatus(); \
187187
}
188188
// NOLINTNEXTLINE
189189
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def"

tensorflow/core/ir/importexport/convert_types.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
171171
}
172172

173173
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
174-
if (type.isa<tftype##Type>()) { \
174+
if (llvm::isa<tftype##Type>(type)) { \
175175
*dtype = tensorflow::DT_##enumerant; \
176176
return ::tensorflow::OkStatus(); \
177177
}

tensorflow/core/ir/types/dialect.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ Type TensorFlowRefType::RemoveRef() {
631631
if (mlir::isa<Complex128RefType>(*this))
632632
return ComplexType::get(Float64Type::get(ctx));
633633
#define HANDLE_TF_TYPE(tftype, enumerant, name) \
634-
if (isa<tftype##RefType>()) return tftype##Type::get(ctx);
634+
if (mlir::isa<tftype##RefType>(*this)) return tftype##Type::get(ctx);
635635

636636
#define HANDLE_TF_REF_TYPE(tftype, enumerant, name)
637637
// NOLINTNEXTLINE

0 commit comments

Comments
 (0)