@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929
3030#include " vast/Conversion/Parser/Config.hpp"
3131
32+ #include " vast/server/server.hpp"
33+ #include " vast/server/types.hpp"
34+
3235#include < ranges>
3336
3437namespace vast ::conv {
@@ -75,6 +78,154 @@ namespace vast::conv {
7578
7679 using function_models = llvm::StringMap< function_model >;
7780
81+ struct location
82+ {
83+ std::string filePath;
84+ server::range range;
85+ };
86+
87+ location get_location (file_loc_t loc) {
88+ return {
89+ .filePath = loc.getFilename ().str (),
90+ .range = {
91+ .start = { loc.getLine (), loc.getColumn (), },
92+ .end = { loc.getLine (), loc.getColumn (), },
93+ },
94+ };
95+ }
96+
97+ location get_location (name_loc_t loc) {
98+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99+ }
100+
101+ std::optional< location > get_location (loc_t loc) {
102+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+ return get_location (file_loc);
104+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+ return get_location (name_loc);
106+ }
107+
108+ return std::nullopt ;
109+ }
110+
111+ pr::data_type parse_type_name (const std::string &name) {
112+ if (name == " data" ) {
113+ return pr::data_type::data;
114+ } else if (name == " nodata" ) {
115+ return pr::data_type::nodata;
116+ } else {
117+ return pr::data_type::maybedata;
118+ }
119+ }
120+
121+ function_category
122+ ask_user_for_category (vast::server::server_base &server, core::function_op_interface op) {
123+ auto loc = op.getLoc ();
124+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
125+ VAST_ASSERT (sym);
126+ auto name = sym.getSymbolName ().str ();
127+
128+ vast::server::input_request req{
129+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
130+ .text = " Please choose category for function `" + name + ' `' ,
131+ .filePath = std::nullopt ,
132+ .range = std::nullopt ,
133+ };
134+
135+ if (auto req_loc = get_location (loc)) {
136+ req.filePath = req_loc->filePath ;
137+ req.range = req_loc->range ;
138+ }
139+
140+ auto response = server.send_request (req);
141+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
142+ {
143+ if (result->value == " nonparser" ) {
144+ return function_category::nonparser;
145+ } else if (result->value == " sink" ) {
146+ return function_category::sink;
147+ } else if (result->value == " source" ) {
148+ return function_category::source;
149+ } else if (result->value == " parser" ) {
150+ return function_category::parser;
151+ }
152+ }
153+ return function_category::nonparser;
154+ }
155+
156+ pr::data_type ask_user_for_return_type (
157+ vast::server::server_base &server, core::function_op_interface op
158+ ) {
159+ auto loc = op.getLoc ();
160+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
161+ VAST_ASSERT (sym);
162+ auto name = sym.getSymbolName ().str ();
163+
164+ vast::server::input_request req{
165+ .type = { " maybedata" , " nodata" , " data" },
166+ .text = " Please choose return type for function `" + name + ' `' ,
167+ .filePath = std::nullopt ,
168+ .range = std::nullopt ,
169+ };
170+
171+ if (auto req_loc = get_location (loc)) {
172+ req.filePath = req_loc->filePath ;
173+ req.range = req_loc->range ;
174+ }
175+
176+ auto response = server.send_request (req);
177+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
178+ {
179+ return parse_type_name (result->value );
180+ }
181+ return pr::data_type::maybedata;
182+ }
183+
184+ pr::data_type ask_user_for_argument_type (
185+ vast::server::server_base &server, core::function_op_interface op, unsigned int idx
186+ ) {
187+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
188+ auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
189+ VAST_ASSERT (sym);
190+ auto name = sym.getSymbolName ().str ();
191+
192+ vast::server::input_request req{
193+ .type = { " maybedata" , " nodata" , " data" },
194+ .text = " Please choose a type for argument " + std::to_string (idx)
195+ + " of function `" + name + ' `' ,
196+ .filePath = std::nullopt ,
197+ .range = std::nullopt ,
198+ };
199+
200+ if (idx < num_body_args) {
201+ auto arg = op.getArgument (idx);
202+ auto loc = arg.getLoc ();
203+ if (auto req_loc = get_location (loc)) {
204+ req.filePath = req_loc->filePath ;
205+ req.range = req_loc->range ;
206+ }
207+ }
208+
209+ auto response = server.send_request (req);
210+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
211+ {
212+ return parse_type_name (result->value );
213+ }
214+ return pr::data_type::maybedata;
215+ }
216+
217+ function_model ask_user_for_function_model (
218+ vast::server::server_base &server, core::function_op_interface op
219+ ) {
220+ function_model model;
221+ model.return_type = ask_user_for_return_type (server, op);
222+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
223+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
224+ }
225+ model.category = ask_user_for_category (server, op);
226+ return model;
227+ }
228+
78229} // namespace vast::conv
79230
80231LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +281,28 @@ namespace vast::conv {
130281 using base = base_conversion_config;
131282
132283 parser_conversion_config (
133- rewrite_pattern_set patterns, conversion_target target,
134- const function_models &models
284+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
285+ vast::server::server_base *server
135286 )
136- : base(std::move(patterns), std::move(target)), models(models)
137- {}
287+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138288
139289 template < typename pattern >
140290 void add_pattern () {
141291 auto ctx = patterns.getContext ();
142292 if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143293 patterns.template add < pattern >(ctx);
144- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145- patterns.template add < pattern >(ctx, models);
294+ } else if constexpr (std::is_constructible_v<
295+ pattern, mcontext_t *, function_models &,
296+ vast::server::server_base * >)
297+ {
298+ patterns.template add < pattern >(ctx, models, server);
146299 } else {
147300 static_assert (false , " pattern does not have a valid constructor" );
148301 }
149302 }
150303
151- const function_models ⊧
304+ function_models ⊧
305+ vast::server::server_base *server;
152306 };
153307
154308 using signature_conversion_t = mlir::TypeConverter::SignatureConversion;
@@ -308,26 +462,36 @@ namespace vast::conv {
308462 {
309463 using base = mlir::OpConversionPattern< op_t >;
310464
311- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
312- : base(mctx), models(models)
313- {}
465+ parser_conversion_pattern_base (
466+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
467+ )
468+ : base(mctx), models(models), server(server) {}
314469
315- static std::optional< function_model >
316- get_model (const function_models &models, core::function_op_interface op) {
470+ static std::optional< function_model > get_model (
471+ function_models &models, core::function_op_interface op,
472+ vast::server::server_base *server
473+ ) {
317474 auto sym = mlir::dyn_cast< core::SymbolOpInterface >(op.getOperation ());
318475 VAST_ASSERT (sym);
319476 if (auto kv = models.find (sym.getSymbolName ()); kv != models.end ()) {
320477 return kv->second ;
321478 }
322479
480+ if (server) {
481+ auto model = ask_user_for_function_model (*server, op);
482+ models[sym.getSymbolName ()] = model;
483+ return model;
484+ }
485+
323486 return std::nullopt ;
324487 }
325488
326489 std::optional< function_model > get_model (core::function_op_interface op) const {
327- return get_model (models, op);
490+ return get_model (models, op, server );
328491 }
329492
330- const function_models ⊧
493+ function_models ⊧
494+ vast::server::server_base *server;
331495 };
332496
333497 //
@@ -606,11 +770,12 @@ namespace vast::conv {
606770 return mlir::success ();
607771 }
608772
609- static void legalize (parser_conversion_config &cfg) {
773+ static void
774+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
610775 cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
611- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg. models ](op_t op) {
776+ cfg.target .addDynamicallyLegalOp < op_t >([& cfg, server ](op_t op) {
612777 auto tc = function_type_converter (
613- *op.getContext (), get_model (models, op)
778+ *op.getContext (), get_model (cfg. models , op, server )
614779 );
615780 return tc.isSignatureLegal (op.getFunctionType ());
616781 });
@@ -870,6 +1035,9 @@ namespace vast::conv {
8701035 {
8711036 using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
8721037
1038+ struct server_handler
1039+ {};
1040+
8731041 static conversion_target create_conversion_target (mcontext_t &mctx) {
8741042 return conversion_target (mctx);
8751043 }
@@ -884,6 +1052,12 @@ namespace vast::conv {
8841052 if (!config.empty ()) {
8851053 load_and_parse (config);
8861054 }
1055+
1056+ if (!socket.empty ()) {
1057+ server = std::make_shared< vast::server::server< server_handler > >(
1058+ vast::server::sock_adapter::create_unix_socket (socket)
1059+ );
1060+ }
8871061 }
8881062
8891063 void load_and_parse (string_ref config) {
@@ -910,10 +1084,12 @@ namespace vast::conv {
9101084
9111085 parser_conversion_config make_config () {
9121086 auto &ctx = getContext ();
913- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
1087+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
1088+ server.get () };
9141089 }
9151090
9161091 function_models models;
1092+ std::shared_ptr< vast::server::server< server_handler > > server;
9171093 };
9181094
9191095} // namespace vast::conv
0 commit comments