Skip to content

Commit 551b315

Browse files
HerrCai0907atc-github
authored andcommitted
feat(inserter): support for more insertion point instructions (#113)
1 parent 2ad1dfa commit 551b315

File tree

2 files changed

+175
-31
lines changed

2 files changed

+175
-31
lines changed

passes/helper/ExprInserter.cpp

Lines changed: 161 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <memory>
2+
#include <vector>
23

34
#include "ExprInserter.hpp"
45
#include "ToString.hpp"
@@ -30,6 +31,11 @@ bool ExprInserter::canInsertBefore(wasm::Expression *insertPosition) {
3031
if (canInsertAfter(call->operands.back()))
3132
return true;
3233
}
34+
if (insertPosition->is<wasm::LocalSet>()) {
35+
wasm::LocalSet *const localSet = insertPosition->cast<wasm::LocalSet>();
36+
if (canInsertAfter(localSet->value))
37+
return true;
38+
}
3339
fmt::println("[" PASS_NAME "] fn '{}', failed to insert before {}", func_->name.str, toString(insertPosition));
3440
return false;
3541
}
@@ -53,55 +59,87 @@ void ExprInserter::insertBefore(wasm::Builder &b, wasm::Expression *insertedExpr
5359
}
5460
break;
5561
}
62+
case wasm::Expression::LocalSetId: {
63+
wasm::LocalSet *const localSet = insertPosition->cast<wasm::LocalSet>();
64+
insertAfter(b, insertedExpr, &localSet->value);
65+
break;
66+
}
5667
default:
5768
__builtin_unreachable();
5869
}
5970
}
6071

6172
bool ExprInserter::canInsertAfter(wasm::Expression *insertPosition) {
62-
if (insertPosition->type == wasm::Type::none && !isTerminator(insertPosition))
63-
return true;
64-
if (insertPosition->is<wasm::Return>())
65-
return true;
66-
if (insertPosition->is<wasm::If>())
67-
return true;
73+
if (isTerminator(insertPosition)) {
74+
// special handler for terminator
75+
if (insertPosition->is<wasm::Unreachable>())
76+
return true;
77+
if (wasm::Return *expr = insertPosition->dynCast<wasm::Return>(); expr != nullptr) {
78+
return true;
79+
if (expr->value == nullptr)
80+
return true;
81+
if (canInsertAfter(expr->value))
82+
return true;
83+
}
84+
if (wasm::Break *expr = insertPosition->dynCast<wasm::Break>(); expr != nullptr && expr->condition == nullptr) {
85+
return true;
86+
if (expr->value == nullptr)
87+
return true;
88+
if (canInsertAfter(expr->value))
89+
return true;
90+
}
91+
} else {
92+
if (insertPosition->type == wasm::Type::none)
93+
return true;
94+
if (insertPosition->type != wasm::Type::unreachable)
95+
return true;
96+
}
6897
fmt::println("[" PASS_NAME "] fn '{}', failed to insert after {}", func_->name.str, toString(insertPosition));
6998
return false;
7099
}
100+
71101
void ExprInserter::insertAfter(wasm::Builder &b, wasm::Expression *insertedExpr, wasm::Expression **insertPositionPtr) {
72102
assert(insertedExpr->type == wasm::Type::none);
73103
wasm::Expression *const insertPosition = *insertPositionPtr;
74-
if (insertPosition->type == wasm::Type::none) {
75-
*insertPositionPtr = b.makeBlock({insertPosition, insertedExpr}, wasm::Type::none);
76-
return;
77-
}
78-
if (auto *const returnExpr = insertPosition->dynCast<wasm::Return>()) {
79-
if (returnExpr->value == nullptr) {
80-
*insertPositionPtr = b.makeBlock({insertedExpr, insertPosition}, wasm::Type::none);
81-
} else {
82-
wasm::Type const localType = returnExpr->value->type;
83-
wasm::Index const tmpLocal = b.addVar(func_, localType);
84-
returnExpr->value = b.makeBlock(
85-
{b.makeLocalSet(tmpLocal, returnExpr->value), insertedExpr, b.makeLocalGet(tmpLocal, localType)}, localType);
104+
105+
if (isTerminator(insertPosition)) {
106+
if (insertPosition->is<wasm::Unreachable>()) {
107+
*insertPositionPtr = b.makeBlock({insertedExpr, insertPosition}, wasm::Type::unreachable);
108+
return;
86109
}
87-
return;
88-
}
89-
if (auto *const ifExpr = insertPosition->dynCast<wasm::If>()) {
90-
wasm::Type const type = ifExpr->type;
91-
assert(type != wasm::Type::none);
92-
if (ifExpr->type == wasm::Type::unreachable) {
110+
if (wasm::Return *expr = insertPosition->dynCast<wasm::Return>(); expr != nullptr) {
111+
if (expr->value == nullptr) {
112+
*insertPositionPtr = b.makeBlock({insertedExpr, insertPosition}, wasm::Type::unreachable);
113+
return;
114+
}
115+
insertAfter(b, insertedExpr, &expr->value);
116+
return;
117+
}
118+
if (wasm::Break *expr = insertPosition->dynCast<wasm::Break>(); expr != nullptr && expr->condition == nullptr) {
119+
if (expr->value == nullptr) {
120+
*insertPositionPtr = b.makeBlock({insertedExpr, insertPosition}, wasm::Type::unreachable);
121+
return;
122+
}
123+
insertAfter(b, insertedExpr, &expr->value);
124+
return;
125+
}
126+
} else {
127+
wasm::Type const exprType = insertPosition->type;
128+
if (exprType == wasm::Type::none) {
93129
*insertPositionPtr = b.makeBlock({insertPosition, insertedExpr}, wasm::Type::none);
94-
} else {
95-
wasm::Index index = wasm::Builder::addVar(func_, type);
130+
return;
131+
}
132+
if (exprType != wasm::Type::unreachable) {
133+
wasm::Index const tmpLocal = b.addVar(func_, exprType);
96134
*insertPositionPtr = b.makeBlock(
97135
{
98-
b.makeLocalSet(index, ifExpr),
136+
b.makeLocalSet(tmpLocal, insertPosition),
99137
insertedExpr,
100-
b.makeLocalGet(index, type),
138+
b.makeLocalGet(tmpLocal, exprType),
101139
},
102-
type);
140+
exprType);
141+
return;
103142
}
104-
return;
105143
}
106144
__builtin_unreachable();
107145
}
@@ -113,10 +151,12 @@ void ExprInserter::insertAfter(wasm::Builder &b, wasm::Expression *insertedExpr,
113151
#include <gtest/gtest.h>
114152

115153
#include "FindExpr.hpp"
154+
#include "Matcher.hpp"
116155

117156
namespace warpo::passes::ut {
118157

119-
using wasm::Const, wasm::Block, wasm::Nop, wasm::Call, wasm::LocalGet, wasm::LocalSet, wasm::If, wasm::Return;
158+
using wasm::Const, wasm::Block, wasm::Nop, wasm::Call, wasm::LocalGet, wasm::LocalSet, wasm::If, wasm::Return,
159+
wasm::Loop, wasm::Break;
120160
using wasm::Type;
121161

122162
TEST(ExprInserter, InsertBeforeNoOperand) {
@@ -173,6 +213,31 @@ TEST(ExprInserter, InsertBeforeCallWithOperands) {
173213
ASSERT_TRUE(insertPos->cast<Call>()->operands[1]->cast<Block>()->list[2]->is<LocalGet>());
174214
}
175215

216+
TEST(ExprInserter, InsertBeforeLocalSet) {
217+
wasm::Module m{};
218+
wasm::Builder b{m};
219+
wasm::Expression *const insertPos = b.makeLocalSet(0, b.makeConst(1));
220+
std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction("test", wasm::Signature(), {}, insertPos);
221+
ExprInserter inserter{f.get()};
222+
223+
ASSERT_TRUE(inserter.canInsertBefore(insertPos));
224+
inserter.insertBefore(b, b.makeNop(), findExprPointer(insertPos, f.get()));
225+
226+
ASSERT_EQ(f->body, insertPos);
227+
228+
using namespace matcher;
229+
auto matcher = isLocalSet(local_set::v(isBlock(block::has(3), block::at(0, isLocalSet(local_set::v(isConst()))),
230+
block::at(1, isNop()), block::at(2, isLocalGet()))));
231+
EXPECT_TRUE(matcher(*f->body));
232+
233+
ASSERT_EQ(f->body, insertPos);
234+
ASSERT_TRUE(insertPos->cast<LocalSet>()->value->is<Block>());
235+
ASSERT_TRUE(insertPos->cast<LocalSet>()->value->cast<Block>()->list[0]->is<LocalSet>());
236+
ASSERT_TRUE(insertPos->cast<LocalSet>()->value->cast<Block>()->list[0]->cast<LocalSet>()->value->is<Const>());
237+
ASSERT_TRUE(insertPos->cast<LocalSet>()->value->cast<Block>()->list[1]->is<Nop>());
238+
ASSERT_TRUE(insertPos->cast<LocalSet>()->value->cast<Block>()->list[2]->is<LocalGet>());
239+
}
240+
176241
TEST(ExprInserter, InsertAfterTypeNone) {
177242
wasm::Module m{};
178243
wasm::Builder b{m};
@@ -188,6 +253,38 @@ TEST(ExprInserter, InsertAfterTypeNone) {
188253
ASSERT_TRUE(f->body->cast<Block>()->list[1]->is<Nop>());
189254
}
190255

256+
TEST(ExprInserter, InsertAfterLoopWithoutType) {
257+
wasm::Module m{};
258+
wasm::Builder b{m};
259+
wasm::Expression *const insertPos = b.makeLoop("", b.makeNop(), Type::none);
260+
std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction("test", wasm::Signature(), {}, insertPos);
261+
ExprInserter inserter{f.get()};
262+
263+
ASSERT_TRUE(inserter.canInsertAfter(insertPos));
264+
inserter.insertAfter(b, b.makeNop(), findExprPointer(insertPos, f.get()));
265+
266+
ASSERT_TRUE(f->body->is<Block>());
267+
ASSERT_TRUE(f->body->cast<Block>()->list[0]->is<Loop>());
268+
ASSERT_TRUE(f->body->cast<Block>()->list[1]->is<Nop>());
269+
}
270+
271+
TEST(ExprInserter, InsertAfterLoopWithType) {
272+
wasm::Module m{};
273+
wasm::Builder b{m};
274+
wasm::Expression *const insertPos = b.makeLoop("", b.makeConst(1), Type::i32);
275+
std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction("test", wasm::Signature(), {}, insertPos);
276+
ExprInserter inserter{f.get()};
277+
278+
ASSERT_TRUE(inserter.canInsertAfter(insertPos));
279+
inserter.insertAfter(b, b.makeNop(), findExprPointer(insertPos, f.get()));
280+
281+
ASSERT_TRUE(f->body->is<Block>());
282+
ASSERT_TRUE(f->body->cast<Block>()->list[0]->is<LocalSet>());
283+
ASSERT_TRUE(f->body->cast<Block>()->list[0]->cast<LocalSet>()->value->is<Loop>());
284+
ASSERT_TRUE(f->body->cast<Block>()->list[1]->is<Nop>());
285+
ASSERT_TRUE(f->body->cast<Block>()->list[2]->is<LocalGet>());
286+
}
287+
191288
TEST(ExprInserter, InsertAfterReturnWithoutValue) {
192289
wasm::Module m{};
193290
wasm::Builder b{m};
@@ -221,6 +318,39 @@ TEST(ExprInserter, InsertAfterReturnWithValue) {
221318
ASSERT_TRUE(f->body->cast<Return>()->value->cast<Block>()->list[2]->is<LocalGet>());
222319
}
223320

321+
TEST(ExprInserter, InsertAfterBrWithoutValue) {
322+
wasm::Module m{};
323+
wasm::Builder b{m};
324+
wasm::Expression *const insertPos = b.makeBreak("bb");
325+
std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction("test", wasm::Signature(), {}, insertPos);
326+
ExprInserter inserter{f.get()};
327+
328+
ASSERT_TRUE(inserter.canInsertAfter(insertPos));
329+
inserter.insertAfter(b, b.makeNop(), findExprPointer(insertPos, f.get()));
330+
331+
ASSERT_TRUE(f->body->is<Block>());
332+
ASSERT_TRUE(f->body->cast<Block>()->list[0]->is<Nop>());
333+
ASSERT_TRUE(f->body->cast<Block>()->list[1]->is<Break>());
334+
}
335+
336+
TEST(ExprInserter, InsertAfterBrWithValue) {
337+
wasm::Module m{};
338+
wasm::Builder b{m};
339+
wasm::Expression *const insertPos = b.makeBreak("bb", b.makeConst(1));
340+
std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction("test", wasm::Signature(), {}, insertPos);
341+
ExprInserter inserter{f.get()};
342+
343+
ASSERT_TRUE(inserter.canInsertAfter(insertPos));
344+
inserter.insertAfter(b, b.makeNop(), findExprPointer(insertPos, f.get()));
345+
346+
ASSERT_EQ(f->body, insertPos);
347+
ASSERT_TRUE(f->body->cast<Break>()->value->is<Block>());
348+
ASSERT_TRUE(f->body->cast<Break>()->value->cast<Block>()->list[0]->is<LocalSet>());
349+
ASSERT_TRUE(f->body->cast<Break>()->value->cast<Block>()->list[0]->cast<LocalSet>()->value->is<Const>());
350+
ASSERT_TRUE(f->body->cast<Break>()->value->cast<Block>()->list[1]->is<Nop>());
351+
ASSERT_TRUE(f->body->cast<Break>()->value->cast<Block>()->list[2]->is<LocalGet>());
352+
}
353+
224354
} // namespace warpo::passes::ut
225355

226356
#endif

passes/helper/Matcher.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,4 +197,18 @@ static inline M<wasm::If> hasFalse() {
197197

198198
constexpr IsMatcherImpl<wasm::Return, wasm::Expression> isReturn;
199199

200+
constexpr IsMatcherImpl<wasm::Block, wasm::Expression> isBlock;
201+
namespace block {
202+
static inline M<wasm::Block> has(size_t n) {
203+
return M<wasm::Block>([n](wasm::Block const &expr, Context &ctx) -> bool { return expr.list.size() == n; });
204+
}
205+
static inline M<wasm::Block> at(size_t n, M<wasm::Expression> const &m) {
206+
return M<wasm::Block>([n, m](wasm::Block const &expr, Context &ctx) -> bool {
207+
if (n >= expr.list.size())
208+
return false;
209+
return m(*expr.list[n], ctx);
210+
});
211+
}
212+
} // namespace block
213+
200214
} // namespace warpo::passes::matcher

0 commit comments

Comments
 (0)