Skip to content

Commit baba4b9

Browse files
mandelstarpit
authored andcommitted
feat: update rust Repeat AST to use Expr for for attr (IBM#904)
Signed-off-by: Louis Mandel <[email protected]> Signed-off-by: Nick Mitchell <[email protected]>
1 parent a0855d3 commit baba4b9

File tree

5 files changed

+117
-104
lines changed

5 files changed

+117
-104
lines changed

pdl-live-react/src-tauri/src/compile/beeai.rs

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ use serde_json::{Map, Value, from_reader, json, to_string};
1212
use tempfile::Builder;
1313

1414
use crate::pdl::ast::{
15-
ArrayBlockBuilder, CallBlock, EvalsTo, FunctionBlock, ListOrString, MessageBlock,
16-
MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock, PdlOptionalType,
17-
PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role, TextBlock, TextBlockBuilder,
15+
ArrayBlockBuilder, Block::*, CallBlock, EvalsTo, Expr, FunctionBlock, ListOrString,
16+
MessageBlock, MetadataBuilder, ModelBlockBuilder, ObjectBlock, PdlBaseType, PdlBlock,
17+
PdlBlock::Advanced, PdlOptionalType, PdlParser, PdlType, PythonCodeBlock, RepeatBlock, Role,
18+
TextBlock, TextBlockBuilder,
1819
};
1920
use crate::pdl::pip::pip_install_if_needed;
2021
use crate::pdl::requirements::BEEAI_FRAMEWORK;
@@ -190,7 +191,7 @@ fn with_tools(
190191
}
191192

192193
fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
193-
let repeat = PdlBlock::Text(TextBlock {
194+
let repeat = Advanced(Text(TextBlock {
194195
metadata: Some(
195196
MetadataBuilder::default()
196197
.description("Calling tool ${ tool.function.name }".to_string())
@@ -199,19 +200,19 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
199200
),
200201
role: None,
201202
parser: None,
202-
text: vec![PdlBlock::Model(
203+
text: vec![Advanced(Model(
203204
ModelBlockBuilder::default()
204205
.model(model.as_str())
205206
.parameters(strip_nulls(parameters))
206-
.input(PdlBlock::Array(
207+
.input(Advanced(Array(
207208
ArrayBlockBuilder::default()
208-
.array(vec![PdlBlock::Message(MessageBlock {
209+
.array(vec![Advanced(Message(MessageBlock {
209210
metadata: None,
210211
role: Role::Tool,
211212
defsite: None,
212213
name: Some("${ tool.function.name }".to_string()),
213214
tool_call_id: Some("${ tool.id }".to_string()),
214-
content: Box::new(PdlBlock::Call(CallBlock {
215+
content: Box::new(Advanced(Call(CallBlock {
215216
metadata: Some(
216217
MetadataBuilder::default()
217218
.defs(json_loads(
@@ -226,28 +227,33 @@ fn call_tools(model: &String, parameters: &HashMap<String, Value>) -> PdlBlock {
226227
"${ pdl__tools[tool.function.name] }".to_string(),
227228
), // look up tool in tool_declarations def (see below)
228229
args: Some("${ args }".into()), // invoke with arguments as specified by the model
229-
})),
230-
})])
230+
}))),
231+
}))])
231232
.build()
232233
.unwrap(),
233-
))
234+
)))
234235
.build()
235236
.unwrap(),
236-
)],
237-
});
237+
))],
238+
}));
238239

239240
let mut for_ = HashMap::new();
240241
for_.insert(
241242
"tool".to_string(),
242-
ListOrString::String("${ response.choices[0].message.tool_calls }".to_string()),
243+
EvalsTo::Expr(Expr {
244+
pdl_expr: ListOrString::String(
245+
"${ response.choices[0].message.tool_calls }".to_string(),
246+
),
247+
pdl_result: None,
248+
}),
243249
);
244250

245251
// response.choices[0].message.tool_calls
246-
PdlBlock::Repeat(RepeatBlock {
252+
Advanced(Repeat(RepeatBlock {
247253
metadata: None,
248254
for_: for_,
249255
repeat: Box::new(repeat),
250-
})
256+
}))
251257
}
252258

253259
fn json_loads(
@@ -258,7 +264,7 @@ fn json_loads(
258264
let mut m = indexmap::IndexMap::new();
259265
m.insert(
260266
outer_name.to_owned(),
261-
PdlBlock::Text(
267+
Advanced(Text(
262268
TextBlockBuilder::default()
263269
.text(vec![PdlBlock::String(format!(
264270
"{{\"{}\": {}}}",
@@ -273,7 +279,7 @@ fn json_loads(
273279
.parser(PdlParser::Json)
274280
.build()
275281
.unwrap(),
276-
),
282+
)),
277283
);
278284
m
279285
}
@@ -422,7 +428,7 @@ pub fn compile(source_file_path: &str, debug: bool) -> Result<PdlBlock, Box<dyn
422428
tool_name.clone(),
423429
PdlBlock::Function(FunctionBlock {
424430
function: schema,
425-
return_: Box::new(PdlBlock::PythonCode(PythonCodeBlock {
431+
return_: Box::new(Advanced(PythonCode(PythonCodeBlock {
426432
// tool function definition
427433
metadata: None,
428434
lang: "python".to_string(),
@@ -456,7 +462,7 @@ asyncio.run(invoke())
456462
"".to_string()
457463
}
458464
),
459-
})),
465+
}))),
460466
}),
461467
)
462468
})
@@ -485,7 +491,7 @@ asyncio.run(invoke())
485491
let model = format!("{}/{}", provider, model);
486492

487493
if let Some(instructions) = instructions {
488-
model_call.push(PdlBlock::Text(TextBlock {
494+
model_call.push(Advanced(Text(TextBlock {
489495
role: Some(Role::System),
490496
text: vec![PdlBlock::String(instructions)],
491497
metadata: Some(
@@ -495,7 +501,7 @@ asyncio.run(invoke())
495501
.unwrap(),
496502
),
497503
parser: None,
498-
}));
504+
})));
499505
}
500506

501507
let mut model_builder = ModelBlockBuilder::default();
@@ -518,7 +524,7 @@ asyncio.run(invoke())
518524
}
519525
}
520526

521-
model_call.push(PdlBlock::Model(model_builder.build().unwrap()));
527+
model_call.push(Advanced(Model(model_builder.build().unwrap())));
522528

523529
if let Some(tools) = tools {
524530
if tools.len() > 0 {
@@ -532,7 +538,7 @@ asyncio.run(invoke())
532538
closure_name.clone(),
533539
PdlBlock::Function(FunctionBlock {
534540
function: HashMap::new(),
535-
return_: Box::new(PdlBlock::Text(TextBlock {
541+
return_: Box::new(Advanced(Text(TextBlock {
536542
metadata: Some(
537543
MetadataBuilder::default()
538544
.description(format!("Model call {}", &model))
@@ -542,10 +548,10 @@ asyncio.run(invoke())
542548
role: None,
543549
parser: None,
544550
text: model_call,
545-
})),
551+
}))),
546552
}),
547553
);
548-
PdlBlock::Text(TextBlock {
554+
Advanced(Text(TextBlock {
549555
metadata: Some(
550556
MetadataBuilder::default()
551557
.description("Model call wrapper".to_string())
@@ -555,11 +561,11 @@ asyncio.run(invoke())
555561
),
556562
role: None,
557563
parser: None,
558-
text: vec![PdlBlock::Call(CallBlock::new(format!(
564+
text: vec![Advanced(Call(CallBlock::new(format!(
559565
"${{ {} }}",
560566
closure_name
561-
)))],
562-
})
567+
))))],
568+
}))
563569
},
564570
)
565571
.collect::<Vec<_>>();
@@ -574,19 +580,19 @@ asyncio.run(invoke())
574580
let mut defs = indexmap::IndexMap::new();
575581
defs.insert(
576582
"pdl__tools".to_string(),
577-
PdlBlock::Object(ObjectBlock {
583+
Advanced(Object(ObjectBlock {
578584
object: tool_declarations,
579-
}),
585+
})),
580586
);
581587
metadata.defs(defs);
582588
}
583589

584-
let pdl: PdlBlock = PdlBlock::Text(TextBlock {
590+
let pdl: PdlBlock = Advanced(Text(TextBlock {
585591
metadata: Some(metadata.build().unwrap()),
586592
role: None,
587593
parser: None,
588594
text: body,
589-
});
595+
}));
590596

591597
Ok(pdl)
592598
}

pdl-live-react/src-tauri/src/pdl/ast.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl SequencingBlock for LastOfBlock {
183183
&self.parser
184184
}
185185
fn to_block(&self) -> PdlBlock {
186-
PdlBlock::LastOf(self.clone())
186+
PdlBlock::Advanced(Block::LastOf(self.clone()))
187187
}
188188
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
189189
match output_results.last() {
@@ -243,7 +243,7 @@ impl SequencingBlock for TextBlock {
243243
&self.parser
244244
}
245245
fn to_block(&self) -> PdlBlock {
246-
PdlBlock::Text(self.clone())
246+
PdlBlock::Advanced(Block::Text(self.clone()))
247247
}
248248
fn result_for(&self, output_results: Vec<PdlResult>) -> PdlResult {
249249
PdlResult::String(
@@ -344,7 +344,7 @@ pub struct RepeatBlock {
344344

345345
/// Arrays to iterate over
346346
#[serde(rename = "for")]
347-
pub for_: HashMap<String, ListOrString>,
347+
pub for_: HashMap<String, EvalsTo<ListOrString, Vec<PdlResult>>>,
348348

349349
/// Body of the loop
350350
pub repeat: Box<PdlBlock>,
@@ -596,8 +596,15 @@ pub enum PdlBlock {
596596
Number(Number),
597597
String(String),
598598
Function(FunctionBlock),
599+
Advanced(Block),
600+
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
601+
Empty(EmptyBlock),
602+
}
599603

600-
// the rest have Metadata; TODO refactor to make this more explicit
604+
/// A PDL block that has structure and metadata
605+
#[derive(Serialize, Deserialize, Debug, Clone)]
606+
#[serde(untagged)]
607+
pub enum Block {
601608
If(IfBlock),
602609
Import(ImportBlock),
603610
Include(IncludeBlock),
@@ -612,9 +619,6 @@ pub enum PdlBlock {
612619
Model(ModelBlock),
613620
LastOf(LastOfBlock),
614621
Text(TextBlock),
615-
616-
// must be last to prevent serde from aggressively matching on it, since other block types also (may) have a `defs`
617-
Empty(EmptyBlock),
618622
}
619623

620624
impl From<bool> for PdlBlock {

pdl-live-react/src-tauri/src/pdl/extract.rs

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::pdl::ast::{Metadata, PdlBlock};
1+
use crate::pdl::ast::{Block::*, Metadata, PdlBlock, PdlBlock::Advanced};
22

33
/// Extract models referenced by the programs
44
pub fn extract_models(program: &PdlBlock) -> Vec<String> {
@@ -20,18 +20,26 @@ pub fn extract_values(program: &PdlBlock, field: &str) -> Vec<String> {
2020
/// Take one Yaml fragment and produce a vector of the string-valued entries of the given field
2121
fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>) {
2222
match program {
23-
PdlBlock::Model(b) => values.push(b.model.clone()),
24-
PdlBlock::Repeat(b) => {
23+
PdlBlock::Empty(b) => {
24+
b.defs
25+
.values()
26+
.for_each(|p| extract_values_iter(p, field, values));
27+
}
28+
PdlBlock::Function(b) => {
29+
extract_values_iter(&b.return_, field, values);
30+
}
31+
Advanced(Model(b)) => values.push(b.model.clone()),
32+
Advanced(Repeat(b)) => {
2533
extract_values_iter(&b.repeat, field, values);
2634
}
27-
PdlBlock::Message(b) => {
35+
Advanced(Message(b)) => {
2836
extract_values_iter(&b.content, field, values);
2937
}
30-
PdlBlock::Array(b) => b
38+
Advanced(Array(b)) => b
3139
.array
3240
.iter()
3341
.for_each(|p| extract_values_iter(p, field, values)),
34-
PdlBlock::Text(b) => {
42+
Advanced(Text(b)) => {
3543
b.text
3644
.iter()
3745
.for_each(|p| extract_values_iter(p, field, values));
@@ -43,7 +51,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
4351
.for_each(|p| extract_values_iter(p, field, values));
4452
}
4553
}
46-
PdlBlock::LastOf(b) => {
54+
Advanced(LastOf(b)) => {
4755
b.last_of
4856
.iter()
4957
.for_each(|p| extract_values_iter(p, field, values));
@@ -55,7 +63,7 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
5563
.for_each(|p| extract_values_iter(p, field, values));
5664
}
5765
}
58-
PdlBlock::If(b) => {
66+
Advanced(If(b)) => {
5967
extract_values_iter(&b.then, field, values);
6068
if let Some(else_) = &b.else_ {
6169
extract_values_iter(else_, field, values);
@@ -68,20 +76,11 @@ fn extract_values_iter(program: &PdlBlock, field: &str, values: &mut Vec<String>
6876
.for_each(|p| extract_values_iter(p, field, values));
6977
}
7078
}
71-
PdlBlock::Empty(b) => {
72-
b.defs
73-
.values()
74-
.for_each(|p| extract_values_iter(p, field, values));
75-
}
76-
PdlBlock::Object(b) => b
79+
Advanced(Object(b)) => b
7780
.object
7881
.values()
7982
.for_each(|p| extract_values_iter(p, field, values)),
8083

81-
PdlBlock::Function(b) => {
82-
extract_values_iter(&b.return_, field, values);
83-
}
84-
8584
_ => {}
8685
}
8786
}

0 commit comments

Comments
 (0)