Skip to content

Commit 26ed43d

Browse files
committed
feat(InsertTracePoint): support insert trace point (#128)
1 parent bfe75a0 commit 26ed43d

File tree

7 files changed

+421
-7
lines changed

7 files changed

+421
-7
lines changed

include/warpo/support/Opt.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ namespace warpo::cli {
88

99
enum class Category : uint32_t {
1010
None = 0,
11-
Frontend = 1 << 1,
12-
Optimization = 1 << 2,
13-
OnlyForTest = 1 << 3,
11+
OnlyForTest = 1 << 1,
12+
Frontend = 1 << 2,
13+
Optimization = 1 << 3,
14+
Transformation = 1 << 4,
1415
All = static_cast<uint32_t>(-1),
1516
};
1617

include/warpo/support/StatisticsKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#endif
77

88
PERF_ITEM_KIND_TOP(CompilationHIR)
9+
PERF_ITEM_KIND_TOP(Instrument)
910
PERF_ITEM_KIND_TOP(Optimization)
1011
PERF_ITEM_KIND_TOP(Validation)
1112
PERF_ITEM_KIND_TOP(Lowering)

passes/InsertTracePoint.cpp

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
// Copyright (C) 2025 Bayerische Motoren Werke Aktiengesellschaft (BMW AG)
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
#include <cassert>
18+
#include <cstddef>
19+
#include <fmt/base.h>
20+
#include <fmt/format.h>
21+
#include <fstream>
22+
#include <iostream>
23+
#include <memory>
24+
#include <string>
25+
26+
#include "InsertTracePoint.hpp"
27+
#include "literal.h"
28+
#include "pass.h"
29+
#include "warpo/support/FileSystem.hpp"
30+
#include "warpo/support/IncMap.hpp"
31+
#include "warpo/support/Opt.hpp"
32+
#include "wasm-builder.h"
33+
#include "wasm-type.h"
34+
#include "wasm.h"
35+
36+
#define PASS_NAME "Tracing"
37+
#define DEBUG_PREFIX "[Tracing] "
38+
39+
constexpr const char *tracePointFunctionName = "~/lib/trace_point";
40+
constexpr size_t traceOffset = 0x1'000000;
41+
42+
namespace warpo::passes {
43+
namespace {
44+
45+
struct FunctionIndexMap : private IncMap<wasm::Function *> {
46+
using IncMap::contains;
47+
size_t getIndex(wasm::Function *func) const { return IncMap::getIndex(func) + traceOffset; }
48+
49+
static FunctionIndexMap create(wasm::Module *m) {
50+
FunctionIndexMap functionIndexes{};
51+
for (std::unique_ptr<wasm::Function> const &func : m->functions) {
52+
if (func->name.startsWith("~lib"))
53+
continue;
54+
functionIndexes.insert(func.get());
55+
}
56+
return functionIndexes;
57+
}
58+
template <class Fn> void forEach(Fn &&fn) const {
59+
for (auto const &[func, index] : *this) {
60+
fn(func, index + traceOffset);
61+
}
62+
}
63+
};
64+
65+
struct TracingInserter : public wasm::Pass {
66+
FunctionIndexMap const &functionIndexes_;
67+
TracingInserter(FunctionIndexMap const &functionIndexes) noexcept : functionIndexes_(functionIndexes) {}
68+
bool isFunctionParallel() override { return true; }
69+
bool modifiesBinaryenIR() override { return true; }
70+
std::unique_ptr<wasm::Pass> create() override { return std::make_unique<TracingInserter>(functionIndexes_); }
71+
72+
void runOnFunction(wasm::Module *m, wasm::Function *func) override {
73+
struct WrapperCall : public wasm::PostWalker<WrapperCall> {
74+
FunctionIndexMap const &functionIndexes_;
75+
explicit WrapperCall(FunctionIndexMap const &functionIndexes) : functionIndexes_(functionIndexes) {}
76+
void visitCall(wasm::Call *expr) {
77+
wasm::Function *func = getModule()->getFunction(expr->target);
78+
if (!func->imported())
79+
return;
80+
assert(functionIndexes_.contains(func));
81+
int32_t index = static_cast<int32_t>(functionIndexes_.getIndex(func));
82+
wasm::Builder b{*getModule()};
83+
wasm::Type const resultType = func->getResults();
84+
if (func->getResults() == wasm::Type::none) {
85+
replaceCurrent(b.makeBlock(
86+
{
87+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(index))}, wasm::Type::none),
88+
expr,
89+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(-index))}, wasm::Type::none),
90+
},
91+
wasm::Type::none));
92+
} else {
93+
wasm::Index const localIdx = wasm::Builder::addVar(getFunction(), resultType);
94+
replaceCurrent(b.makeBlock(
95+
{
96+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(index))}, wasm::Type::none),
97+
b.makeLocalSet(localIdx, expr),
98+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(-index))}, wasm::Type::none),
99+
b.makeLocalGet(localIdx, resultType),
100+
},
101+
resultType));
102+
}
103+
}
104+
};
105+
106+
WrapperCall wrapperCall{functionIndexes_};
107+
wrapperCall.walkFunctionInModule(func, m);
108+
109+
if (!functionIndexes_.contains(func))
110+
return;
111+
112+
const int32_t currentIndex = static_cast<int32_t>(functionIndexes_.getIndex(func));
113+
assert(currentIndex > 0);
114+
115+
struct ReturnWithResultReplacer : public wasm::PostWalker<ReturnWithResultReplacer> {
116+
wasm::Index const scratchReturnValueLocalIndex_;
117+
int32_t const currentIndex_;
118+
wasm::Type const &resultType_;
119+
explicit ReturnWithResultReplacer(wasm::Index const scratchReturnValueLocalIndex, int32_t const currentIndex,
120+
wasm::Type const &returnType)
121+
: scratchReturnValueLocalIndex_(scratchReturnValueLocalIndex), currentIndex_(currentIndex),
122+
resultType_(returnType) {}
123+
void visitReturn(wasm::Return *expr) {
124+
wasm::Builder b{*getModule()};
125+
assert(expr->value);
126+
replaceCurrent(b.makeBlock(
127+
{
128+
b.makeLocalSet(scratchReturnValueLocalIndex_, expr->value),
129+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(-currentIndex_))}, wasm::Type::none),
130+
expr,
131+
},
132+
wasm::Type::unreachable));
133+
expr->value = b.makeLocalGet(scratchReturnValueLocalIndex_, resultType_);
134+
}
135+
};
136+
struct ReturnWithoutResultReplacer : public wasm::PostWalker<ReturnWithoutResultReplacer> {
137+
int32_t const currentIndex_;
138+
explicit ReturnWithoutResultReplacer(int32_t const currentIndex) : currentIndex_(currentIndex) {}
139+
void visitReturn(wasm::Return *expr) {
140+
wasm::Builder b{*getModule()};
141+
replaceCurrent(b.makeBlock(
142+
{
143+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal{-currentIndex_})}, wasm::Type::none),
144+
expr,
145+
},
146+
wasm::Type::unreachable));
147+
}
148+
};
149+
150+
wasm::Builder b{*m};
151+
wasm::Type const resultType = func->getResults();
152+
if (resultType == wasm::Type::none) {
153+
func->body = b.makeBlock(
154+
{
155+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(currentIndex))}, wasm::Type::none, false),
156+
func->body,
157+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(-currentIndex))}, wasm::Type::none, false),
158+
},
159+
func->getResults());
160+
ReturnWithoutResultReplacer returnReplacer{currentIndex};
161+
returnReplacer.walkFunctionInModule(func, m);
162+
} else {
163+
wasm::Index const scratchReturnValueLocalIndex = wasm::Builder::addVar(func, resultType);
164+
func->body = b.makeBlock(
165+
{
166+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(currentIndex))}, wasm::Type::none, false),
167+
b.makeLocalSet(scratchReturnValueLocalIndex, func->body),
168+
b.makeCall(tracePointFunctionName, {b.makeConst(wasm::Literal(-currentIndex))}, wasm::Type::none, false),
169+
b.makeLocalGet(scratchReturnValueLocalIndex, resultType),
170+
},
171+
func->getResults());
172+
ReturnWithResultReplacer returnReplacer{scratchReturnValueLocalIndex, currentIndex, resultType};
173+
returnReplacer.walkFunctionInModule(func, m);
174+
}
175+
}
176+
};
177+
178+
struct TracePointInserter : public wasm::Pass {
179+
std::string const tracePointMappingFile_;
180+
explicit TracePointInserter(std::string const &tracePointMappingFile)
181+
: tracePointMappingFile_(tracePointMappingFile) {}
182+
void run(wasm::Module *m) override {
183+
wasm::Builder b{*m};
184+
if (m->getFunctionOrNull(tracePointFunctionName) == nullptr) {
185+
std::unique_ptr<wasm::Function> func = wasm::Builder::makeFunction(
186+
tracePointFunctionName, wasm::Signature{wasm::Type::i32, wasm::Type::none}, {}, nullptr);
187+
func->module = "builtin";
188+
func->base = "tracePoint";
189+
m->addFunction(std::move(func));
190+
}
191+
FunctionIndexMap const functionIndexes = FunctionIndexMap::create(m);
192+
193+
wasm::PassRunner runner{getPassRunner()};
194+
runner.add(std::unique_ptr<wasm::Pass>(new TracingInserter(functionIndexes)));
195+
runner.run();
196+
197+
if (!tracePointMappingFile_.empty()) {
198+
std::ofstream of{tracePointMappingFile_, std::ios::out};
199+
ensureFileDirectory(tracePointMappingFile_);
200+
functionIndexes.forEach([&](wasm::Function *func, size_t index) { of << index << " " << func->name << "\n"; });
201+
}
202+
}
203+
};
204+
205+
struct TracePointInserterPlaceHolder : public wasm::Pass {
206+
bool modifiesBinaryenIR() override { return false; }
207+
bool isFunctionParallel() override { return true; }
208+
std::unique_ptr<wasm::Pass> create() override {
209+
return std::unique_ptr<wasm::Pass>{new TracePointInserterPlaceHolder()};
210+
}
211+
void run(wasm::Module *m) override {}
212+
void runOnFunction(wasm::Module *m, wasm::Function *func) override {}
213+
};
214+
215+
} // namespace
216+
} // namespace warpo::passes
217+
218+
namespace warpo {
219+
220+
static cli::Opt<std::string> tracePointMappingOption{
221+
cli::Category::Transformation,
222+
"--trace-point-mapping-file",
223+
[](argparse::Argument &arg) -> void { arg.help("File to write the trace output to."); },
224+
};
225+
226+
wasm::Pass *passes::createInsertTracePointPass() {
227+
std::string const &tracePointMappingFile = tracePointMappingOption.get();
228+
if (tracePointMappingFile.empty())
229+
return new TracePointInserterPlaceHolder();
230+
return new TracePointInserter(tracePointMappingFile);
231+
}
232+
233+
} // namespace warpo
234+
235+
#ifdef WARPO_ENABLE_UNIT_TESTS
236+
237+
#include <gtest/gtest.h>
238+
239+
#include "Runner.hpp"
240+
#include "helper/Matcher.hpp"
241+
#include "pass.h"
242+
243+
namespace warpo::passes::ut {
244+
namespace {
245+
246+
TEST(TracePointInserterTest, WithoutResult) {
247+
auto m = loadWat(R"(
248+
(module
249+
(func $empty)
250+
(func $fn_without_result
251+
call $empty
252+
call $empty
253+
call $empty
254+
call $empty
255+
)
256+
)
257+
)");
258+
259+
wasm::Function *fn = m->getFunction("fn_without_result");
260+
261+
wasm::PassRunner runner{m.get()};
262+
runner.add(std::unique_ptr<wasm::Pass>{new TracePointInserter(tracePointMappingOption.get())});
263+
runner.run();
264+
265+
using namespace matcher;
266+
auto const match = isBlock(block::list(allOf({
267+
has(3),
268+
at(0, isCall(call::callee(tracePointFunctionName), call::operands(allOf({at(0, isConst())})))),
269+
at(1, isBlock(block::list(allOf({
270+
has(4),
271+
at(0, isCall(call::callee("empty"))),
272+
at(1, isCall(call::callee("empty"))),
273+
at(2, isCall(call::callee("empty"))),
274+
at(3, isCall(call::callee("empty"))),
275+
})))),
276+
at(2, isCall(call::callee(tracePointFunctionName), call::operands(allOf({at(0, isConst())})))),
277+
})));
278+
279+
EXPECT_TRUE(match(*fn->body));
280+
}
281+
282+
TEST(TracePointInserterTest, WithResult) {
283+
auto m = loadWat(R"(
284+
(module
285+
(func $empty)
286+
(func $fn_with_result (result i32)
287+
call $empty
288+
call $empty
289+
call $empty
290+
i32.const 10
291+
)
292+
)
293+
)");
294+
295+
wasm::Function *fn = m->getFunction("fn_with_result");
296+
297+
wasm::PassRunner runner{m.get()};
298+
runner.add(std::unique_ptr<wasm::Pass>{new TracePointInserter(tracePointMappingOption.get())});
299+
runner.run();
300+
301+
using namespace matcher;
302+
auto const match = isBlock(block::list(allOf({
303+
has(4),
304+
at(0, isCall(call::callee(tracePointFunctionName), call::operands(allOf({at(0, isConst())})))),
305+
at(1, isLocalSet(local_set::v(isBlock(block::list(allOf({
306+
has(4),
307+
at(0, isCall(call::callee("empty"))),
308+
at(1, isCall(call::callee("empty"))),
309+
at(2, isCall(call::callee("empty"))),
310+
at(3, isConst()),
311+
})))))),
312+
at(2, isCall(call::callee(tracePointFunctionName), call::operands(allOf({at(0, isConst())})))),
313+
at(3, isLocalGet()),
314+
})));
315+
316+
EXPECT_TRUE(match(*fn->body));
317+
}
318+
319+
TEST(TracePointInserterTest, CallImport) {
320+
auto m = loadWat(R"(
321+
(module
322+
(import "env" "empty" (func $empty))
323+
(func $fn_call_import
324+
call $empty
325+
)
326+
)
327+
)");
328+
329+
wasm::Function *fn = m->getFunction("fn_call_import");
330+
331+
wasm::PassRunner runner{m.get()};
332+
runner.add(std::unique_ptr<wasm::Pass>{new TracePointInserter(tracePointMappingOption.get())});
333+
runner.run();
334+
335+
using namespace matcher;
336+
auto const match = isBlock(block::list(allOf({
337+
has(3),
338+
at(0, isCall(call::callee(tracePointFunctionName), call::operands(at(0, isConst())))),
339+
at(1, isBlock(block::list(allOf({
340+
has(3),
341+
at(0, isCall(call::callee(tracePointFunctionName))),
342+
at(1, isCall(call::callee("empty"))),
343+
at(2, isCall(call::callee(tracePointFunctionName))),
344+
})))),
345+
at(2, isCall(call::callee(tracePointFunctionName), call::operands(at(0, isConst())))),
346+
})));
347+
348+
EXPECT_TRUE(match(*fn->body));
349+
}
350+
351+
} // namespace
352+
} // namespace warpo::passes::ut
353+
354+
#endif

0 commit comments

Comments
 (0)