diff --git a/examples/tarai.py b/examples/tarai.py new file mode 100644 index 0000000..daf1609 --- /dev/null +++ b/examples/tarai.py @@ -0,0 +1,8 @@ +def tarai(x: int, y: int, z: int) -> int: + if x > y: + return tarai( tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y) ) + else: + return y + + +print(tarai(14, 7, 0)) diff --git a/examples/tarai_prim.py b/examples/tarai_prim.py new file mode 100644 index 0000000..d14bcf4 --- /dev/null +++ b/examples/tarai_prim.py @@ -0,0 +1,16 @@ +from lyrt import from_prim, native +from lyrt.prim import Int + +p1 = Int[8](1) + + +@native(gc="none") +def tarai(x: Int[8], y: Int[8], z: Int[8]) -> Int[8]: + if x > y: + return tarai( tarai(x - p1, y, z), tarai(y - p1, z, x), tarai(z - p1, x, y) ) + else: + return y + + +ans = tarai(Int[8](14), Int[8](7), Int[8](0)) +print(from_prim(ans)) diff --git a/src/lython/dialects/cpp/PyVerifier.cpp b/src/lython/dialects/cpp/PyVerifier.cpp index 6cd265d..04958e8 100644 --- a/src/lython/dialects/cpp/PyVerifier.cpp +++ b/src/lython/dialects/cpp/PyVerifier.cpp @@ -860,4 +860,64 @@ LogicalResult NumLeOp::verify() { return success(); } +LogicalResult NumLtOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + if (lhsType != rhsType) + return emitOpError("operand types must match"); + if (!isPyIntType(lhsType) && !isPyFloatType(lhsType)) + return emitOpError("operands must be !py.int or !py.float"); + if (!isPyBoolType(getResult().getType())) + return emitOpError("result must be !py.bool"); + return success(); +} + +LogicalResult NumGtOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + if (lhsType != rhsType) + return emitOpError("operand types must match"); + if (!isPyIntType(lhsType) && !isPyFloatType(lhsType)) + return emitOpError("operands must be !py.int or !py.float"); + if (!isPyBoolType(getResult().getType())) + return emitOpError("result must be !py.bool"); + return success(); +} + +LogicalResult NumGeOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + if (lhsType != rhsType) + return emitOpError("operand types must match"); + if (!isPyIntType(lhsType) && !isPyFloatType(lhsType)) + return emitOpError("operands must be !py.int or !py.float"); + if (!isPyBoolType(getResult().getType())) + return emitOpError("result must be !py.bool"); + return success(); +} + +LogicalResult NumEqOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + if (lhsType != rhsType) + return emitOpError("operand types must match"); + if (!isPyIntType(lhsType) && !isPyFloatType(lhsType)) + return emitOpError("operands must be !py.int or !py.float"); + if (!isPyBoolType(getResult().getType())) + return emitOpError("result must be !py.bool"); + return success(); +} + +LogicalResult NumNeOp::verify() { + Type lhsType = getLhs().getType(); + Type rhsType = getRhs().getType(); + if (lhsType != rhsType) + return emitOpError("operand types must match"); + if (!isPyIntType(lhsType) && !isPyFloatType(lhsType)) + return emitOpError("operands must be !py.int or !py.float"); + if (!isPyBoolType(getResult().getType())) + return emitOpError("result must be !py.bool"); + return success(); +} + } // namespace py diff --git a/src/lython/dialects/tablegen/PyDialect.td b/src/lython/dialects/tablegen/PyDialect.td index 21d3b83..e8712e6 100644 --- a/src/lython/dialects/tablegen/PyDialect.td +++ b/src/lython/dialects/tablegen/PyDialect.td @@ -529,6 +529,91 @@ def Py_NumLeOp : Py_Op<"num.le", [Pure]> { let hasVerifier = 1; } +def Py_NumLtOp : Py_Op<"num.lt", [Pure]> { + let summary = "Python numeric < comparison"; + let description = [{ + Compares two numeric values and returns a `!py.bool` object. + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + + let results = (outs Py_BoolType:$result); + + let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)"; + let hasVerifier = 1; +} + +def Py_NumGtOp : Py_Op<"num.gt", [Pure]> { + let summary = "Python numeric > comparison"; + let description = [{ + Compares two numeric values and returns a `!py.bool` object. + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + + let results = (outs Py_BoolType:$result); + + let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)"; + let hasVerifier = 1; +} + +def Py_NumGeOp : Py_Op<"num.ge", [Pure]> { + let summary = "Python numeric >= comparison"; + let description = [{ + Compares two numeric values and returns a `!py.bool` object. + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + + let results = (outs Py_BoolType:$result); + + let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)"; + let hasVerifier = 1; +} + +def Py_NumEqOp : Py_Op<"num.eq", [Pure]> { + let summary = "Python numeric == comparison"; + let description = [{ + Compares two numeric values and returns a `!py.bool` object. + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + + let results = (outs Py_BoolType:$result); + + let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)"; + let hasVerifier = 1; +} + +def Py_NumNeOp : Py_Op<"num.ne", [Pure]> { + let summary = "Python numeric != comparison"; + let description = [{ + Compares two numeric values and returns a `!py.bool` object. + }]; + + let arguments = (ins + AnyType:$lhs, + AnyType:$rhs + ); + + let results = (outs Py_BoolType:$result); + + let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)"; + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // Class instantiation //===----------------------------------------------------------------------===// diff --git a/src/lython/lowering/PyOptimizationPass.cpp b/src/lython/lowering/PyOptimizationPass.cpp index 2a27858..5bef4bb 100644 --- a/src/lython/lowering/PyOptimizationPass.cpp +++ b/src/lython/lowering/PyOptimizationPass.cpp @@ -135,6 +135,17 @@ bool cleanupDeadTuples(ModuleOp module) { return !toErase.empty(); } +/// Remove unused TupleEmptyOps. +void removeUnusedTupleEmpties(ModuleOp module) { + SmallVector toErase; + module.walk([&](TupleEmptyOp op) { + if (op.getResult().use_empty()) + toErase.push_back(op); + }); + for (auto op : toErase) + op->erase(); +} + /// Hoist integer constants to entry block and perform CSE. void hoistIntConstants(ModuleOp module) { module.walk([&](func::FuncOp func) { @@ -391,23 +402,65 @@ void eliminateBoolBoxingUnboxing(ModuleOp module) { } } - // Only optimize if the pattern matches exactly: + // Only optimize if the pattern matches: // - One LyBool_AsBool call - // - One Ly_DecRef call + // - Optional Ly_DecRef call (bool singletons are immortal) // - No other users - if (!asBoolCall || !decRefCall || hasOtherUsers) + if (!asBoolCall || hasOtherUsers) continue; // Replace uses of the AsBool result with the original i1 value asBoolCall.getResult().replaceAllUsesWith(i1Value); // Erase the operations (in reverse order of dependencies) - decRefCall->erase(); + if (decRefCall) + decRefCall->erase(); asBoolCall->erase(); fromBoolCall->erase(); } } +/// CSE for LyLong_FromI64 for small integers (-5 to 256). +/// Small integers are immortal, so sharing is safe. +void cseSmallIntFromI64(ModuleOp module) { + module.walk([&](func::FuncOp func) { + if (func.isExternal()) + return; + + llvm::DenseMap cache; + SmallVector toErase; + + func.walk([&](LLVM::CallOp callOp) { + auto callee = callOp.getCallee(); + if (!callee || *callee != "LyLong_FromI64") + return; + if (callOp.getNumOperands() != 1) + return; + auto constOp = + callOp.getOperand(0).getDefiningOp(); + if (!constOp) + return; + auto intAttr = llvm::dyn_cast(constOp.getValue()); + if (!intAttr) + return; + int64_t value = intAttr.getInt(); + if (value < -5 || value > 256) + return; + + auto it = cache.find(value); + if (it != cache.end()) { + callOp.getResult().replaceAllUsesWith(it->second.getResult()); + toErase.push_back(callOp); + } else { + cache[value] = callOp; + } + }); + + for (auto op : toErase) + op->erase(); + }); +} + /// CSE for LLVM constants within each function. void cseConstants(ModuleOp module) { module.walk([&](func::FuncOp func) { @@ -495,6 +548,7 @@ std::unique_ptr> createPyOptimizationPass() { void runPreLoweringOptimizations(ModuleOp module) { cleanupDeadTuples(module); + removeUnusedTupleEmpties(module); hoistIntConstants(module); removeSmallIntDecrefs(module); } @@ -504,6 +558,7 @@ void runPostLoweringOptimizations(ModuleOp module) { replaceEmptyTupleNew(module); cseSingletonGetters(module); eliminateBoolBoxingUnboxing(module); + cseSmallIntFromI64(module); cseConstants(module); eliminateDeadCode(module); } diff --git a/src/lython/lowering/PyValueLowering.cpp b/src/lython/lowering/PyValueLowering.cpp index a125cd2..8fb9455 100644 --- a/src/lython/lowering/PyValueLowering.cpp +++ b/src/lython/lowering/PyValueLowering.cpp @@ -268,6 +268,215 @@ struct NumLeLowering : public OpConversionPattern { } }; +// Type-specialized lowering for NumLtOp +struct NumLtLowering : public OpConversionPattern { + NumLtLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(NumLtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + if (!module) + return failure(); + auto *typeConverter = + static_cast(getTypeConverter()); + RuntimeAPI runtime(module, rewriter, *typeConverter); + Type resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) + return failure(); + + // For int operands, use LyLong_Compare + LyBool_FromBool + if (isPyIntType(op.getLhs().getType()) && + isPyIntType(op.getRhs().getType())) { + // LyLong_Compare returns int: -1 (less), 0 (equal), 1 (greater) + auto cmpCall = runtime.call(op.getLoc(), RuntimeSymbols::kLongCompare, + rewriter.getI32Type(), adaptor.getOperands()); + // For <: compare result < 0 + Value zero = rewriter.create( + op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value ltZero = rewriter.create( + op.getLoc(), LLVM::ICmpPredicate::slt, cmpCall.getResult(), zero); + // Convert i1 to LyBool + auto boolCall = runtime.call(op.getLoc(), RuntimeSymbols::kBoolFromBool, + resultType, ValueRange{ltZero}); + rewriter.replaceOp(op, boolCall.getResults()); + } else { + auto call = runtime.call(op.getLoc(), RuntimeSymbols::kNumberLt, + resultType, adaptor.getOperands()); + rewriter.replaceOp(op, call.getResults()); + } + return success(); + } +}; + +// Type-specialized lowering for NumGtOp +struct NumGtLowering : public OpConversionPattern { + NumGtLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(NumGtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + if (!module) + return failure(); + auto *typeConverter = + static_cast(getTypeConverter()); + RuntimeAPI runtime(module, rewriter, *typeConverter); + Type resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) + return failure(); + + // For int operands, use LyLong_Compare + LyBool_FromBool + if (isPyIntType(op.getLhs().getType()) && + isPyIntType(op.getRhs().getType())) { + // LyLong_Compare returns int: -1 (less), 0 (equal), 1 (greater) + auto cmpCall = runtime.call(op.getLoc(), RuntimeSymbols::kLongCompare, + rewriter.getI32Type(), adaptor.getOperands()); + // For >: compare result > 0 + Value zero = rewriter.create( + op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value gtZero = rewriter.create( + op.getLoc(), LLVM::ICmpPredicate::sgt, cmpCall.getResult(), zero); + // Convert i1 to LyBool + auto boolCall = runtime.call(op.getLoc(), RuntimeSymbols::kBoolFromBool, + resultType, ValueRange{gtZero}); + rewriter.replaceOp(op, boolCall.getResults()); + } else { + auto call = runtime.call(op.getLoc(), RuntimeSymbols::kNumberGt, + resultType, adaptor.getOperands()); + rewriter.replaceOp(op, call.getResults()); + } + return success(); + } +}; + +// Type-specialized lowering for NumGeOp +struct NumGeLowering : public OpConversionPattern { + NumGeLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(NumGeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + if (!module) + return failure(); + auto *typeConverter = + static_cast(getTypeConverter()); + RuntimeAPI runtime(module, rewriter, *typeConverter); + Type resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) + return failure(); + + // For int operands, use LyLong_Compare + LyBool_FromBool + if (isPyIntType(op.getLhs().getType()) && + isPyIntType(op.getRhs().getType())) { + // LyLong_Compare returns int: -1 (less), 0 (equal), 1 (greater) + auto cmpCall = runtime.call(op.getLoc(), RuntimeSymbols::kLongCompare, + rewriter.getI32Type(), adaptor.getOperands()); + // For >=: compare result >= 0 + Value zero = rewriter.create( + op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value geZero = rewriter.create( + op.getLoc(), LLVM::ICmpPredicate::sge, cmpCall.getResult(), zero); + // Convert i1 to LyBool + auto boolCall = runtime.call(op.getLoc(), RuntimeSymbols::kBoolFromBool, + resultType, ValueRange{geZero}); + rewriter.replaceOp(op, boolCall.getResults()); + } else { + auto call = runtime.call(op.getLoc(), RuntimeSymbols::kNumberGe, + resultType, adaptor.getOperands()); + rewriter.replaceOp(op, call.getResults()); + } + return success(); + } +}; + +// Type-specialized lowering for NumEqOp +struct NumEqLowering : public OpConversionPattern { + NumEqLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(NumEqOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + if (!module) + return failure(); + auto *typeConverter = + static_cast(getTypeConverter()); + RuntimeAPI runtime(module, rewriter, *typeConverter); + Type resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) + return failure(); + + // For int operands, use LyLong_Compare + LyBool_FromBool + if (isPyIntType(op.getLhs().getType()) && + isPyIntType(op.getRhs().getType())) { + // LyLong_Compare returns int: -1 (less), 0 (equal), 1 (greater) + auto cmpCall = runtime.call(op.getLoc(), RuntimeSymbols::kLongCompare, + rewriter.getI32Type(), adaptor.getOperands()); + // For ==: compare result == 0 + Value zero = rewriter.create( + op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value eqZero = rewriter.create( + op.getLoc(), LLVM::ICmpPredicate::eq, cmpCall.getResult(), zero); + // Convert i1 to LyBool + auto boolCall = runtime.call(op.getLoc(), RuntimeSymbols::kBoolFromBool, + resultType, ValueRange{eqZero}); + rewriter.replaceOp(op, boolCall.getResults()); + } else { + auto call = runtime.call(op.getLoc(), RuntimeSymbols::kNumberEq, + resultType, adaptor.getOperands()); + rewriter.replaceOp(op, call.getResults()); + } + return success(); + } +}; + +// Type-specialized lowering for NumNeOp +struct NumNeLowering : public OpConversionPattern { + NumNeLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) + : OpConversionPattern(converter, ctx) {} + + LogicalResult + matchAndRewrite(NumNeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = op->getParentOfType(); + if (!module) + return failure(); + auto *typeConverter = + static_cast(getTypeConverter()); + RuntimeAPI runtime(module, rewriter, *typeConverter); + Type resultType = typeConverter->convertType(op.getResult().getType()); + if (!resultType) + return failure(); + + // For int operands, use LyLong_Compare + LyBool_FromBool + if (isPyIntType(op.getLhs().getType()) && + isPyIntType(op.getRhs().getType())) { + // LyLong_Compare returns int: -1 (less), 0 (equal), 1 (greater) + auto cmpCall = runtime.call(op.getLoc(), RuntimeSymbols::kLongCompare, + rewriter.getI32Type(), adaptor.getOperands()); + // For !=: compare result != 0 + Value zero = rewriter.create( + op.getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); + Value neZero = rewriter.create( + op.getLoc(), LLVM::ICmpPredicate::ne, cmpCall.getResult(), zero); + // Convert i1 to LyBool + auto boolCall = runtime.call(op.getLoc(), RuntimeSymbols::kBoolFromBool, + resultType, ValueRange{neZero}); + rewriter.replaceOp(op, boolCall.getResults()); + } else { + auto call = runtime.call(op.getLoc(), RuntimeSymbols::kNumberNe, + resultType, adaptor.getOperands()); + rewriter.replaceOp(op, call.getResults()); + } + return success(); + } +}; struct CastToPrimLowering : public OpConversionPattern { CastToPrimLowering(PyLLVMTypeConverter &converter, MLIRContext *ctx) : OpConversionPattern(converter, ctx) {} @@ -474,10 +683,11 @@ void populatePyValueLoweringPatterns(PyLLVMTypeConverter &typeConverter, auto *ctx = patterns.getContext(); patterns .add(typeConverter, - ctx); + FloatConstantLowering, NumAddLowering, NumSubLowering, NumLtLowering, + NumLeLowering, NumGtLowering, NumGeLowering, NumEqLowering, + NumNeLowering, CastToPrimLowering, CastFromPrimLowering, + ClassNewLowering, AttrGetLowering, AttrSetLowering, ClassOpLowering>( + typeConverter, ctx); } } // namespace py diff --git a/src/lython/lowering/RuntimeLoweringPass.cpp b/src/lython/lowering/RuntimeLoweringPass.cpp index f473598..7a24d59 100644 --- a/src/lython/lowering/RuntimeLoweringPass.cpp +++ b/src/lython/lowering/RuntimeLoweringPass.cpp @@ -276,8 +276,9 @@ struct RuntimeLoweringPass target.addIllegalOp< StrConstantOp, IntConstantOp, FloatConstantOp, TupleEmptyOp, TupleCreateOp, DictEmptyOp, DictInsertOp, NoneOp, FuncObjectOp, - NumAddOp, NumSubOp, NumLeOp, CastToPrimOp, CastIdentityOp, UpcastOp, - IncRefOp, DecRefOp, ClassNewOp, AttrGetOp, AttrSetOp, ClassOp>(); + NumAddOp, NumSubOp, NumLtOp, NumLeOp, NumGtOp, NumGeOp, NumEqOp, + NumNeOp, CastToPrimOp, CastIdentityOp, UpcastOp, IncRefOp, + DecRefOp, ClassNewOp, AttrGetOp, AttrSetOp, ClassOp>(); ScopedDiagnosticHandler diagHandler(ctx, materializationFilter); auto result = diff --git a/src/lython/lowering/RuntimeSupport.h b/src/lython/lowering/RuntimeSupport.h index dc67d8f..14399da 100644 --- a/src/lython/lowering/RuntimeSupport.h +++ b/src/lython/lowering/RuntimeSupport.h @@ -29,7 +29,12 @@ struct RuntimeSymbols { // Generic numeric operations (with type dispatch) static constexpr llvm::StringLiteral kNumberAdd{"LyNumber_Add"}; static constexpr llvm::StringLiteral kNumberSub{"LyNumber_Sub"}; + static constexpr llvm::StringLiteral kNumberLt{"LyNumber_Lt"}; static constexpr llvm::StringLiteral kNumberLe{"LyNumber_Le"}; + static constexpr llvm::StringLiteral kNumberGt{"LyNumber_Gt"}; + static constexpr llvm::StringLiteral kNumberGe{"LyNumber_Ge"}; + static constexpr llvm::StringLiteral kNumberEq{"LyNumber_Eq"}; + static constexpr llvm::StringLiteral kNumberNe{"LyNumber_Ne"}; // Type-specialized integer operations (inlinable fast paths) static constexpr llvm::StringLiteral kLongAdd{"LyLong_Add"}; static constexpr llvm::StringLiteral kLongSub{"LyLong_Sub"}; diff --git a/src/lython/runtime/lyrt.h b/src/lython/runtime/lyrt.h index a083664..e59a6e4 100644 --- a/src/lython/runtime/lyrt.h +++ b/src/lython/runtime/lyrt.h @@ -28,7 +28,12 @@ LyLongObject *LyLong_FromI64(std::int64_t value); LyFloatObject *LyFloat_FromDouble(double value); LyObject *LyNumber_Add(LyObject *lhs, LyObject *rhs); LyObject *LyNumber_Sub(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Lt(LyObject *lhs, LyObject *rhs); LyObject *LyNumber_Le(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Gt(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Ge(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Eq(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Ne(LyObject *lhs, LyObject *rhs); bool LyBool_AsBool(LyObject *object); LyObject *Ly_CallVectorcall(LyObject *callable, LyTupleObject *posargs, diff --git a/src/lython/runtime/objects/long.cpp b/src/lython/runtime/objects/long.cpp index b85e842..31b7f9b 100644 --- a/src/lython/runtime/objects/long.cpp +++ b/src/lython/runtime/objects/long.cpp @@ -646,15 +646,11 @@ LyObject *LyNumber_Add(LyObject *lhs, LyObject *rhs) { return reinterpret_cast(result); } - // float + any or any + float - if (lhs->ob_type == floatType || rhs->ob_type == floatType) { - double left = (lhs->ob_type == floatType) - ? reinterpret_cast(lhs)->value - : LyLong_AsDouble(reinterpret_cast(lhs)); - double right = (rhs->ob_type == floatType) - ? reinterpret_cast(rhs)->value - : LyLong_AsDouble(reinterpret_cast(rhs)); - return reinterpret_cast(LyFloat_FromDouble(left + right)); + // float + float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + auto *result = LyFloat_Add(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(result); } return nullptr; @@ -675,15 +671,36 @@ LyObject *LyNumber_Sub(LyObject *lhs, LyObject *rhs) { return reinterpret_cast(result); } - // float - any or any - float - if (lhs->ob_type == floatType || rhs->ob_type == floatType) { - double left = (lhs->ob_type == floatType) - ? reinterpret_cast(lhs)->value - : LyLong_AsDouble(reinterpret_cast(lhs)); - double right = (rhs->ob_type == floatType) - ? reinterpret_cast(rhs)->value - : LyLong_AsDouble(reinterpret_cast(rhs)); - return reinterpret_cast(LyFloat_FromDouble(left - right)); + // float - float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + auto *result = LyFloat_Sub(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(result); + } + + return nullptr; +} + +LyObject *LyNumber_Lt(LyObject *lhs, LyObject *rhs) { + if (!lhs || !rhs) { + return nullptr; + } + + auto *longType = &LyLong_Type(); + auto *floatType = &LyFloat_Type(); + + // int < int + if (lhs->ob_type == longType && rhs->ob_type == longType) { + int cmp = LyLong_Compare(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(LyBool_FromBool(cmp < 0)); + } + + // float < float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; + return reinterpret_cast(LyBool_FromBool(left < right)); } return nullptr; @@ -704,20 +721,116 @@ LyObject *LyNumber_Le(LyObject *lhs, LyObject *rhs) { return reinterpret_cast(LyBool_FromBool(cmp <= 0)); } - // float <= any or any <= float - if (lhs->ob_type == floatType || rhs->ob_type == floatType) { - double left = (lhs->ob_type == floatType) - ? reinterpret_cast(lhs)->value - : LyLong_AsDouble(reinterpret_cast(lhs)); - double right = (rhs->ob_type == floatType) - ? reinterpret_cast(rhs)->value - : LyLong_AsDouble(reinterpret_cast(rhs)); + // float <= float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; return reinterpret_cast(LyBool_FromBool(left <= right)); } return nullptr; } +LyObject *LyNumber_Gt(LyObject *lhs, LyObject *rhs) { + if (!lhs || !rhs) { + return nullptr; + } + + auto *longType = &LyLong_Type(); + auto *floatType = &LyFloat_Type(); + + // int > int + if (lhs->ob_type == longType && rhs->ob_type == longType) { + int cmp = LyLong_Compare(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(LyBool_FromBool(cmp > 0)); + } + + // float > float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; + return reinterpret_cast(LyBool_FromBool(left > right)); + } + + return nullptr; +} + +LyObject *LyNumber_Ge(LyObject *lhs, LyObject *rhs) { + if (!lhs || !rhs) { + return nullptr; + } + + auto *longType = &LyLong_Type(); + auto *floatType = &LyFloat_Type(); + + // int >= int + if (lhs->ob_type == longType && rhs->ob_type == longType) { + int cmp = LyLong_Compare(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(LyBool_FromBool(cmp >= 0)); + } + + // float >= float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; + return reinterpret_cast(LyBool_FromBool(left >= right)); + } + + return nullptr; +} + +LyObject *LyNumber_Eq(LyObject *lhs, LyObject *rhs) { + if (!lhs || !rhs) { + return nullptr; + } + + auto *longType = &LyLong_Type(); + auto *floatType = &LyFloat_Type(); + + // int == int + if (lhs->ob_type == longType && rhs->ob_type == longType) { + int cmp = LyLong_Compare(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(LyBool_FromBool(cmp == 0)); + } + + // float == float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; + return reinterpret_cast(LyBool_FromBool(left == right)); + } + + return nullptr; +} + +LyObject *LyNumber_Ne(LyObject *lhs, LyObject *rhs) { + if (!lhs || !rhs) { + return nullptr; + } + + auto *longType = &LyLong_Type(); + auto *floatType = &LyFloat_Type(); + + // int != int + if (lhs->ob_type == longType && rhs->ob_type == longType) { + int cmp = LyLong_Compare(reinterpret_cast(lhs), + reinterpret_cast(rhs)); + return reinterpret_cast(LyBool_FromBool(cmp != 0)); + } + + // float != float + if (lhs->ob_type == floatType && rhs->ob_type == floatType) { + double left = reinterpret_cast(lhs)->value; + double right = reinterpret_cast(rhs)->value; + return reinterpret_cast(LyBool_FromBool(left != right)); + } + + return nullptr; +} + // Exported wrapper for inline LyLong_Add (for JIT/AOT linking) LyLongObject *LyLong_Add(const LyLongObject *a, const LyLongObject *b) { if (LyLong_BothAreCompact(a, b)) { diff --git a/src/lython/runtime/objects/long.h b/src/lython/runtime/objects/long.h index e97f1b6..7367f7b 100644 --- a/src/lython/runtime/objects/long.h +++ b/src/lython/runtime/objects/long.h @@ -132,7 +132,12 @@ int LyLong_Compare(const LyLongObject *lhs, const LyLongObject *rhs); double LyLong_AsDouble(const LyLongObject *value); LyObject *LyNumber_Add(LyObject *lhs, LyObject *rhs); LyObject *LyNumber_Sub(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Lt(LyObject *lhs, LyObject *rhs); LyObject *LyNumber_Le(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Gt(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Ge(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Eq(LyObject *lhs, LyObject *rhs); +LyObject *LyNumber_Ne(LyObject *lhs, LyObject *rhs); // Exported wrappers for JIT/AOT linking (also available as inline in header) LyLongObject *LyLong_Add(const LyLongObject *a, const LyLongObject *b); LyLongObject *LyLong_Sub(const LyLongObject *a, const LyLongObject *b); diff --git a/src/lython/visitors/expr.py b/src/lython/visitors/expr.py index f962acb..2f993dd 100644 --- a/src/lython/visitors/expr.py +++ b/src/lython/visitors/expr.py @@ -305,12 +305,24 @@ def visit_Compare(self, node: ast.Compare) -> ir.Value: if self._in_native_func: return self._handle_primitive_compare(op, lhs, rhs, self._loc(node)) - # In object mode, use py.num.le - if not isinstance(op, ast.LtE): - raise NotImplementedError("Only <= comparison supported in object mode") + # In object mode, use py.num.* comparisons bool_type = self.get_py_type("!py.bool") with self._loc(node), self.insertion_point(): - return py_ops.NumLeOp(bool_type, lhs, rhs).result + if isinstance(op, ast.LtE): + return py_ops.NumLeOp(bool_type, lhs, rhs).result + if isinstance(op, ast.Gt): + return py_ops.NumGtOp(bool_type, lhs, rhs).result + if isinstance(op, ast.Lt): + return py_ops.NumLtOp(bool_type, lhs, rhs).result + if isinstance(op, ast.GtE): + return py_ops.NumGeOp(bool_type, lhs, rhs).result + if isinstance(op, ast.Eq): + return py_ops.NumEqOp(bool_type, lhs, rhs).result + if isinstance(op, ast.NotEq): + return py_ops.NumNeOp(bool_type, lhs, rhs).result + raise NotImplementedError( + "Only <, <=, >, >=, ==, != comparisons supported in object mode" + ) def _handle_primitive_compare( self, op: ast.cmpop, lhs: ir.Value, rhs: ir.Value, loc: ir.Location diff --git a/src/lython/visitors/stmt.py b/src/lython/visitors/stmt.py index c3387f6..0e036d8 100644 --- a/src/lython/visitors/stmt.py +++ b/src/lython/visitors/stmt.py @@ -136,9 +136,9 @@ def hoge(n: int) -> int: ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args ] arg_names_attr = ( - ir.ArrayAttr.get( + ir.ArrayAttr.get( # pyright: ignore[reportUnknownMemberType] arg_name_attrs, context=self.ctx - ) # pyright: ignore[reportUnknownMemberType] + ) if arg_name_attrs else None ) @@ -405,9 +405,9 @@ def _visit_method_def( ir.StringAttr.get(arg.arg, self.ctx) for arg in node.args.args ] arg_names_attr = ( - ir.ArrayAttr.get( + ir.ArrayAttr.get( # pyright: ignore[reportUnknownMemberType] arg_name_attrs, context=self.ctx - ) # pyright: ignore[reportUnknownMemberType] + ) if arg_name_attrs else None ) @@ -697,29 +697,39 @@ def visit_If(self, node: ast.If) -> None: assert self.current_block is not None parent_region = self.current_block.region true_block = ( - parent_region.blocks.append() - ) # pyright: ignore[reportUnknownMemberType] + parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] + ) false_block = ( - parent_region.blocks.append() - ) # pyright: ignore[reportUnknownMemberType] - merge_block = ( - parent_region.blocks.append() - ) # pyright: ignore[reportUnknownMemberType] + parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] + ) with self._loc(node), self.insertion_point(): cf_ops.CondBranchOp(cond, [], [], true_block, false_block) - def handle_branch(block: ir.Block, statements: list[ast.stmt]) -> None: + def handle_branch(block: ir.Block, statements: list[ast.stmt]) -> bool: self._set_insertion_block(block) self.push_scope() for stmt in statements: self.visit(stmt) - if not self._block_terminated(block): - with self._loc(node), ir.InsertionPoint(block): - cf_ops.BranchOp([], merge_block) + terminated = self._block_terminated(block) self.pop_scope() + return terminated - handle_branch(true_block, node.body) - handle_branch(false_block, node.orelse or []) + true_terminated = handle_branch(true_block, node.body) + false_terminated = handle_branch(false_block, node.orelse or []) + + if true_terminated and false_terminated: + self._set_insertion_block(None) + return + + merge_block = ( + parent_region.blocks.append() # pyright: ignore[reportUnknownMemberType] + ) + if not true_terminated: + with self._loc(node), ir.InsertionPoint(true_block): + cf_ops.BranchOp([], merge_block) + if not false_terminated: + with self._loc(node), ir.InsertionPoint(false_block): + cf_ops.BranchOp([], merge_block) self._set_insertion_block(merge_block)