Skip to content

Commit 6be2588

Browse files
committed
pr: Add option to use interactive server in parser conversion pass.
1 parent ca110c2 commit 6be2588

File tree

4 files changed

+242
-19
lines changed

4 files changed

+242
-19
lines changed

include/vast/Conversion/Parser/Passes.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "vast/Util/Warnings.hpp"
66

7+
#include "vast/server/server.hpp"
8+
79
VAST_RELAX_WARNINGS
810
#include <mlir/Pass/Pass.h>
911
#include <mlir/Pass/PassManager.h>

include/vast/Conversion/Parser/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1010
let options = [
1111
Option< "config", "config", "std::string", "",
1212
"Configuration file for parser transformation."
13+
>,
14+
Option< "socket", "socket", "std::string", "",
15+
"Unix socket path to use for server"
1316
>
1417
];
1518

include/vast/server/types.hpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,34 @@
44

55
#include <concepts>
66
#include <cstdint>
7+
#include <optional>
78
#include <string>
89
#include <variant>
910

1011
#include <nlohmann/json.hpp>
1112

13+
namespace nlohmann {
14+
template< typename T >
15+
struct adl_serializer< std::optional< T > >
16+
{
17+
static void to_json(json &j, const std::optional< T > &opt) {
18+
if (!opt.has_value()) {
19+
j = nullptr;
20+
} else {
21+
j = *opt;
22+
}
23+
}
24+
25+
static void from_json(const json &j, std::optional< T > &opt) {
26+
if (j.is_null()) {
27+
opt = std::nullopt;
28+
} else {
29+
opt = j.template get< T >();
30+
}
31+
}
32+
};
33+
} // namespace nlohmann
34+
1235
namespace vast::server {
1336
template< typename T >
1437
concept json_convertible = requires(T obj, nlohmann::json &json) {
@@ -88,14 +111,33 @@ namespace vast::server {
88111
template< request_like request >
89112
using result_type = std::variant< typename request::response_type, error< request > >;
90113

114+
struct position
115+
{
116+
unsigned int line;
117+
unsigned int character;
118+
119+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
120+
};
121+
122+
struct range
123+
{
124+
position start;
125+
position end;
126+
127+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
128+
};
129+
91130
struct input_request
92131
{
93132
static constexpr const char *method = "input";
94133
static constexpr bool is_notification = false;
95134

96135
nlohmann::json type;
97136
std::string text;
98-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
137+
std::optional< std::string > filePath;
138+
std::optional< range > range;
139+
140+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)
99141

100142
struct response_type
101143
{

lib/vast/Conversion/Parser/ToParser.cpp

Lines changed: 194 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929

3030
#include "vast/Conversion/Parser/Config.hpp"
3131

32+
#include "vast/server/server.hpp"
33+
#include "vast/server/types.hpp"
34+
3235
#include <ranges>
3336

3437
namespace vast::conv {
@@ -75,6 +78,154 @@ namespace vast::conv {
7578

7679
using function_models = llvm::StringMap< function_model >;
7780

81+
struct location
82+
{
83+
std::string filePath;
84+
server::range range;
85+
};
86+
87+
location get_location(file_loc_t loc) {
88+
return {
89+
.filePath = loc.getFilename().str(),
90+
.range = {
91+
.start = { loc.getLine(), loc.getColumn(), },
92+
.end = { loc.getLine(), loc.getColumn(), },
93+
},
94+
};
95+
}
96+
97+
location get_location(name_loc_t loc) {
98+
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
99+
}
100+
101+
std::optional< location > get_location(loc_t loc) {
102+
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+
return get_location(file_loc);
104+
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+
return get_location(name_loc);
106+
}
107+
108+
return std::nullopt;
109+
}
110+
111+
pr::data_type parse_type_name(const std::string &name) {
112+
if (name == "data") {
113+
return pr::data_type::data;
114+
} else if (name == "nodata") {
115+
return pr::data_type::nodata;
116+
} else {
117+
return pr::data_type::maybedata;
118+
}
119+
}
120+
121+
function_category
122+
ask_user_for_category(vast::server::server_base &server, core::function_op_interface op) {
123+
auto loc = op.getLoc();
124+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
125+
VAST_ASSERT(sym);
126+
auto name = sym.getSymbolName().str();
127+
128+
vast::server::input_request req{
129+
.type = {"nonparser", "sink", "source", "parser",},
130+
.text = "Please choose category for function `" + name + '`',
131+
.filePath = std::nullopt,
132+
.range = std::nullopt,
133+
};
134+
135+
if (auto req_loc = get_location(loc)) {
136+
req.filePath = req_loc->filePath;
137+
req.range = req_loc->range;
138+
}
139+
140+
auto response = server.send_request(req);
141+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
142+
{
143+
if (result->value == "nonparser") {
144+
return function_category::nonparser;
145+
} else if (result->value == "sink") {
146+
return function_category::sink;
147+
} else if (result->value == "source") {
148+
return function_category::source;
149+
} else if (result->value == "parser") {
150+
return function_category::parser;
151+
}
152+
}
153+
return function_category::nonparser;
154+
}
155+
156+
pr::data_type ask_user_for_return_type(
157+
vast::server::server_base &server, core::function_op_interface op
158+
) {
159+
auto loc = op.getLoc();
160+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
161+
VAST_ASSERT(sym);
162+
auto name = sym.getSymbolName().str();
163+
164+
vast::server::input_request req{
165+
.type = { "maybedata", "nodata", "data" },
166+
.text = "Please choose return type for function `" + name + '`',
167+
.filePath = std::nullopt,
168+
.range = std::nullopt,
169+
};
170+
171+
if (auto req_loc = get_location(loc)) {
172+
req.filePath = req_loc->filePath;
173+
req.range = req_loc->range;
174+
}
175+
176+
auto response = server.send_request(req);
177+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
178+
{
179+
return parse_type_name(result->value);
180+
}
181+
return pr::data_type::maybedata;
182+
}
183+
184+
pr::data_type ask_user_for_argument_type(
185+
vast::server::server_base &server, core::function_op_interface op, unsigned int idx
186+
) {
187+
auto num_body_args = op.getFunctionBody().getNumArguments();
188+
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
189+
VAST_ASSERT(sym);
190+
auto name = sym.getSymbolName().str();
191+
192+
vast::server::input_request req{
193+
.type = { "maybedata", "nodata", "data" },
194+
.text = "Please choose a type for argument " + std::to_string(idx)
195+
+ " of function `" + name + '`',
196+
.filePath = std::nullopt,
197+
.range = std::nullopt,
198+
};
199+
200+
if (idx < num_body_args) {
201+
auto arg = op.getArgument(idx);
202+
auto loc = arg.getLoc();
203+
if (auto req_loc = get_location(loc)) {
204+
req.filePath = req_loc->filePath;
205+
req.range = req_loc->range;
206+
}
207+
}
208+
209+
auto response = server.send_request(req);
210+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
211+
{
212+
return parse_type_name(result->value);
213+
}
214+
return pr::data_type::maybedata;
215+
}
216+
217+
function_model ask_user_for_function_model(
218+
vast::server::server_base &server, core::function_op_interface op
219+
) {
220+
function_model model;
221+
model.return_type = ask_user_for_return_type(server, op);
222+
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
223+
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
224+
}
225+
model.category = ask_user_for_category(server, op);
226+
return model;
227+
}
228+
78229
} // namespace vast::conv
79230

80231
LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
@@ -130,25 +281,28 @@ namespace vast::conv {
130281
using base = base_conversion_config;
131282

132283
parser_conversion_config(
133-
rewrite_pattern_set patterns, conversion_target target,
134-
const function_models &models
284+
rewrite_pattern_set patterns, conversion_target target, function_models &models,
285+
vast::server::server_base *server
135286
)
136-
: base(std::move(patterns), std::move(target)), models(models)
137-
{}
287+
: base(std::move(patterns), std::move(target)), models(models), server(server) {}
138288

139289
template< typename pattern >
140290
void add_pattern() {
141291
auto ctx = patterns.getContext();
142292
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143293
patterns.template add< pattern >(ctx);
144-
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145-
patterns.template add< pattern >(ctx, models);
294+
} else if constexpr (std::is_constructible_v<
295+
pattern, mcontext_t *, function_models &,
296+
vast::server::server_base * >)
297+
{
298+
patterns.template add< pattern >(ctx, models, server);
146299
} else {
147300
static_assert(false, "pattern does not have a valid constructor");
148301
}
149302
}
150303

151-
const function_models &models;
304+
function_models &models;
305+
vast::server::server_base *server;
152306
};
153307

154308
using signature_conversion_t = mlir::TypeConverter::SignatureConversion;
@@ -308,26 +462,36 @@ namespace vast::conv {
308462
{
309463
using base = mlir::OpConversionPattern< op_t >;
310464

311-
parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
312-
: base(mctx), models(models)
313-
{}
465+
parser_conversion_pattern_base(
466+
mcontext_t *mctx, function_models &models, vast::server::server_base *server
467+
)
468+
: base(mctx), models(models), server(server) {}
314469

315-
static std::optional< function_model >
316-
get_model(const function_models &models, core::function_op_interface op) {
470+
static std::optional< function_model > get_model(
471+
function_models &models, core::function_op_interface op,
472+
vast::server::server_base *server
473+
) {
317474
auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation());
318475
VAST_ASSERT(sym);
319476
if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) {
320477
return kv->second;
321478
}
322479

480+
if (server) {
481+
auto model = ask_user_for_function_model(*server, op);
482+
models[sym.getSymbolName()] = model;
483+
return model;
484+
}
485+
323486
return std::nullopt;
324487
}
325488

326489
std::optional< function_model > get_model(core::function_op_interface op) const {
327-
return get_model(models, op);
490+
return get_model(models, op, server);
328491
}
329492

330-
const function_models &models;
493+
function_models &models;
494+
vast::server::server_base *server;
331495
};
332496

333497
//
@@ -606,11 +770,12 @@ namespace vast::conv {
606770
return mlir::success();
607771
}
608772

609-
static void legalize(parser_conversion_config &cfg) {
773+
static void
774+
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
610775
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
611-
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
776+
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
612777
auto tc = function_type_converter(
613-
*op.getContext(), get_model(models, op)
778+
*op.getContext(), get_model(cfg.models, op, server)
614779
);
615780
return tc.isSignatureLegal(op.getFunctionType());
616781
});
@@ -870,6 +1035,9 @@ namespace vast::conv {
8701035
{
8711036
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
8721037

1038+
struct server_handler
1039+
{};
1040+
8731041
static conversion_target create_conversion_target(mcontext_t &mctx) {
8741042
return conversion_target(mctx);
8751043
}
@@ -884,6 +1052,12 @@ namespace vast::conv {
8841052
if (!config.empty()) {
8851053
load_and_parse(config);
8861054
}
1055+
1056+
if (!socket.empty()) {
1057+
server = std::make_shared< vast::server::server< server_handler > >(
1058+
vast::server::sock_adapter::create_unix_socket(socket)
1059+
);
1060+
}
8871061
}
8881062

8891063
void load_and_parse(string_ref config) {
@@ -910,10 +1084,12 @@ namespace vast::conv {
9101084

9111085
parser_conversion_config make_config() {
9121086
auto &ctx = getContext();
913-
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
1087+
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
1088+
server.get() };
9141089
}
9151090

9161091
function_models models;
1092+
std::shared_ptr< vast::server::server< server_handler > > server;
9171093
};
9181094

9191095
} // namespace vast::conv

0 commit comments

Comments
 (0)