@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929
3030#include " vast/Conversion/Parser/Config.hpp"
3131
32+ #include " vast/server/server.hpp"
33+ #include " vast/server/types.hpp"
34+
3235#include < ranges>
3336
3437namespace vast ::conv {
@@ -75,6 +78,150 @@ namespace vast::conv {
7578
7679 using function_models = llvm::StringMap< function_model >;
7780
81+ struct location
82+ {
83+ std::string filePath;
84+ server::range range;
85+ };
86+
87+ location get_location (file_loc_t loc) {
88+ return {
89+ .filePath = loc.getFilename ().str (),
90+ .range = {
91+ .start = { loc.getLine (), loc.getColumn (), },
92+ .end = { loc.getLine (), loc.getColumn (), },
93+ },
94+ };
95+ }
96+
97+ location get_location (name_loc_t loc) {
98+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99+ }
100+
101+ std::optional< location > get_location (loc_t loc) {
102+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+ return get_location (file_loc);
104+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+ return get_location (name_loc);
106+ }
107+
108+ return std::nullopt ;
109+ }
110+
111+ pr::data_type parse_type_name (const std::string &name) {
112+ if (name == " data" ) {
113+ return pr::data_type::data;
114+ } else if (name == " nodata" ) {
115+ return pr::data_type::nodata;
116+ } else {
117+ return pr::data_type::maybedata;
118+ }
119+ }
120+
121+ function_category ask_user_for_category (vast::server::server_base &server, core::function_op_interface op) {
122+ auto loc = op.getLoc ();
123+ auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation ());
124+ VAST_ASSERT (sym);
125+ auto name = sym.getSymbolName ().str ();
126+
127+ vast::server::input_request req{
128+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
129+ .text = " Please choose category for function `" + name + ' `' ,
130+ .filePath = std::nullopt ,
131+ .range = std::nullopt ,
132+ };
133+
134+ if (auto req_loc = get_location (loc)) {
135+ req.filePath = req_loc->filePath ;
136+ req.range = req_loc->range ;
137+ }
138+
139+ auto response = server.send_request (req);
140+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
141+ {
142+ if (result->value == " nonparser" ) {
143+ return function_category::nonparser;
144+ } else if (result->value == " sink" ) {
145+ return function_category::sink;
146+ } else if (result->value == " source" ) {
147+ return function_category::source;
148+ } else if (result->value == " parser" ) {
149+ return function_category::parser;
150+ }
151+ }
152+ return function_category::nonparser;
153+ }
154+
155+ pr::data_type ask_user_for_return_type (vast::server::server_base &server, core::function_op_interface op) {
156+ auto loc = op.getLoc ();
157+ auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation ());
158+ VAST_ASSERT (sym);
159+ auto name = sym.getSymbolName ().str ();
160+
161+ vast::server::input_request req{
162+ .type = { " maybedata" , " nodata" , " data" },
163+ .text = " Please choose return type for function `" + name + ' `' ,
164+ .filePath = std::nullopt ,
165+ .range = std::nullopt ,
166+ };
167+
168+ if (auto req_loc = get_location (loc)) {
169+ req.filePath = req_loc->filePath ;
170+ req.range = req_loc->range ;
171+ }
172+
173+ auto response = server.send_request (req);
174+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
175+ {
176+ return parse_type_name (result->value );
177+ }
178+ return pr::data_type::maybedata;
179+ }
180+
181+ pr::data_type ask_user_for_argument_type (
182+ vast::server::server_base &server, core::function_op_interface op, unsigned int idx
183+ ) {
184+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
185+ auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation ());
186+ VAST_ASSERT (sym);
187+ auto name = sym.getSymbolName ().str ();
188+
189+ vast::server::input_request req{
190+ .type = { " maybedata" , " nodata" , " data" },
191+ .text = " Please choose a type for argument " + std::to_string (idx)
192+ + " of function `" + name + ' `' ,
193+ .filePath = std::nullopt ,
194+ .range = std::nullopt ,
195+ };
196+
197+ if (idx < num_body_args) {
198+ auto arg = op.getArgument (idx);
199+ auto loc = arg.getLoc ();
200+ if (auto req_loc = get_location (loc)) {
201+ req.filePath = req_loc->filePath ;
202+ req.range = req_loc->range ;
203+ }
204+ }
205+
206+ auto response = server.send_request (req);
207+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
208+ {
209+ return parse_type_name (result->value );
210+ }
211+ return pr::data_type::maybedata;
212+ }
213+
214+ function_model
215+ ask_user_for_function_model (vast::server::server_base &server, core::function_op_interface op) {
216+ function_model model;
217+ model.return_type = ask_user_for_return_type (server, op);
218+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
219+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
220+ }
221+ model.category = ask_user_for_category (server, op);
222+ return model;
223+ }
224+
78225} // namespace vast::conv
79226
80227LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +277,28 @@ namespace vast::conv {
130277 using base = base_conversion_config;
131278
132279 parser_conversion_config (
133- rewrite_pattern_set patterns, conversion_target target,
134- const function_models &models
280+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
281+ vast::server::server_base *server
135282 )
136- : base(std::move(patterns), std::move(target)), models(models)
137- {}
283+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138284
139285 template < typename pattern >
140286 void add_pattern () {
141287 auto ctx = patterns.getContext ();
142288 if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143289 patterns.template add < pattern >(ctx);
144- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145- patterns.template add < pattern >(ctx, models);
290+ } else if constexpr (std::is_constructible_v<
291+ pattern, mcontext_t *, function_models &,
292+ vast::server::server_base * >)
293+ {
294+ patterns.template add < pattern >(ctx, models, server);
146295 } else {
147296 static_assert (false , " pattern does not have a valid constructor" );
148297 }
149298 }
150299
151- const function_models ⊧
300+ function_models ⊧
301+ vast::server::server_base *server;
152302 };
153303
154304 struct function_type_converter
@@ -277,26 +427,35 @@ namespace vast::conv {
277427 {
278428 using base = mlir::OpConversionPattern< op_t >;
279429
280- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
281- : base(mctx), models(models)
282- {}
430+ parser_conversion_pattern_base (
431+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
432+ )
433+ : base(mctx), models(models), server(server) {}
283434
284- static std::optional< function_model >
285- get_model (const function_models &models, core::function_op_interface op) {
435+ static std::optional< function_model > get_model (
436+ function_models &models, core::function_op_interface op, vast::server::server_base *server
437+ ) {
286438 auto sym = mlir::dyn_cast<core::SymbolOpInterface>(op.getOperation ());
287439 VAST_ASSERT (sym);
288440 if (auto kv = models.find (sym.getSymbolName ()); kv != models.end ()) {
289441 return kv->second ;
290442 }
291443
444+ if (server) {
445+ auto model = ask_user_for_function_model (*server, op);
446+ models[sym.getSymbolName ()] = model;
447+ return model;
448+ }
449+
292450 return std::nullopt ;
293451 }
294452
295453 std::optional< function_model > get_model (core::function_op_interface op) const {
296- return get_model (models, op);
454+ return get_model (models, op, server );
297455 }
298456
299- const function_models ⊧
457+ function_models ⊧
458+ vast::server::server_base *server;
300459 };
301460
302461 //
@@ -543,10 +702,13 @@ namespace vast::conv {
543702 return mlir::failure ();
544703 }
545704
546- static void legalize (parser_conversion_config &cfg) {
705+ static void
706+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
547707 cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
548- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg.models ](op_t op) {
549- return function_type_converter (*op.getContext (), get_model (models, op))
708+ cfg.target .addDynamicallyLegalOp < op_t >([&cfg, server](op_t op) {
709+ return function_type_converter (
710+ *op.getContext (), get_model (cfg.models , op, server)
711+ )
550712 .isLegal (op.getFunctionType ());
551713 });
552714 }
@@ -724,6 +886,9 @@ namespace vast::conv {
724886 {
725887 using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
726888
889+ struct server_handler
890+ {};
891+
727892 static conversion_target create_conversion_target (mcontext_t &mctx) {
728893 return conversion_target (mctx);
729894 }
@@ -738,6 +903,12 @@ namespace vast::conv {
738903 if (!config.empty ()) {
739904 load_and_parse (config);
740905 }
906+
907+ if (!socket.empty ()) {
908+ server = std::make_shared< vast::server::server< server_handler > >(
909+ vast::server::sock_adapter::create_unix_socket (socket)
910+ );
911+ }
741912 }
742913
743914 void load_and_parse (string_ref config) {
@@ -764,10 +935,12 @@ namespace vast::conv {
764935
765936 parser_conversion_config make_config () {
766937 auto &ctx = getContext ();
767- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
938+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
939+ server.get () };
768940 }
769941
770942 function_models models;
943+ std::shared_ptr< vast::server::server< server_handler > > server;
771944 };
772945
773946} // namespace vast::conv
0 commit comments