@@ -29,6 +29,9 @@ VAST_UNRELAX_WARNINGS
2929
3030#include " vast/Conversion/Parser/Config.hpp"
3131
32+ #include " vast/server/server.hpp"
33+ #include " vast/server/types.hpp"
34+
3235#include < ranges>
3336
3437namespace vast ::conv {
@@ -75,6 +78,141 @@ namespace vast::conv {
7578
7679 using function_models = llvm::StringMap< function_model >;
7780
81+ struct location
82+ {
83+ std::string filePath;
84+ server::range range;
85+ };
86+
87+ location get_location (file_loc_t loc) {
88+ return {
89+ .filePath = loc.getFilename ().str (),
90+ .range = {
91+ .start = { loc.getLine (), loc.getColumn (), },
92+ .end = { loc.getLine (), loc.getColumn (), },
93+ },
94+ };
95+ }
96+
97+ location get_location (name_loc_t loc) {
98+ return get_location (mlir::cast< file_loc_t >(loc.getChildLoc ()));
99+ }
100+
101+ std::optional< location > get_location (loc_t loc) {
102+ if (auto file_loc = mlir::dyn_cast< file_loc_t >(loc)) {
103+ return get_location (file_loc);
104+ } else if (auto name_loc = mlir::dyn_cast< name_loc_t >(loc)) {
105+ return get_location (name_loc);
106+ }
107+
108+ return std::nullopt ;
109+ }
110+
111+ pr::data_type parse_type_name (const std::string &name) {
112+ if (name == " data" ) {
113+ return pr::data_type::data;
114+ } else if (name == " nodata" ) {
115+ return pr::data_type::nodata;
116+ } else {
117+ return pr::data_type::maybedata;
118+ }
119+ }
120+
121+ function_category ask_user_for_category (vast::server::server_base &server, hl::FuncOp op) {
122+ auto loc = op.getLoc ();
123+
124+ vast::server::input_request req{
125+ .type = {" nonparser" , " sink" , " source" , " parser" ,},
126+ .text = " Please choose category for function `" + op.getSymName ().str () + ' `' ,
127+ .filePath = std::nullopt ,
128+ .range = std::nullopt ,
129+ };
130+
131+ if (auto req_loc = get_location (loc)) {
132+ req.filePath = req_loc->filePath ;
133+ req.range = req_loc->range ;
134+ }
135+
136+ auto response = server.send_request (req);
137+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
138+ {
139+ if (result->value == " nonparser" ) {
140+ return function_category::nonparser;
141+ } else if (result->value == " sink" ) {
142+ return function_category::sink;
143+ } else if (result->value == " source" ) {
144+ return function_category::source;
145+ } else if (result->value == " parser" ) {
146+ return function_category::parser;
147+ }
148+ }
149+ return function_category::nonparser;
150+ }
151+
152+ pr::data_type ask_user_for_return_type (vast::server::server_base &server, hl::FuncOp op) {
153+ auto loc = op.getLoc ();
154+
155+ vast::server::input_request req{
156+ .type = { " maybedata" , " nodata" , " data" },
157+ .text = " Please choose return type for function `" + op.getSymName ().str () + ' `' ,
158+ .filePath = std::nullopt ,
159+ .range = std::nullopt ,
160+ };
161+
162+ if (auto req_loc = get_location (loc)) {
163+ req.filePath = req_loc->filePath ;
164+ req.range = req_loc->range ;
165+ }
166+
167+ auto response = server.send_request (req);
168+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
169+ {
170+ return parse_type_name (result->value );
171+ }
172+ return pr::data_type::maybedata;
173+ }
174+
175+ pr::data_type ask_user_for_argument_type (
176+ vast::server::server_base &server, hl::FuncOp op, unsigned int idx
177+ ) {
178+ auto num_body_args = op.getFunctionBody ().getNumArguments ();
179+
180+ vast::server::input_request req{
181+ .type = { " maybedata" , " nodata" , " data" },
182+ .text = " Please choose a type for argument " + std::to_string (idx)
183+ + " of function `" + op.getSymName ().str () + ' `' ,
184+ .filePath = std::nullopt ,
185+ .range = std::nullopt ,
186+ };
187+
188+ if (idx < num_body_args) {
189+ auto arg = op.getArgument (idx);
190+ auto loc = arg.getLoc ();
191+ if (auto req_loc = get_location (loc)) {
192+ req.filePath = req_loc->filePath ;
193+ req.range = req_loc->range ;
194+ }
195+ }
196+
197+ auto response = server.send_request (req);
198+ if (auto result = std::get_if< vast::server::input_request::response_type >(&response))
199+ {
200+ return parse_type_name (result->value );
201+ }
202+ return pr::data_type::maybedata;
203+ }
204+
205+ function_model
206+ ask_user_for_function_model (vast::server::server_base &server, hl::FuncOp op) {
207+ function_model model;
208+ model.return_type = ask_user_for_return_type (server, op);
209+ for (unsigned int i = 0 ; i < op.getNumArguments (); ++i) {
210+ model.arguments .push_back (ask_user_for_argument_type (server, op, i));
211+ }
212+ model.category = ask_user_for_category (server, op);
213+ return model;
214+ }
215+
78216} // namespace vast::conv
79217
80218LLVM_YAML_IS_SEQUENCE_VECTOR (vast::pr::data_type);
@@ -130,25 +268,28 @@ namespace vast::conv {
130268 using base = base_conversion_config;
131269
132270 parser_conversion_config (
133- rewrite_pattern_set patterns, conversion_target target,
134- const function_models &models
271+ rewrite_pattern_set patterns, conversion_target target, function_models &models,
272+ vast::server::server_base *server
135273 )
136- : base(std::move(patterns), std::move(target)), models(models)
137- {}
274+ : base(std::move(patterns), std::move(target)), models(models), server(server) {}
138275
139276 template < typename pattern >
140277 void add_pattern () {
141278 auto ctx = patterns.getContext ();
142279 if constexpr (std::is_constructible_v< pattern, mcontext_t * >) {
143280 patterns.template add < pattern >(ctx);
144- } else if constexpr (std::is_constructible_v< pattern, mcontext_t *, const function_models & >) {
145- patterns.template add < pattern >(ctx, models);
281+ } else if constexpr (std::is_constructible_v<
282+ pattern, mcontext_t *, function_models &,
283+ vast::server::server_base * >)
284+ {
285+ patterns.template add < pattern >(ctx, models, server);
146286 } else {
147287 static_assert (false , " pattern does not have a valid constructor" );
148288 }
149289 }
150290
151- const function_models ⊧
291+ function_models ⊧
292+ vast::server::server_base *server;
152293 };
153294
154295 struct function_type_converter
@@ -277,24 +418,33 @@ namespace vast::conv {
277418 {
278419 using base = mlir::OpConversionPattern< op_t >;
279420
280- parser_conversion_pattern_base (mcontext_t *mctx, const function_models &models)
281- : base(mctx), models(models)
282- {}
421+ parser_conversion_pattern_base (
422+ mcontext_t *mctx, function_models &models, vast::server::server_base *server
423+ )
424+ : base(mctx), models(models), server(server) {}
283425
284- static std::optional< function_model >
285- get_model (const function_models &models, hl::FuncOp func) {
426+ static std::optional< function_model > get_model (
427+ function_models &models, hl::FuncOp func, vast::server::server_base *server
428+ ) {
286429 if (auto kv = models.find (func.getSymName ()); kv != models.end ()) {
287430 return kv->second ;
288431 }
289432
433+ if (server) {
434+ auto model = ask_user_for_function_model (*server, func);
435+ models[func.getSymName ()] = model;
436+ return model;
437+ }
438+
290439 return std::nullopt ;
291440 }
292441
293442 std::optional< function_model > get_model (hl::FuncOp func) const {
294- return get_model (models, func);
443+ return get_model (models, func, server );
295444 }
296445
297- const function_models ⊧
446+ function_models ⊧
447+ vast::server::server_base *server;
298448 };
299449
300450 //
@@ -541,10 +691,13 @@ namespace vast::conv {
541691 return mlir::failure ();
542692 }
543693
544- static void legalize (parser_conversion_config &cfg) {
694+ static void
695+ legalize (parser_conversion_config &cfg, vast::server::server_base *server) {
545696 cfg.target .addLegalOp < mlir::UnrealizedConversionCastOp >();
546- cfg.target .addDynamicallyLegalOp < op_t >([models = cfg.models ](op_t op) {
547- return function_type_converter (*op.getContext (), get_model (models, op))
697+ cfg.target .addDynamicallyLegalOp < op_t >([&cfg, server](op_t op) {
698+ return function_type_converter (
699+ *op.getContext (), get_model (cfg.models , op, server)
700+ )
548701 .isLegal (op.getFunctionType ());
549702 });
550703 }
@@ -708,6 +861,9 @@ namespace vast::conv {
708861 {
709862 using base = ConversionPassMixin< HLToParserPass, HLToParserBase >;
710863
864+ struct server_handler
865+ {};
866+
711867 static conversion_target create_conversion_target (mcontext_t &mctx) {
712868 return conversion_target (mctx);
713869 }
@@ -722,6 +878,12 @@ namespace vast::conv {
722878 if (!config.empty ()) {
723879 load_and_parse (config);
724880 }
881+
882+ if (!socket.empty ()) {
883+ server = std::make_shared< vast::server::server< server_handler > >(
884+ vast::server::sock_adapter::create_unix_socket (socket)
885+ );
886+ }
725887 }
726888
727889 void load_and_parse (string_ref config) {
@@ -748,10 +910,12 @@ namespace vast::conv {
748910
749911 parser_conversion_config make_config () {
750912 auto &ctx = getContext ();
751- return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models };
913+ return { rewrite_pattern_set (&ctx), create_conversion_target (ctx), models,
914+ server.get () };
752915 }
753916
754917 function_models models;
918+ std::shared_ptr< vast::server::server< server_handler > > server;
755919 };
756920
757921} // namespace vast::conv
0 commit comments