diff --git a/implants/lib/eldritch/eldritch-core/src/interpreter/core.rs b/implants/lib/eldritch/eldritch-core/src/interpreter/core.rs index 3e20729e3..a1d77719b 100644 --- a/implants/lib/eldritch/eldritch-core/src/interpreter/core.rs +++ b/implants/lib/eldritch/eldritch-core/src/interpreter/core.rs @@ -202,6 +202,10 @@ impl Interpreter { self.call_stack.clear(); self.current_func_name = "".to_string(); + if let Err(e) = exec::hoist_functions(self, &stmts) { + return Err(self.format_error(input, e)); + } + for stmt in stmts { match &stmt.kind { // Special case: if top-level statement is an expression, return its value diff --git a/implants/lib/eldritch/eldritch-core/src/interpreter/exec.rs b/implants/lib/eldritch/eldritch-core/src/interpreter/exec.rs index f3271246b..3a1fe7378 100644 --- a/implants/lib/eldritch/eldritch-core/src/interpreter/exec.rs +++ b/implants/lib/eldritch/eldritch-core/src/interpreter/exec.rs @@ -155,6 +155,46 @@ pub fn execute(interp: &mut Interpreter, stmt: &Stmt) -> Result<(), EldritchErro Ok(()) } +pub fn hoist_functions(interp: &mut Interpreter, stmts: &[Stmt]) -> Result<(), EldritchError> { + // Collect functions to hoist so we don't hold read locks while evaluating default params + // Only hoist the *first* definition of a given name in this block, to allow forward references, + // while sequential execution will properly overwrite later definitions. + let mut to_hoist = Vec::new(); + let mut seen = BTreeSet::new(); + + for stmt in stmts { + if let StmtKind::Def(name, params, _return_annotation, body) = &stmt.kind { + if !seen.contains(name) { + seen.insert(name.clone()); + to_hoist.push((name, params, body)); + } + } + } + + for (name, params, body) in to_hoist { + let mut runtime_params = Vec::new(); + for param in params { + match param { + Param::Normal(n, _) => runtime_params.push(RuntimeParam::Normal(n.clone())), + Param::Star(n, _) => runtime_params.push(RuntimeParam::Star(n.clone())), + Param::StarStar(n, _) => runtime_params.push(RuntimeParam::StarStar(n.clone())), + Param::WithDefault(n, _, _) => { + runtime_params.push(RuntimeParam::WithDefault(n.clone(), Value::None)); + } + } + } + + let func = Value::Function(Function { + name: name.clone(), + params: runtime_params, + body: body.clone(), + closure: interp.env.clone(), + }); + interp.env.write().values.insert(name.clone(), func); + } + Ok(()) +} + pub fn execute_stmts(interp: &mut Interpreter, stmts: &[Stmt]) -> Result<(), EldritchError> { for stmt in stmts { execute(interp, stmt)?; diff --git a/implants/lib/eldritch/eldritch-core/tests/regression_forward_ref.rs b/implants/lib/eldritch/eldritch-core/tests/regression_forward_ref.rs new file mode 100644 index 000000000..17b7a1b0b --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/regression_forward_ref.rs @@ -0,0 +1,50 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_forward_reference() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + r#" +def b(): + a() + +def a(): + print("A") + +b() +"#, + ); + if let Err(e) = res { + panic!("Failed with: {}", e); + } +} + +#[test] +fn test_redefine() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + r#" +def b(): + return a() + +def a(): + return "A" + +res1 = b() + +def a(): + return "A2" + +res2 = b() +"#, + ); + if let Err(e) = res { + panic!("Failed with: {}", e); + } + // Verify values + let env = interp.env.read(); + let res1 = env.values.get("res1").unwrap(); + let res2 = env.values.get("res2").unwrap(); + assert_eq!(res1, &eldritch_core::Value::String("A".to_string())); + assert_eq!(res2, &eldritch_core::Value::String("A2".to_string())); +} diff --git a/implants/lib/eldritch/eldritch-core/tests/test_eval.rs b/implants/lib/eldritch/eldritch-core/tests/test_eval.rs new file mode 100644 index 000000000..27a215548 --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/test_eval.rs @@ -0,0 +1,18 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_manual() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + " +def b(): + a() + +def a(): + print(\"A\") + +b() +", + ); + println!("result is {:?}", res); +} diff --git a/implants/lib/eldritch/eldritch-core/tests/test_eval_2.rs b/implants/lib/eldritch/eldritch-core/tests/test_eval_2.rs new file mode 100644 index 000000000..845e841df --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/test_eval_2.rs @@ -0,0 +1,19 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_manual() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + " +def b(): + a() + +def a(): + print(\"A\") + +b() +", + ); + println!("result is {:?}", res); + assert!(res.is_ok()); +} diff --git a/implants/lib/eldritch/eldritch-core/tests/test_eval_split.rs b/implants/lib/eldritch/eldritch-core/tests/test_eval_split.rs new file mode 100644 index 000000000..73d10f0d2 --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/test_eval_split.rs @@ -0,0 +1,14 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_manual() { + let mut interp = Interpreter::new(); + let res = interp.interpret("def b():\n a()"); + println!("def b: {:?}", res); + + let res = interp.interpret("def a():\n print(\"A\")"); + println!("def a: {:?}", res); + + let res = interp.interpret("b()"); + println!("result is {:?}", res); +} diff --git a/implants/lib/eldritch/eldritch-core/tests/test_hoist.rs b/implants/lib/eldritch/eldritch-core/tests/test_hoist.rs new file mode 100644 index 000000000..34d268ea3 --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/test_hoist.rs @@ -0,0 +1,19 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_hoist() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + " +def b(): + a() + +b() + +def a(): + print(\"A\") +", + ); + println!("result is {:?}", res); + assert!(res.is_ok()); +} diff --git a/implants/lib/eldritch/eldritch-core/tests/test_nested.rs b/implants/lib/eldritch/eldritch-core/tests/test_nested.rs new file mode 100644 index 000000000..dab6f11f5 --- /dev/null +++ b/implants/lib/eldritch/eldritch-core/tests/test_nested.rs @@ -0,0 +1,21 @@ +use eldritch_core::Interpreter; + +#[test] +fn test_nested() { + let mut interp = Interpreter::new(); + let res = interp.interpret( + " +def c(): + def b(): + a() + def a(): + print(\"A\") + b() + +c() +", + ); + if let Err(e) = res { + panic!("Failed with: {}", e); + } +}