diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index f9db5dcb88b4c..a3ee76bf7026c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1167,10 +1167,61 @@ bool checkErrorIfPad(Operation *op) { return true; } +// Returns true if the operation takes no input operands, excluding attributes. +static bool isNullaryOperation(Operation *op) { + if (isa(op) || isa(op) || + isa(op) || isa(op)) + return true; + return false; +} + +bool checkErrorIfCondIf(Operation *op) { + auto ifOp = dyn_cast(op); + if (!ifOp) + return true; + + // Whether the types and shapes of operands between the input/output list and + // internal regions are validated by the operation verifier. However, with + // support for the simplified form - where redundant operand notations are + // omitted - is not conformant to the specification. According to the + // specification, all operands passed into an operation must be explicitly + // declared at each operation's structure. This code section verify that the + // operation's form complies with this requirement. + + // Returns true if the region uses no external input operands. + auto isNullaryRegion = [](Region ®ion) -> bool { + bool noLiveInValue = true; + region.walk([&noLiveInValue](Operation *op) { + if (!isNullaryOperation(op)) { + noLiveInValue = false; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return noLiveInValue; + }; + + mlir::Region &thenGraph = ifOp.getThenGraph(); + mlir::Region &elseGraph = ifOp.getElseGraph(); + bool isThenGraphNullaryRegion = isNullaryRegion(thenGraph); + bool isElseGraphNullaryRegion = isNullaryRegion(elseGraph); + bool isInputListEmpty = ifOp.getInputList().size() == 0; + + if ((isInputListEmpty != isThenGraphNullaryRegion) || + (isInputListEmpty != isElseGraphNullaryRegion)) { + op->emitOpError() + << "the current simplified form is not strictly conformant to the " + "spec, please use the generic format\n"; + return false; + } + + return true; +} + LogicalResult TosaValidation::applyErrorIfCheck(Operation *op) { if (!checkErrorIfResize(op) || !checkErrorIfMul(op) || !checkErrorIfTable(op) || !checkErrorIfRescale(op) || - !checkErrorIfPad(op)) + !checkErrorIfPad(op) || !checkErrorIfCondIf(op)) return failure(); return success(); } diff --git a/mlir/test/Dialect/Tosa/error_if_check.mlir b/mlir/test/Dialect/Tosa/error_if_check.mlir index ac161128694cc..1f25132d6bcf3 100644 --- a/mlir/test/Dialect/Tosa/error_if_check.mlir +++ b/mlir/test/Dialect/Tosa/error_if_check.mlir @@ -225,3 +225,17 @@ func.func @test_error_i32_unsigned_output(%arg0: tensor<1xi8>) -> tensor<1xi32> %0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = false, output_unsigned = true} : (tensor<1xi8>, tensor<1xi16>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1xi32> return %0 : tensor<1xi32> } + +// ----- +// CHECK-LABEL: cond_if_simplified_form +func.func @test_cond_if_simplified_form(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op the current simplified form is not strictly conformant to the spec, please use the generic format}} + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +}