1717#include < llvm/ADT/ScopedHashTable.h>
1818#include < llvm/ADT/ScopeExit.h>
1919#include < llvm/ADT/Twine.h>
20+ #include < llvm/ADT/TypeSwitch.h>
2021#include < llvm/Support/Casting.h>
2122#include < mlir/IR/Value.h>
2223#include < mlir/IR/Operation.h>
@@ -169,6 +170,13 @@ class NbdlWriter {
169170 }
170171
171172 void WriteForwardedExpr (mlir::Value V) {
173+ assert (V.hasOneUse () &&
174+ " expecting exactly one use for forwarding reference" );
175+ WriteExpr (V, /* IsFwd=*/ true );
176+ }
177+
178+ // Special function to denote that we checked for multiple uses.
179+ void WriteLastUseForwardedExpr (mlir::Value V) {
172180 WriteExpr (V, /* IsFwd=*/ true );
173181 }
174182
@@ -339,36 +347,36 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
339347 << SetLocalVarName (Op.getResult (), " get_" )
340348 << " = " ;
341349 if (MemberNameOp) {
342- WriteForwardedExpr (Op.getState ());
350+ WriteExpr (Op.getState ());
343351 OS << ' .' << MemberNameOp.getName ()
344352 << " ;\n " ;
345353 } else {
346354 OS << " nbdl::get(" ;
347- WriteForwardedExpr (Op.getState ());
355+ WriteExpr (Op.getState ());
348356 if (!isa<nbdl_gen::UnitType>(Op.getKey ().getType ())) {
349357 OS << " , " ;
350- WriteForwardedExpr (Op.getKey ());
358+ WriteExpr (Op.getKey ());
351359 }
352360 OS << " );\n " ;
353361 }
354362 }
355363
356364 void Visit (VisitOp Op) {
357- WriteForwardedExpr (Op.getFn ());
365+ WriteExpr (Op.getFn ());
358366 OS << ' (' ;
359367 llvm::interleave (Op.getArgs (), OS,
360368 [&](mlir::Value V) {
361- WriteForwardedExpr (V);
369+ WriteExpr (V);
362370 }, " ,\n " );
363371 OS << " );\n " ;
364372 }
365373
366374 void Visit (MatchOp Op) {
367375 OS << " nbdl::match(" ;
368- WriteForwardedExpr (Op.getStore ());
376+ WriteExpr (Op.getStore ());
369377 if (!isa<nbdl_gen::UnitType>(Op.getKey ().getType ())) {
370378 OS << " , " ;
371- WriteForwardedExpr (Op.getKey ());
379+ WriteExpr (Op.getKey ());
372380 }
373381 OS << " , " ;
374382 OS << " \n boost::hana::overload_linearly(" ;
@@ -386,6 +394,9 @@ class FuncWriter : public NbdlWriter<FuncWriter> {
386394 mlir::Region& Body = Op.getBody ();
387395 if (Body.empty ())
388396 return ;
397+ llvm::StringRef TypeStr = Op.getType ();
398+ if (TypeStr.empty ())
399+ TypeStr = " auto&&" ;
389400 OS << " [&]" ;
390401 // Write parameters.
391402 OS << ' (' ;
@@ -562,10 +573,17 @@ class ContextWriter : public NbdlWriter<ContextWriter> {
562573 if (auto StoreOp = Member.getDefiningOp <nbdl_gen::StoreOp>()) {
563574 llvm::interleaveComma (StoreOp.getArgs (), OS,
564575 [&](mlir::Value Arg) {
565- WriteForwardedExpr (Arg);
576+ // Check for later use or possibly later use
577+ // (due to unspecified order of evaluation)
578+ if (Arg.hasOneUse () ||
579+ (llvm::count (StoreOp.getArgs (), Arg) == 1 &&
580+ !findLaterUse (Arg, StoreOp)))
581+ WriteLastUseForwardedExpr (Arg);
582+ else
583+ WriteExpr (Arg); // Not forwarded
566584 });
567585 } else {
568- SetErrorV (" unsupported operation arguments " , Member);
586+ SetErrorV (" unsupported operation (WriteInitArgs) " , Member);
569587 }
570588 }
571589
@@ -575,6 +593,49 @@ class ContextWriter : public NbdlWriter<ContextWriter> {
575593 << " () const {\n return " << GetLocalVal (Value) << " ;\n }\n " ;
576594 }
577595 }
596+
597+ // Find any use of Arg in subsequent stores composed of Store
598+ // (not including Store itself.)
599+ StoreOp findLaterUse (mlir::Value Arg, mlir::Value Store) {
600+ assert (Store.hasOneUse () && " store result should be used exactly once" );
601+ mlir::OpOperand StoreUse = *Store.getUsers ().begin ();
602+
603+ return llvm::TypeSwitch<mlir::Operation*, StoreOp>(StoreUse.getOwner ())
604+ .Case <StoreOp>([&](StoreOp S) {
605+ return llvm::is_contained (S.getArgs (), Arg) ? S : StoreOp ();
606+ })
607+ .Case <StoreComposeOp>([&](StoreComposeOp SC) {
608+ // Check that Arg is not used in LHS.
609+ if (SC.getLhs () != Store)
610+ if (StoreOp Result = findUse (Arg, SC.getLhs ()))
611+ return Result;
612+ return findLaterUse (Arg, SC.getResult ());
613+ })
614+ .Case <ContOp>([](auto ) { return StoreOp (); })
615+ .Default ([](auto ) {
616+ llvm_unreachable (" unexpected use of store" );
617+ return StoreOp ();
618+ });
619+ }
620+
621+ // Find any use of Arg in a Store.
622+ StoreOp findUse (mlir::Value Arg, mlir::Value Store) {
623+ return llvm::TypeSwitch<mlir::Operation*, StoreOp>(Store.getDefiningOp ())
624+ .Case <StoreOp>([&](StoreOp S) {
625+ return llvm::is_contained (S.getArgs (), Arg) ? S : StoreOp ();
626+ })
627+ .Case <StoreComposeOp>([&](StoreComposeOp SC) {
628+ // Check both sides.
629+ if (StoreOp Result = findUse (Arg, SC.getLhs ()))
630+ return Result;
631+ else
632+ return findUse (Arg, SC.getRhs ());
633+ })
634+ .Default ([](auto ) {
635+ llvm_unreachable (" unexpected result for store" );
636+ return StoreOp ();
637+ });
638+ }
578639};
579640
580641} // end namespace
0 commit comments