|
1 | 1 | #include "BasicBlockWalker.hpp"
|
| 2 | +#include <algorithm> |
2 | 3 | #include <cfg/cfg-traversal.h>
|
| 4 | +#include <iostream> |
| 5 | +#include <memory> |
| 6 | +#include <set> |
| 7 | +#include <utility> |
| 8 | +#include <wasm.h> |
3 | 9 | #include "ir/branch-utils.h"
|
4 | 10 |
|
5 | 11 | namespace wasmInstrumentation {
|
@@ -27,30 +33,72 @@ void BasicBlockWalker::visitExpression(wasm::Expression *curr) noexcept {
|
27 | 33 | }
|
28 | 34 | }
|
29 | 35 |
|
30 |
| -void BasicBlockWalker::unlinkEmptyBlock() noexcept { |
31 |
| - const auto lambda = [](std::unique_ptr<BasicBlock> &block) { |
32 |
| - if (block->contents.exprs.empty() && block->out.size() == 1) { |
33 |
| - const auto outBlock = block->out[0]; |
34 |
| - outBlock->in.erase(std::find(outBlock->in.begin(), outBlock->in.end(), block.get())); |
35 |
| - for (auto &inBlock : block->in) { |
36 |
| - inBlock->out.erase(std::find(inBlock->out.begin(), inBlock->out.end(), block.get())); |
37 |
| - inBlock->out.push_back(outBlock); |
38 |
| - outBlock->in.push_back(inBlock); |
| 36 | +static bool |
| 37 | +isBasicBlockContainUnreachable(BasicBlockWalker::BasicBlock &block, |
| 38 | + std::set<BasicBlockWalker::BasicBlock *> unreachableBlocks) { |
| 39 | + return (!block.contents.exprs.empty() && |
| 40 | + std::any_of(block.contents.exprs.begin(), block.contents.exprs.end(), |
| 41 | + [](wasm::Expression *expr) { |
| 42 | + return expr->is<wasm::Unreachable>(); |
| 43 | + })) || |
| 44 | + (!block.in.empty() && |
| 45 | + std::all_of(block.in.begin(), block.in.end(), |
| 46 | + [&unreachableBlocks](BasicBlockWalker::BasicBlock *inBlock) { |
| 47 | + return unreachableBlocks.find(inBlock) != unreachableBlocks.end(); |
| 48 | + })); |
| 49 | +}; |
| 50 | + |
| 51 | +static void removeDuplicates(std::vector<BasicBlockWalker::BasicBlock *> &list) { |
| 52 | + std::sort(list.begin(), list.end()); |
| 53 | + list.erase(std::unique(list.begin(), list.end()), list.end()); |
| 54 | +} |
| 55 | + |
| 56 | +void BasicBlockWalker::cleanBlock() noexcept { |
| 57 | + bool isModified = true; |
| 58 | + std::set<BasicBlock *> unreachableBlocks{}; |
| 59 | + while (isModified) { |
| 60 | + isModified = false; |
| 61 | + for (auto &block : basicBlocks) { |
| 62 | + if (isBasicBlockContainUnreachable(*block, unreachableBlocks)) { |
| 63 | + isModified |= unreachableBlocks.insert(block.get()).second; |
39 | 64 | }
|
40 |
| - block->in.clear(); |
41 |
| - block->out.clear(); |
42 |
| - return true; |
43 | 65 | }
|
44 |
| - return false; |
45 |
| - }; |
46 |
| - basicBlocks.erase(std::remove_if(basicBlocks.begin(), basicBlocks.end(), lambda), |
47 |
| - basicBlocks.end()); |
| 66 | + } |
| 67 | + std::set<BasicBlock *> emptyBlocks{}; |
| 68 | + for (auto &block : basicBlocks) { |
| 69 | + if (block->contents.exprs.empty() && block->out.size() == 1) { |
| 70 | + emptyBlocks.insert(block.get()); |
| 71 | + } |
| 72 | + } |
| 73 | + |
| 74 | + std::set<BasicBlock *> targetCleanBlocks{}; |
| 75 | + targetCleanBlocks.insert(unreachableBlocks.begin(), unreachableBlocks.end()); |
| 76 | + targetCleanBlocks.insert(emptyBlocks.begin(), emptyBlocks.end()); |
| 77 | + |
| 78 | + for (auto &block : targetCleanBlocks) { |
| 79 | + for (auto &outBlock : block->out) { |
| 80 | + outBlock->in.erase(std::find(outBlock->in.begin(), outBlock->in.end(), block)); |
| 81 | + outBlock->in.insert(outBlock->in.end(), block->in.begin(), block->in.end()); |
| 82 | + removeDuplicates(outBlock->in); |
| 83 | + } |
| 84 | + for (auto &inBlock : block->in) { |
| 85 | + inBlock->out.erase(std::find(inBlock->out.begin(), inBlock->out.end(), block)); |
| 86 | + inBlock->out.insert(inBlock->out.end(), block->out.begin(), block->out.end()); |
| 87 | + removeDuplicates(inBlock->out); |
| 88 | + } |
| 89 | + block->in.clear(); |
| 90 | + block->out.clear(); |
| 91 | + basicBlocks.erase(std::find_if(basicBlocks.begin(), basicBlocks.end(), |
| 92 | + [&block](std::unique_ptr<BasicBlock> const &b) -> bool { |
| 93 | + return b.get() == block; |
| 94 | + })); |
| 95 | + } |
48 | 96 | }
|
49 | 97 |
|
50 | 98 | void BasicBlockWalker::doWalkFunction(wasm::Function *const func) noexcept {
|
51 | 99 | wasm::CFGWalker<BasicBlockWalker, wasm::UnifiedExpressionVisitor<BasicBlockWalker>,
|
52 | 100 | BasicBlockInfo>::doWalkFunction(func);
|
53 |
| - unlinkEmptyBlock(); |
| 101 | + cleanBlock(); |
54 | 102 | // LCOV_EXCL_START
|
55 | 103 | if (basicBlocks.size() > UINT32_MAX) {
|
56 | 104 | std::cerr << "Error: BasicBlocks length exceeds UINT32_MAX\n";
|
|
0 commit comments