Skip to content

Commit 52492b7

Browse files
committed
pr: Add option to use interactive server in parser conversion pass.
1 parent 61a5d1b commit 52492b7

File tree

4 files changed

+239
-19
lines changed

4 files changed

+239
-19
lines changed

include/vast/Conversion/Parser/Passes.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "vast/Util/Warnings.hpp"
66

7+
#include "vast/server/server.hpp"
8+
79
VAST_RELAX_WARNINGS
810
#include <mlir/Pass/Pass.h>
911
#include <mlir/Pass/PassManager.h>

include/vast/Conversion/Parser/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1010
let options = [
1111
Option< "config", "config", "std::string", "",
1212
"Configuration file for parser transformation."
13+
>,
14+
Option< "socket", "socket", "std::string", "",
15+
"Unix socket path to use for server"
1316
>
1417
];
1518

include/vast/server/types.hpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,34 @@
44

55
#include <concepts>
66
#include <cstdint>
7+
#include <optional>
78
#include <string>
89
#include <variant>
910

1011
#include <nlohmann/json.hpp>
1112

13+
namespace nlohmann {
14+
template< typename T >
15+
struct adl_serializer< std::optional< T > >
16+
{
17+
static void to_json(json &j, const std::optional< T > &opt) {
18+
if (!opt.has_value()) {
19+
j = nullptr;
20+
} else {
21+
j = *opt;
22+
}
23+
}
24+
25+
static void from_json(const json &j, std::optional< T > &opt) {
26+
if (j.is_null()) {
27+
opt = std::nullopt;
28+
} else {
29+
opt = j.template get< T >();
30+
}
31+
}
32+
};
33+
} // namespace nlohmann
34+
1235
namespace vast::server {
1336
template< typename T >
1437
concept json_convertible = requires(T obj, nlohmann::json &json) {
@@ -88,14 +111,33 @@ namespace vast::server {
88111
template< request_like request >
89112
using result_type = std::variant< typename request::response_type, error< request > >;
90113

114+
struct position
115+
{
116+
unsigned int line;
117+
unsigned int character;
118+
119+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(position, line, character)
120+
};
121+
122+
struct range
123+
{
124+
position start;
125+
position end;
126+
127+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(range, start, end)
128+
};
129+
91130
struct input_request
92131
{
93132
static constexpr const char *method = "input";
94133
static constexpr bool is_notification = false;
95134

96135
nlohmann::json type;
97136
std::string text;
98-
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text)
137+
std::optional< std::string > filePath;
138+
std::optional< range > range;
139+
140+
NLOHMANN_DEFINE_TYPE_INTRUSIVE(input_request, type, text, filePath, range)
99141

100142
struct response_type
101143
{

lib/vast/Conversion/Parser/ToParser.cpp

Lines changed: 191 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929

3030
#include "vast/Conversion/Parser/Config.hpp"
3131

32+
#include "vast/server/server.hpp"
33+
#include "vast/server/types.hpp"
34+
3235
#include <ranges>
3336

3437
namespace vast::conv {
@@ -75,6 +78,150 @@ namespace vast::conv {
7578

7679
using function_models = llvm::StringMap< function_model >;
7780

81+
struct location
82+
{
83+
std::string filePath;
84+
server::range range;
85+
};
86+
87+
location get_location(file_loc_t loc) {
88+
return {
89+
.filePath = loc.getFilename().str(),
90+
.range = {
91+
.start = { loc.getLine(), loc.getColumn(), },
92+
.end = { loc.getLine(), loc.getColumn(), },
93+
},
94+
};
95+
}
96+
97+
location get_location(name_loc_t loc) {
98+
return get_location(mlir::cast< file_loc_t >(loc.getChildLoc()));
99+
}
100+
101+
std::optional< location > get_location(loc_t loc) {
102+
if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+
return get_location(file_loc);
104+
} else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+
return get_location(name_loc);
106+
}
107+
108+
return std::nullopt;
109+
}
110+
111+
pr::data_type parse_type_name(const std::string &name) {
112+
if (name == "data") {
113+
return pr::data_type::data;
114+
} else if (name == "nodata") {
115+
return pr::data_type::nodata;
116+
} else {
117+
return pr::data_type::maybedata;
118+
}
119+
}
120+
121+
function_category ask_user_for_category(vast::server::server_base &server, core::function_op_interface op) {
122+
auto loc = op.getLoc();
123+
auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation());
124+
VAST_ASSERT(sym);
125+
auto name = sym.getSymbolName().str();
126+
127+
vast::server::input_request req{
128+
.type = {"nonparser", "sink", "source", "parser",},
129+
.text = "Please choose category for function `" + name + '`',
130+
.filePath = std::nullopt,
131+
.range = std::nullopt,
132+
};
133+
134+
if (auto req_loc = get_location(loc)) {
135+
req.filePath = req_loc->filePath;
136+
req.range = req_loc->range;
137+
}
138+
139+
auto response = server.send_request(req);
140+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
141+
{
142+
if (result->value == "nonparser") {
143+
return function_category::nonparser;
144+
} else if (result->value == "sink") {
145+
return function_category::sink;
146+
} else if (result->value == "source") {
147+
return function_category::source;
148+
} else if (result->value == "parser") {
149+
return function_category::parser;
150+
}
151+
}
152+
return function_category::nonparser;
153+
}
154+
155+
pr::data_type ask_user_for_return_type(vast::server::server_base &server, core::function_op_interface op) {
156+
auto loc = op.getLoc();
157+
auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation());
158+
VAST_ASSERT(sym);
159+
auto name = sym.getSymbolName().str();
160+
161+
vast::server::input_request req{
162+
.type = { "maybedata", "nodata", "data" },
163+
.text = "Please choose return type for function `" + name + '`',
164+
.filePath = std::nullopt,
165+
.range = std::nullopt,
166+
};
167+
168+
if (auto req_loc = get_location(loc)) {
169+
req.filePath = req_loc->filePath;
170+
req.range = req_loc->range;
171+
}
172+
173+
auto response = server.send_request(req);
174+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
175+
{
176+
return parse_type_name(result->value);
177+
}
178+
return pr::data_type::maybedata;
179+
}
180+
181+
pr::data_type ask_user_for_argument_type(
182+
vast::server::server_base &server, core::function_op_interface op, unsigned int idx
183+
) {
184+
auto num_body_args = op.getFunctionBody().getNumArguments();
185+
auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation());
186+
VAST_ASSERT(sym);
187+
auto name = sym.getSymbolName().str();
188+
189+
vast::server::input_request req{
190+
.type = { "maybedata", "nodata", "data" },
191+
.text = "Please choose a type for argument " + std::to_string(idx)
192+
+ " of function `" + name + '`',
193+
.filePath = std::nullopt,
194+
.range = std::nullopt,
195+
};
196+
197+
if (idx < num_body_args) {
198+
auto arg = op.getArgument(idx);
199+
auto loc = arg.getLoc();
200+
if (auto req_loc = get_location(loc)) {
201+
req.filePath = req_loc->filePath;
202+
req.range = req_loc->range;
203+
}
204+
}
205+
206+
auto response = server.send_request(req);
207+
if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
208+
{
209+
return parse_type_name(result->value);
210+
}
211+
return pr::data_type::maybedata;
212+
}
213+
214+
function_model
215+
ask_user_for_function_model(vast::server::server_base &server, core::function_op_interface op) {
216+
function_model model;
217+
model.return_type = ask_user_for_return_type(server, op);
218+
for (unsigned int i = 0; i < op.getNumArguments(); ++i) {
219+
model.arguments.push_back(ask_user_for_argument_type(server, op, i));
220+
}
221+
model.category = ask_user_for_category(server, op);
222+
return model;
223+
}
224+
78225
} // namespace vast::conv
79226

80227
LLVM_YAML_IS_SEQUENCE_VECTOR(vast::pr::data_type);
@@ -130,25 +277,28 @@ namespace vast::conv {
130277
using base = base_conversion_config;
131278

132279
parser_conversion_config(
133-
rewrite_pattern_set patterns, conversion_target target,
134-
const function_models &models
280+
rewrite_pattern_set patterns, conversion_target target, function_models &models,
281+
vast::server::server_base *server
135282
)
136-
: base(std::move(patterns), std::move(target)), models(models)
137-
{}
283+
: base(std::move(patterns), std::move(target)), models(models), server(server) {}
138284

139285
template< typename pattern >
140286
void add_pattern() {
141287
auto ctx = patterns.getContext();
142288
if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143289
patterns.template add< pattern >(ctx);
144-
} else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145-
patterns.template add< pattern >(ctx, models);
290+
} else if constexpr (std::is_constructible_v<
291+
pattern, mcontext_t *, function_models &,
292+
vast::server::server_base * >)
293+
{
294+
patterns.template add< pattern >(ctx, models, server);
146295
} else {
147296
static_assert(false, "pattern does not have a valid constructor");
148297
}
149298
}
150299

151-
const function_models &models;
300+
function_models &models;
301+
vast::server::server_base *server;
152302
};
153303

154304
struct function_type_converter
@@ -277,26 +427,35 @@ namespace vast::conv {
277427
{
278428
using base = mlir::OpConversionPattern< op_t >;
279429

280-
parser_conversion_pattern_base(mcontext_t *mctx, const function_models &models)
281-
: base(mctx), models(models)
282-
{}
430+
parser_conversion_pattern_base(
431+
mcontext_t *mctx, function_models &models, vast::server::server_base *server
432+
)
433+
: base(mctx), models(models), server(server) {}
283434

284-
static std::optional< function_model >
285-
get_model(const function_models &models, core::function_op_interface op) {
435+
static std::optional< function_model > get_model(
436+
function_models &models, core::function_op_interface op, vast::server::server_base *server
437+
) {
286438
auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation());
287439
VAST_ASSERT(sym);
288440
if (auto kv = models.find(sym.getSymbolName()); kv != models.end()) {
289441
return kv->second;
290442
}
291443

444+
if (server) {
445+
auto model = ask_user_for_function_model(*server, op);
446+
models[sym.getSymbolName()] = model;
447+
return model;
448+
}
449+
292450
return std::nullopt;
293451
}
294452

295453
std::optional< function_model > get_model(core::function_op_interface op) const {
296-
return get_model(models, op);
454+
return get_model(models, op, server);
297455
}
298456

299-
const function_models &models;
457+
function_models &models;
458+
vast::server::server_base *server;
300459
};
301460

302461
//
@@ -543,10 +702,13 @@ namespace vast::conv {
543702
return mlir::failure();
544703
}
545704

546-
static void legalize(parser_conversion_config &cfg) {
705+
static void
706+
legalize(parser_conversion_config &cfg, vast::server::server_base *server) {
547707
cfg.target.addLegalOp< mlir::UnrealizedConversionCastOp >();
548-
cfg.target.addDynamicallyLegalOp< op_t >([models = cfg.models](op_t op) {
549-
return function_type_converter(*op.getContext(), get_model(models, op))
708+
cfg.target.addDynamicallyLegalOp< op_t >([&cfg, server](op_t op) {
709+
return function_type_converter(
710+
*op.getContext(), get_model(cfg.models, op, server)
711+
)
550712
.isLegal(op.getFunctionType());
551713
});
552714
}
@@ -724,6 +886,9 @@ namespace vast::conv {
724886
{
725887
using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
726888

889+
struct server_handler
890+
{};
891+
727892
static conversion_target create_conversion_target(mcontext_t &mctx) {
728893
return conversion_target(mctx);
729894
}
@@ -738,6 +903,12 @@ namespace vast::conv {
738903
if (!config.empty()) {
739904
load_and_parse(config);
740905
}
906+
907+
if (!socket.empty()) {
908+
server = std::make_shared< vast::server::server< server_handler > >(
909+
vast::server::sock_adapter::create_unix_socket(socket)
910+
);
911+
}
741912
}
742913

743914
void load_and_parse(string_ref config) {
@@ -764,10 +935,12 @@ namespace vast::conv {
764935

765936
parser_conversion_config make_config() {
766937
auto &ctx = getContext();
767-
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models };
938+
return { rewrite_pattern_set(&ctx), create_conversion_target(ctx), models,
939+
server.get() };
768940
}
769941

770942
function_models models;
943+
std::shared_ptr< vast::server::server< server_handler > > server;
771944
};
772945

773946
} // namespace vast::conv

0 commit comments

Comments
 (0)