Skip to content

Commit 8d37c96

Browse files
committed
wasm3: validate function type
1 parent 7d4c8f1 commit 8d37c96

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

test/utils/wasm3_engine.cpp

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,33 @@ class Wasm3Engine final : public WasmEngine
3838

3939
namespace
4040
{
41+
M3ValueType translate_valtype(char input)
42+
{
43+
if (input == 'i')
44+
return M3ValueType::c_m3Type_i32;
45+
else if (input == 'I')
46+
return M3ValueType::c_m3Type_i64;
47+
else
48+
throw std::runtime_error{"invalid type"};
49+
}
50+
51+
std::pair<std::vector<M3ValueType>, std::vector<M3ValueType>> translate_signature(
52+
std::string_view signature)
53+
{
54+
const auto delimiter_pos = signature.find(':');
55+
assert(delimiter_pos != std::string_view::npos);
56+
const auto inputs = signature.substr(0, delimiter_pos);
57+
const auto outputs = signature.substr(delimiter_pos + 1);
58+
59+
std::vector<M3ValueType> input_types;
60+
std::vector<M3ValueType> output_types;
61+
std::transform(
62+
std::begin(inputs), std::end(inputs), std::back_inserter(input_types), translate_valtype);
63+
std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_types),
64+
translate_valtype);
65+
return {std::move(input_types), std::move(output_types)};
66+
}
67+
4168
const void* env_adler32(
4269
IM3Runtime /*runtime*/, IM3ImportContext /*context*/, uint64_t* stack, void* mem)
4370
{
@@ -115,14 +142,29 @@ fizzy::bytes_view Wasm3Engine::get_memory() const
115142
}
116143

117144
std::optional<WasmEngine::FuncRef> Wasm3Engine::find_function(
118-
std::string_view name, std::string_view) const
145+
std::string_view name, std::string_view signature) const
119146
{
120147
IM3Function function;
121-
if (m3_FindFunction(&function, m_runtime, name.data()) == m3Err_none)
122-
// TODO: validate input/output types
123-
// (m3_GetArgCount/m3_GetArgType/m3_GetRetCount/m3_GetRetType)
124-
return reinterpret_cast<WasmEngine::FuncRef>(function);
125-
return std::nullopt;
148+
if (m3_FindFunction(&function, m_runtime, name.data()) != m3Err_none)
149+
return std::nullopt;
150+
151+
std::vector<M3ValueType> inputs;
152+
std::vector<M3ValueType> outputs;
153+
std::tie(inputs, outputs) = translate_signature(signature);
154+
155+
if (inputs.size() != m3_GetArgCount(function))
156+
return std::nullopt;
157+
for (unsigned i = 0; i < m3_GetArgCount(function); i++)
158+
if (inputs[i] != m3_GetArgType(function, i))
159+
return std::nullopt;
160+
161+
if (outputs.size() != m3_GetRetCount(function))
162+
return std::nullopt;
163+
for (unsigned i = 0; i < m3_GetRetCount(function); i++)
164+
if (outputs[i] != m3_GetRetType(function, i))
165+
return std::nullopt;
166+
167+
return reinterpret_cast<WasmEngine::FuncRef>(function);
126168
}
127169

128170
WasmEngine::Result Wasm3Engine::execute(
@@ -137,7 +179,7 @@ WasmEngine::Result Wasm3Engine::execute(
137179

138180
// This ensures input count/type matches. For the return value we assume find_function did the
139181
// validation.
140-
if (m3_Call(function, static_cast<uint32_t>(args.size()), argPtrs.data()) == m3Err_none)
182+
if (m3_Call(function, static_cast<uint32_t>(argPtrs.size()), argPtrs.data()) == m3Err_none)
141183
{
142184
if (m3_GetRetCount(function) == 0)
143185
return {false, std::nullopt};

0 commit comments

Comments
 (0)