Skip to content

Commit 0c3cbb4

Browse files
committed
[Heavy] Proper forwarding of init args
1 parent f64fb8a commit 0c3cbb4

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

heavy/lib/Nbdl/NbdlWriter.cpp

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <llvm/ADT/ScopedHashTable.h>
1818
#include <llvm/ADT/ScopeExit.h>
1919
#include <llvm/ADT/Twine.h>
20+
#include <llvm/ADT/TypeSwitch.h>
2021
#include <llvm/Support/Casting.h>
2122
#include <mlir/IR/Value.h>
2223
#include <mlir/IR/Operation.h>
@@ -169,6 +170,13 @@ class NbdlWriter {
169170
}
170171

171172
void WriteForwardedExpr(mlir::Value V) {
173+
assert(V.hasOneUse() &&
174+
"expecting exactly one use for forwarding reference");
175+
WriteExpr(V, /*IsFwd=*/true);
176+
}
177+
178+
// Special function to denote that we checked for multiple uses.
179+
void WriteLastUseForwardedExpr(mlir::Value V) {
172180
WriteExpr(V, /*IsFwd=*/true);
173181
}
174182

@@ -339,36 +347,36 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
339347
<< SetLocalVarName(Op.getResult(), "get_")
340348
<< " = ";
341349
if (MemberNameOp) {
342-
WriteForwardedExpr(Op.getState());
350+
WriteExpr(Op.getState());
343351
OS << '.' << MemberNameOp.getName()
344352
<< ";\n";
345353
} else {
346354
OS << "nbdl::get(";
347-
WriteForwardedExpr(Op.getState());
355+
WriteExpr(Op.getState());
348356
if (!isa<nbdl_gen::UnitType>(Op.getKey().getType())) {
349357
OS << ", ";
350-
WriteForwardedExpr(Op.getKey());
358+
WriteExpr(Op.getKey());
351359
}
352360
OS << ");\n";
353361
}
354362
}
355363

356364
void Visit(VisitOp Op) {
357-
WriteForwardedExpr(Op.getFn());
365+
WriteExpr(Op.getFn());
358366
OS << '(';
359367
llvm::interleave(Op.getArgs(), OS,
360368
[&](mlir::Value V) {
361-
WriteForwardedExpr(V);
369+
WriteExpr(V);
362370
}, ",\n");
363371
OS << ");\n";
364372
}
365373

366374
void Visit(MatchOp Op) {
367375
OS << "nbdl::match(";
368-
WriteForwardedExpr(Op.getStore());
376+
WriteExpr(Op.getStore());
369377
if (!isa<nbdl_gen::UnitType>(Op.getKey().getType())) {
370378
OS << ", ";
371-
WriteForwardedExpr(Op.getKey());
379+
WriteExpr(Op.getKey());
372380
}
373381
OS << ", ";
374382
OS << "\nboost::hana::overload_linearly(";
@@ -386,6 +394,9 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
386394
mlir::Region& Body = Op.getBody();
387395
if (Body.empty())
388396
return;
397+
llvm::StringRef TypeStr = Op.getType();
398+
if (TypeStr.empty())
399+
TypeStr = "auto&&";
389400
OS << "[&]";
390401
// Write parameters.
391402
OS << '(';
@@ -562,10 +573,17 @@ class ContextWriter : public NbdlWriter<ContextWriter> {
562573
if (auto StoreOp = Member.getDefiningOp<nbdl_gen::StoreOp>()) {
563574
llvm::interleaveComma(StoreOp.getArgs(), OS,
564575
[&](mlir::Value Arg) {
565-
WriteForwardedExpr(Arg);
576+
// Check for later use or possibly later use
577+
// (due to unspecified order of evaluation)
578+
if (Arg.hasOneUse() ||
579+
(llvm::count(StoreOp.getArgs(), Arg) == 1 &&
580+
!findLaterUse(Arg, StoreOp)))
581+
WriteLastUseForwardedExpr(Arg);
582+
else
583+
WriteExpr(Arg); // Not forwarded
566584
});
567585
} else {
568-
SetErrorV("unsupported operation arguments", Member);
586+
SetErrorV("unsupported operation (WriteInitArgs)", Member);
569587
}
570588
}
571589

@@ -575,6 +593,49 @@ class ContextWriter : public NbdlWriter<ContextWriter> {
575593
<< "() const {\n return " << GetLocalVal(Value) << ";\n}\n";
576594
}
577595
}
596+
597+
// Find any use of Arg in subsequent stores composed of Store
598+
// (not including Store itself.)
599+
StoreOp findLaterUse(mlir::Value Arg, mlir::Value Store) {
600+
assert(Store.hasOneUse() && "store result should be used exactly once");
601+
mlir::OpOperand StoreUse = *Store.getUsers().begin();
602+
603+
return llvm::TypeSwitch<mlir::Operation*, StoreOp>(StoreUse.getOwner())
604+
.Case<StoreOp>([&](StoreOp S) {
605+
return llvm::is_contained(S.getArgs(), Arg) ? S : StoreOp();
606+
})
607+
.Case<StoreComposeOp>([&](StoreComposeOp SC) {
608+
// Check that Arg is not used in LHS.
609+
if (SC.getLhs() != Store)
610+
if (StoreOp Result = findUse(Arg, SC.getLhs()))
611+
return Result;
612+
return findLaterUse(Arg, SC.getResult());
613+
})
614+
.Case<ContOp>([](auto) { return StoreOp(); })
615+
.Default([](auto) {
616+
llvm_unreachable("unexpected use of store");
617+
return StoreOp();
618+
});
619+
}
620+
621+
// Find any use of Arg in a Store.
622+
StoreOp findUse(mlir::Value Arg, mlir::Value Store) {
623+
return llvm::TypeSwitch<mlir::Operation*, StoreOp>(Store.getDefiningOp())
624+
.Case<StoreOp>([&](StoreOp S) {
625+
return llvm::is_contained(S.getArgs(), Arg) ? S : StoreOp();
626+
})
627+
.Case<StoreComposeOp>([&](StoreComposeOp SC) {
628+
// Check both sides.
629+
if (StoreOp Result = findUse(Arg, SC.getLhs()))
630+
return Result;
631+
else
632+
return findUse(Arg, SC.getRhs());
633+
})
634+
.Default([](auto) {
635+
llvm_unreachable("unexpected result for store");
636+
return StoreOp();
637+
});
638+
}
578639
};
579640

580641
} // end namespace

heavy/test/Nbdl/context.scm

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
(build-member (build-member-name 'foo) '::moo::foo_t foo-input)
5353
(build-member (build-member-name 'bar) '::moo::bar_t)
5454
(build-member (build-member-name 'baz) '::moo::baz_t BazArg)
55+
(build-member (build-member-name 'baz2) '::moo::baz_t BazArg)
5556
(old-create-op "nbdl.cont"
5657
(operands parent))
5758
))
@@ -68,6 +69,6 @@
6869
; CHECK-NEXT: ::moo::bar_t bar;
6970
; CHECK-NEXT: ::moo::baz_t baz;
7071
; CHECK: my_context(auto&& arg_0)
71-
; CHECK-NEXT: : foo(42), bar(), baz(static_cast<decltype(arg_0)>(arg_0)
72+
; CHECK-NEXT: : foo(42), bar(), baz(arg_0), baz2(static_cast<decltype(arg_0)>(arg_0))
7273
(translate-cpp my_context)
7374
(newline)

0 commit comments

Comments
 (0)