Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/tarai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def tarai(x: int, y: int, z: int) -> int:
if x > y:
return tarai( tarai(x - 1, y, z), tarai(y - 1, z, x), tarai(z - 1, x, y) )
else:
return y


print(tarai(14, 7, 0))
16 changes: 16 additions & 0 deletions examples/tarai_prim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from lyrt import from_prim, native
from lyrt.prim import Int

p1 = Int[8](1)


@native(gc="none")
def tarai(x: Int[8], y: Int[8], z: Int[8]) -> Int[8]:
if x > y:
return tarai( tarai(x - p1, y, z), tarai(y - p1, z, x), tarai(z - p1, x, y) )
else:
return y


ans = tarai(Int[8](14), Int[8](7), Int[8](0))
print(from_prim(ans))
60 changes: 60 additions & 0 deletions src/lython/dialects/cpp/PyVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,4 +860,64 @@ LogicalResult NumLeOp::verify() {
return success();
}

LogicalResult NumLtOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType != rhsType)
return emitOpError("operand types must match");
if (!isPyIntType(lhsType) && !isPyFloatType(lhsType))
return emitOpError("operands must be !py.int or !py.float");
if (!isPyBoolType(getResult().getType()))
return emitOpError("result must be !py.bool");
return success();
}

LogicalResult NumGtOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType != rhsType)
return emitOpError("operand types must match");
if (!isPyIntType(lhsType) && !isPyFloatType(lhsType))
return emitOpError("operands must be !py.int or !py.float");
if (!isPyBoolType(getResult().getType()))
return emitOpError("result must be !py.bool");
return success();
}

LogicalResult NumGeOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType != rhsType)
return emitOpError("operand types must match");
if (!isPyIntType(lhsType) && !isPyFloatType(lhsType))
return emitOpError("operands must be !py.int or !py.float");
if (!isPyBoolType(getResult().getType()))
return emitOpError("result must be !py.bool");
return success();
}

LogicalResult NumEqOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType != rhsType)
return emitOpError("operand types must match");
if (!isPyIntType(lhsType) && !isPyFloatType(lhsType))
return emitOpError("operands must be !py.int or !py.float");
if (!isPyBoolType(getResult().getType()))
return emitOpError("result must be !py.bool");
return success();
}

LogicalResult NumNeOp::verify() {
Type lhsType = getLhs().getType();
Type rhsType = getRhs().getType();
if (lhsType != rhsType)
return emitOpError("operand types must match");
if (!isPyIntType(lhsType) && !isPyFloatType(lhsType))
return emitOpError("operands must be !py.int or !py.float");
if (!isPyBoolType(getResult().getType()))
return emitOpError("result must be !py.bool");
return success();
}

} // namespace py
85 changes: 85 additions & 0 deletions src/lython/dialects/tablegen/PyDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,91 @@ def Py_NumLeOp : Py_Op<"num.le", [Pure]> {
let hasVerifier = 1;
}

def Py_NumLtOp : Py_Op<"num.lt", [Pure]> {
let summary = "Python numeric < comparison";
let description = [{
Compares two numeric values and returns a `!py.bool` object.
}];

let arguments = (ins
AnyType:$lhs,
AnyType:$rhs
);

let results = (outs Py_BoolType:$result);

let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)";
let hasVerifier = 1;
}

def Py_NumGtOp : Py_Op<"num.gt", [Pure]> {
let summary = "Python numeric > comparison";
let description = [{
Compares two numeric values and returns a `!py.bool` object.
}];

let arguments = (ins
AnyType:$lhs,
AnyType:$rhs
);

let results = (outs Py_BoolType:$result);

let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)";
let hasVerifier = 1;
}

def Py_NumGeOp : Py_Op<"num.ge", [Pure]> {
let summary = "Python numeric >= comparison";
let description = [{
Compares two numeric values and returns a `!py.bool` object.
}];

let arguments = (ins
AnyType:$lhs,
AnyType:$rhs
);

let results = (outs Py_BoolType:$result);

let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)";
let hasVerifier = 1;
}

def Py_NumEqOp : Py_Op<"num.eq", [Pure]> {
let summary = "Python numeric == comparison";
let description = [{
Compares two numeric values and returns a `!py.bool` object.
}];

let arguments = (ins
AnyType:$lhs,
AnyType:$rhs
);

let results = (outs Py_BoolType:$result);

let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)";
let hasVerifier = 1;
}

def Py_NumNeOp : Py_Op<"num.ne", [Pure]> {
let summary = "Python numeric != comparison";
let description = [{
Compares two numeric values and returns a `!py.bool` object.
}];

let arguments = (ins
AnyType:$lhs,
AnyType:$rhs
);

let results = (outs Py_BoolType:$result);

let assemblyFormat = "$lhs `:` type($lhs) `,` $rhs `:` type($rhs) attr-dict `->` type($result)";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// Class instantiation
//===----------------------------------------------------------------------===//
Expand Down
63 changes: 59 additions & 4 deletions src/lython/lowering/PyOptimizationPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ bool cleanupDeadTuples(ModuleOp module) {
return !toErase.empty();
}

/// Remove unused TupleEmptyOps.
void removeUnusedTupleEmpties(ModuleOp module) {
SmallVector<TupleEmptyOp> toErase;
module.walk([&](TupleEmptyOp op) {
if (op.getResult().use_empty())
toErase.push_back(op);
});
for (auto op : toErase)
op->erase();
}

/// Hoist integer constants to entry block and perform CSE.
void hoistIntConstants(ModuleOp module) {
module.walk([&](func::FuncOp func) {
Expand Down Expand Up @@ -391,23 +402,65 @@ void eliminateBoolBoxingUnboxing(ModuleOp module) {
}
}

// Only optimize if the pattern matches exactly:
// Only optimize if the pattern matches:
// - One LyBool_AsBool call
// - One Ly_DecRef call
// - Optional Ly_DecRef call (bool singletons are immortal)
// - No other users
if (!asBoolCall || !decRefCall || hasOtherUsers)
if (!asBoolCall || hasOtherUsers)
continue;

// Replace uses of the AsBool result with the original i1 value
asBoolCall.getResult().replaceAllUsesWith(i1Value);

// Erase the operations (in reverse order of dependencies)
decRefCall->erase();
if (decRefCall)
decRefCall->erase();
asBoolCall->erase();
fromBoolCall->erase();
}
}

/// CSE for LyLong_FromI64 for small integers (-5 to 256).
/// Small integers are immortal, so sharing is safe.
void cseSmallIntFromI64(ModuleOp module) {
module.walk([&](func::FuncOp func) {
if (func.isExternal())
return;

llvm::DenseMap<int64_t, LLVM::CallOp> cache;
SmallVector<LLVM::CallOp> toErase;

func.walk([&](LLVM::CallOp callOp) {
auto callee = callOp.getCallee();
if (!callee || *callee != "LyLong_FromI64")
return;
if (callOp.getNumOperands() != 1)
return;
auto constOp =
callOp.getOperand(0).getDefiningOp<LLVM::ConstantOp>();
if (!constOp)
return;
auto intAttr = llvm::dyn_cast<IntegerAttr>(constOp.getValue());
if (!intAttr)
return;
int64_t value = intAttr.getInt();
if (value < -5 || value > 256)
return;

auto it = cache.find(value);
if (it != cache.end()) {
callOp.getResult().replaceAllUsesWith(it->second.getResult());
toErase.push_back(callOp);
} else {
cache[value] = callOp;
}
});

for (auto op : toErase)
op->erase();
});
}

/// CSE for LLVM constants within each function.
void cseConstants(ModuleOp module) {
module.walk([&](func::FuncOp func) {
Expand Down Expand Up @@ -495,6 +548,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createPyOptimizationPass() {

void runPreLoweringOptimizations(ModuleOp module) {
cleanupDeadTuples(module);
removeUnusedTupleEmpties(module);
hoistIntConstants(module);
removeSmallIntDecrefs(module);
}
Expand All @@ -504,6 +558,7 @@ void runPostLoweringOptimizations(ModuleOp module) {
replaceEmptyTupleNew(module);
cseSingletonGetters(module);
eliminateBoolBoxingUnboxing(module);
cseSmallIntFromI64(module);
cseConstants(module);
eliminateDeadCode(module);
}
Expand Down
Loading