1
1
#include < memory>
2
+ #include < vector>
2
3
3
4
#include " ExprInserter.hpp"
4
5
#include " ToString.hpp"
@@ -30,6 +31,11 @@ bool ExprInserter::canInsertBefore(wasm::Expression *insertPosition) {
30
31
if (canInsertAfter (call->operands .back ()))
31
32
return true ;
32
33
}
34
+ if (insertPosition->is <wasm::LocalSet>()) {
35
+ wasm::LocalSet *const localSet = insertPosition->cast <wasm::LocalSet>();
36
+ if (canInsertAfter (localSet->value ))
37
+ return true ;
38
+ }
33
39
fmt::println (" [" PASS_NAME " ] fn '{}', failed to insert before {}" , func_->name .str , toString (insertPosition));
34
40
return false ;
35
41
}
@@ -53,55 +59,87 @@ void ExprInserter::insertBefore(wasm::Builder &b, wasm::Expression *insertedExpr
53
59
}
54
60
break ;
55
61
}
62
+ case wasm::Expression::LocalSetId: {
63
+ wasm::LocalSet *const localSet = insertPosition->cast <wasm::LocalSet>();
64
+ insertAfter (b, insertedExpr, &localSet->value );
65
+ break ;
66
+ }
56
67
default :
57
68
__builtin_unreachable ();
58
69
}
59
70
}
60
71
61
72
bool ExprInserter::canInsertAfter (wasm::Expression *insertPosition) {
62
- if (insertPosition->type == wasm::Type::none && !isTerminator (insertPosition))
63
- return true ;
64
- if (insertPosition->is <wasm::Return>())
65
- return true ;
66
- if (insertPosition->is <wasm::If>())
67
- return true ;
73
+ if (isTerminator (insertPosition)) {
74
+ // special handler for terminator
75
+ if (insertPosition->is <wasm::Unreachable>())
76
+ return true ;
77
+ if (wasm::Return *expr = insertPosition->dynCast <wasm::Return>(); expr != nullptr ) {
78
+ return true ;
79
+ if (expr->value == nullptr )
80
+ return true ;
81
+ if (canInsertAfter (expr->value ))
82
+ return true ;
83
+ }
84
+ if (wasm::Break *expr = insertPosition->dynCast <wasm::Break>(); expr != nullptr && expr->condition == nullptr ) {
85
+ return true ;
86
+ if (expr->value == nullptr )
87
+ return true ;
88
+ if (canInsertAfter (expr->value ))
89
+ return true ;
90
+ }
91
+ } else {
92
+ if (insertPosition->type == wasm::Type::none)
93
+ return true ;
94
+ if (insertPosition->type != wasm::Type::unreachable)
95
+ return true ;
96
+ }
68
97
fmt::println (" [" PASS_NAME " ] fn '{}', failed to insert after {}" , func_->name .str , toString (insertPosition));
69
98
return false ;
70
99
}
100
+
71
101
void ExprInserter::insertAfter (wasm::Builder &b, wasm::Expression *insertedExpr, wasm::Expression **insertPositionPtr) {
72
102
assert (insertedExpr->type == wasm::Type::none);
73
103
wasm::Expression *const insertPosition = *insertPositionPtr;
74
- if (insertPosition->type == wasm::Type::none) {
75
- *insertPositionPtr = b.makeBlock ({insertPosition, insertedExpr}, wasm::Type::none);
76
- return ;
77
- }
78
- if (auto *const returnExpr = insertPosition->dynCast <wasm::Return>()) {
79
- if (returnExpr->value == nullptr ) {
80
- *insertPositionPtr = b.makeBlock ({insertedExpr, insertPosition}, wasm::Type::none);
81
- } else {
82
- wasm::Type const localType = returnExpr->value ->type ;
83
- wasm::Index const tmpLocal = b.addVar (func_, localType);
84
- returnExpr->value = b.makeBlock (
85
- {b.makeLocalSet (tmpLocal, returnExpr->value ), insertedExpr, b.makeLocalGet (tmpLocal, localType)}, localType);
104
+
105
+ if (isTerminator (insertPosition)) {
106
+ if (insertPosition->is <wasm::Unreachable>()) {
107
+ *insertPositionPtr = b.makeBlock ({insertedExpr, insertPosition}, wasm::Type::unreachable);
108
+ return ;
86
109
}
87
- return ;
88
- }
89
- if (auto *const ifExpr = insertPosition->dynCast <wasm::If>()) {
90
- wasm::Type const type = ifExpr->type ;
91
- assert (type != wasm::Type::none);
92
- if (ifExpr->type == wasm::Type::unreachable) {
110
+ if (wasm::Return *expr = insertPosition->dynCast <wasm::Return>(); expr != nullptr ) {
111
+ if (expr->value == nullptr ) {
112
+ *insertPositionPtr = b.makeBlock ({insertedExpr, insertPosition}, wasm::Type::unreachable);
113
+ return ;
114
+ }
115
+ insertAfter (b, insertedExpr, &expr->value );
116
+ return ;
117
+ }
118
+ if (wasm::Break *expr = insertPosition->dynCast <wasm::Break>(); expr != nullptr && expr->condition == nullptr ) {
119
+ if (expr->value == nullptr ) {
120
+ *insertPositionPtr = b.makeBlock ({insertedExpr, insertPosition}, wasm::Type::unreachable);
121
+ return ;
122
+ }
123
+ insertAfter (b, insertedExpr, &expr->value );
124
+ return ;
125
+ }
126
+ } else {
127
+ wasm::Type const exprType = insertPosition->type ;
128
+ if (exprType == wasm::Type::none) {
93
129
*insertPositionPtr = b.makeBlock ({insertPosition, insertedExpr}, wasm::Type::none);
94
- } else {
95
- wasm::Index index = wasm::Builder::addVar (func_, type);
130
+ return ;
131
+ }
132
+ if (exprType != wasm::Type::unreachable) {
133
+ wasm::Index const tmpLocal = b.addVar (func_, exprType);
96
134
*insertPositionPtr = b.makeBlock (
97
135
{
98
- b.makeLocalSet (index, ifExpr ),
136
+ b.makeLocalSet (tmpLocal, insertPosition ),
99
137
insertedExpr,
100
- b.makeLocalGet (index, type ),
138
+ b.makeLocalGet (tmpLocal, exprType ),
101
139
},
102
- type);
140
+ exprType);
141
+ return ;
103
142
}
104
- return ;
105
143
}
106
144
__builtin_unreachable ();
107
145
}
@@ -113,10 +151,12 @@ void ExprInserter::insertAfter(wasm::Builder &b, wasm::Expression *insertedExpr,
113
151
#include < gtest/gtest.h>
114
152
115
153
#include " FindExpr.hpp"
154
+ #include " Matcher.hpp"
116
155
117
156
namespace warpo ::passes::ut {
118
157
119
- using wasm::Const, wasm::Block, wasm::Nop, wasm::Call, wasm::LocalGet, wasm::LocalSet, wasm::If, wasm::Return;
158
+ using wasm::Const, wasm::Block, wasm::Nop, wasm::Call, wasm::LocalGet, wasm::LocalSet, wasm::If, wasm::Return,
159
+ wasm::Loop, wasm::Break;
120
160
using wasm::Type;
121
161
122
162
TEST (ExprInserter, InsertBeforeNoOperand) {
@@ -173,6 +213,31 @@ TEST(ExprInserter, InsertBeforeCallWithOperands) {
173
213
ASSERT_TRUE (insertPos->cast <Call>()->operands [1 ]->cast <Block>()->list [2 ]->is <LocalGet>());
174
214
}
175
215
216
+ TEST (ExprInserter, InsertBeforeLocalSet) {
217
+ wasm::Module m{};
218
+ wasm::Builder b{m};
219
+ wasm::Expression *const insertPos = b.makeLocalSet (0 , b.makeConst (1 ));
220
+ std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction (" test" , wasm::Signature (), {}, insertPos);
221
+ ExprInserter inserter{f.get ()};
222
+
223
+ ASSERT_TRUE (inserter.canInsertBefore (insertPos));
224
+ inserter.insertBefore (b, b.makeNop (), findExprPointer (insertPos, f.get ()));
225
+
226
+ ASSERT_EQ (f->body , insertPos);
227
+
228
+ using namespace matcher ;
229
+ auto matcher = isLocalSet (local_set::v (isBlock (block::has (3 ), block::at (0 , isLocalSet (local_set::v (isConst ()))),
230
+ block::at (1 , isNop ()), block::at (2 , isLocalGet ()))));
231
+ EXPECT_TRUE (matcher (*f->body ));
232
+
233
+ ASSERT_EQ (f->body , insertPos);
234
+ ASSERT_TRUE (insertPos->cast <LocalSet>()->value ->is <Block>());
235
+ ASSERT_TRUE (insertPos->cast <LocalSet>()->value ->cast <Block>()->list [0 ]->is <LocalSet>());
236
+ ASSERT_TRUE (insertPos->cast <LocalSet>()->value ->cast <Block>()->list [0 ]->cast <LocalSet>()->value ->is <Const>());
237
+ ASSERT_TRUE (insertPos->cast <LocalSet>()->value ->cast <Block>()->list [1 ]->is <Nop>());
238
+ ASSERT_TRUE (insertPos->cast <LocalSet>()->value ->cast <Block>()->list [2 ]->is <LocalGet>());
239
+ }
240
+
176
241
TEST (ExprInserter, InsertAfterTypeNone) {
177
242
wasm::Module m{};
178
243
wasm::Builder b{m};
@@ -188,6 +253,38 @@ TEST(ExprInserter, InsertAfterTypeNone) {
188
253
ASSERT_TRUE (f->body ->cast <Block>()->list [1 ]->is <Nop>());
189
254
}
190
255
256
+ TEST (ExprInserter, InsertAfterLoopWithoutType) {
257
+ wasm::Module m{};
258
+ wasm::Builder b{m};
259
+ wasm::Expression *const insertPos = b.makeLoop (" " , b.makeNop (), Type::none);
260
+ std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction (" test" , wasm::Signature (), {}, insertPos);
261
+ ExprInserter inserter{f.get ()};
262
+
263
+ ASSERT_TRUE (inserter.canInsertAfter (insertPos));
264
+ inserter.insertAfter (b, b.makeNop (), findExprPointer (insertPos, f.get ()));
265
+
266
+ ASSERT_TRUE (f->body ->is <Block>());
267
+ ASSERT_TRUE (f->body ->cast <Block>()->list [0 ]->is <Loop>());
268
+ ASSERT_TRUE (f->body ->cast <Block>()->list [1 ]->is <Nop>());
269
+ }
270
+
271
+ TEST (ExprInserter, InsertAfterLoopWithType) {
272
+ wasm::Module m{};
273
+ wasm::Builder b{m};
274
+ wasm::Expression *const insertPos = b.makeLoop (" " , b.makeConst (1 ), Type::i32 );
275
+ std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction (" test" , wasm::Signature (), {}, insertPos);
276
+ ExprInserter inserter{f.get ()};
277
+
278
+ ASSERT_TRUE (inserter.canInsertAfter (insertPos));
279
+ inserter.insertAfter (b, b.makeNop (), findExprPointer (insertPos, f.get ()));
280
+
281
+ ASSERT_TRUE (f->body ->is <Block>());
282
+ ASSERT_TRUE (f->body ->cast <Block>()->list [0 ]->is <LocalSet>());
283
+ ASSERT_TRUE (f->body ->cast <Block>()->list [0 ]->cast <LocalSet>()->value ->is <Loop>());
284
+ ASSERT_TRUE (f->body ->cast <Block>()->list [1 ]->is <Nop>());
285
+ ASSERT_TRUE (f->body ->cast <Block>()->list [2 ]->is <LocalGet>());
286
+ }
287
+
191
288
TEST (ExprInserter, InsertAfterReturnWithoutValue) {
192
289
wasm::Module m{};
193
290
wasm::Builder b{m};
@@ -221,6 +318,39 @@ TEST(ExprInserter, InsertAfterReturnWithValue) {
221
318
ASSERT_TRUE (f->body ->cast <Return>()->value ->cast <Block>()->list [2 ]->is <LocalGet>());
222
319
}
223
320
321
+ TEST (ExprInserter, InsertAfterBrWithoutValue) {
322
+ wasm::Module m{};
323
+ wasm::Builder b{m};
324
+ wasm::Expression *const insertPos = b.makeBreak (" bb" );
325
+ std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction (" test" , wasm::Signature (), {}, insertPos);
326
+ ExprInserter inserter{f.get ()};
327
+
328
+ ASSERT_TRUE (inserter.canInsertAfter (insertPos));
329
+ inserter.insertAfter (b, b.makeNop (), findExprPointer (insertPos, f.get ()));
330
+
331
+ ASSERT_TRUE (f->body ->is <Block>());
332
+ ASSERT_TRUE (f->body ->cast <Block>()->list [0 ]->is <Nop>());
333
+ ASSERT_TRUE (f->body ->cast <Block>()->list [1 ]->is <Break>());
334
+ }
335
+
336
+ TEST (ExprInserter, InsertAfterBrWithValue) {
337
+ wasm::Module m{};
338
+ wasm::Builder b{m};
339
+ wasm::Expression *const insertPos = b.makeBreak (" bb" , b.makeConst (1 ));
340
+ std::unique_ptr<wasm::Function> f = wasm::Builder::makeFunction (" test" , wasm::Signature (), {}, insertPos);
341
+ ExprInserter inserter{f.get ()};
342
+
343
+ ASSERT_TRUE (inserter.canInsertAfter (insertPos));
344
+ inserter.insertAfter (b, b.makeNop (), findExprPointer (insertPos, f.get ()));
345
+
346
+ ASSERT_EQ (f->body , insertPos);
347
+ ASSERT_TRUE (f->body ->cast <Break>()->value ->is <Block>());
348
+ ASSERT_TRUE (f->body ->cast <Break>()->value ->cast <Block>()->list [0 ]->is <LocalSet>());
349
+ ASSERT_TRUE (f->body ->cast <Break>()->value ->cast <Block>()->list [0 ]->cast <LocalSet>()->value ->is <Const>());
350
+ ASSERT_TRUE (f->body ->cast <Break>()->value ->cast <Block>()->list [1 ]->is <Nop>());
351
+ ASSERT_TRUE (f->body ->cast <Break>()->value ->cast <Block>()->list [2 ]->is <LocalGet>());
352
+ }
353
+
224
354
} // namespace warpo::passes::ut
225
355
226
356
#endif
0 commit comments