diff --git a/wado-compiler/src/optimize/field_scalarize.rs b/wado-compiler/src/optimize/field_scalarize.rs index 29eeef1a2..284e86f20 100644 --- a/wado-compiler/src/optimize/field_scalarize.rs +++ b/wado-compiler/src/optimize/field_scalarize.rs @@ -931,8 +931,22 @@ fn scalarize_loop( } // Step 5: Replace field accesses in the loop body, and insert write-back/re-read - // around function calls (only for fields the callee actually accesses) - replace_in_block(loop_body, &candidates, local_types, type_table, cache); + // around function calls (only for fields the callee actually accesses). + // Collect all labels defined within the loop body: breaks targeting these labels + // stay within the HFS scope and don't need write-back. + let inner_labels = { + let mut labels = IndexSet::default(); + collect_inner_labels(loop_body, &mut labels); + labels + }; + replace_in_block( + loop_body, + &candidates, + local_types, + type_table, + cache, + &inner_labels, + ); ScalarizeResult { pre_stmts, @@ -1311,6 +1325,135 @@ struct ReplaceCtx<'a> { local_types: &'a [TypeId], type_table: &'a TypeTable, cache: &'a FieldUsageCache, + inner_labels: &'a IndexSet, +} + +/// Collect all labels defined within a block (at any nesting depth). +/// Used to determine whether a `break` targets an inner scope (no write-back +/// needed) or escapes the current block (write-back needed). +fn collect_inner_labels(block: &TirBlock, labels: &mut IndexSet) { + for stmt in &block.stmts { + collect_inner_labels_in_stmt(stmt, labels); + } +} + +fn collect_inner_labels_in_stmt(stmt: &TirStmt, labels: &mut IndexSet) { + match &stmt.kind { + TirStmtKind::LabeledBlock { label, block } => { + labels.insert(label.clone()); + collect_inner_labels(block, labels); + } + TirStmtKind::If { + then_block, + else_block, + .. + } => { + collect_inner_labels(then_block, labels); + if let Some(eb) = else_block { + collect_inner_labels(eb, labels); + } + } + TirStmtKind::IfLet { + then_block, + else_block, + .. + } => { + collect_inner_labels(then_block, labels); + if let Some(eb) = else_block { + collect_inner_labels(eb, labels); + } + } + TirStmtKind::Loop { body } => { + collect_inner_labels(body, labels); + } + TirStmtKind::Expr(expr) => { + collect_inner_labels_in_expr(expr, labels); + } + TirStmtKind::Let { value, .. } => { + collect_inner_labels_in_expr(value, labels); + } + _ => {} + } +} + +fn collect_inner_labels_in_expr(expr: &TirExpr, labels: &mut IndexSet) { + match &expr.kind { + TirExprKind::LabeledBlock { label, block, .. } => { + labels.insert(label.clone()); + collect_inner_labels(block, labels); + } + TirExprKind::Cast { expr: inner, .. } + | TirExprKind::Unary { expr: inner, .. } + | TirExprKind::FieldAccess { expr: inner, .. } => { + collect_inner_labels_in_expr(inner, labels); + } + TirExprKind::Binary { left, right, .. } + | TirExprKind::Assign { + target: left, + value: right, + } => { + collect_inner_labels_in_expr(left, labels); + collect_inner_labels_in_expr(right, labels); + } + TirExprKind::Call { args, .. } => { + for arg in args { + collect_inner_labels_in_expr(&arg.expr, labels); + } + } + TirExprKind::MethodCall { receiver, args, .. } => { + collect_inner_labels_in_expr(receiver, labels); + for arg in args { + collect_inner_labels_in_expr(&arg.expr, labels); + } + } + TirExprKind::If { + condition, + then_branch, + else_branch, + } => { + collect_inner_labels_in_expr(condition, labels); + collect_inner_labels(then_branch, labels); + if let Some(eb) = else_branch { + collect_inner_labels(eb, labels); + } + } + TirExprKind::Block(block) => { + collect_inner_labels(block, labels); + } + TirExprKind::Index { expr: e, index } => { + collect_inner_labels_in_expr(e, labels); + collect_inner_labels_in_expr(index, labels); + } + TirExprKind::StructLiteral { fields, .. } => { + for field in fields { + collect_inner_labels_in_expr(&field.value, labels); + } + } + TirExprKind::TupleLiteral { elements } => { + for elem in elements { + collect_inner_labels_in_expr(elem, labels); + } + } + TirExprKind::Switch { + scrutinee, + arms, + default, + .. + } => { + collect_inner_labels_in_expr(scrutinee, labels); + for arm in arms { + collect_inner_labels(arm, labels); + } + collect_inner_labels(default, labels); + } + TirExprKind::Match { expr: e, arms } => { + collect_inner_labels_in_expr(e, labels); + for arm in arms { + collect_inner_labels_in_expr(&arm.body, labels); + } + } + _ => {} + } } fn replace_in_block( @@ -1319,11 +1462,13 @@ fn replace_in_block( local_types: &[TypeId], type_table: &TypeTable, cache: &FieldUsageCache, + inner_labels: &IndexSet, ) { let ctx = ReplaceCtx { local_types, type_table, cache, + inner_labels, }; let span = crate::token::Span::new(0, 0, 0, 0); let mut new_stmts = Vec::new(); @@ -1361,9 +1506,16 @@ fn replace_in_block( } } replace_in_expr(condition, candidates, &ctx); - replace_in_block(then_block, candidates, local_types, type_table, cache); + replace_in_block( + then_block, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); if let Some(eb) = else_block { - replace_in_block(eb, candidates, local_types, type_table, cache); + replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels); } new_stmts.push(stmt); for c in candidates { @@ -1400,9 +1552,16 @@ fn replace_in_block( } } replace_in_expr(scrutinee, candidates, &ctx); - replace_in_block(then_block, candidates, local_types, type_table, cache); + replace_in_block( + then_block, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); if let Some(eb) = else_block { - replace_in_block(eb, candidates, local_types, type_table, cache); + replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels); } new_stmts.push(stmt); for c in candidates { @@ -1413,12 +1572,26 @@ fn replace_in_block( continue; } TirStmtKind::Loop { body } => { - replace_in_block(body, candidates, local_types, type_table, cache); + replace_in_block( + body, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); new_stmts.push(stmt); continue; } TirStmtKind::LabeledBlock { block: inner, .. } => { - replace_in_block(inner, candidates, local_types, type_table, cache); + replace_in_block( + inner, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); new_stmts.push(stmt); continue; } @@ -1436,9 +1609,23 @@ fn replace_in_block( { replace_in_expr(scrutinee, candidates, &ctx); for arm in arms { - replace_in_block(arm, candidates, local_types, type_table, cache); + replace_in_block( + arm, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); } - replace_in_block(default, candidates, local_types, type_table, cache); + replace_in_block( + default, + candidates, + local_types, + type_table, + cache, + inner_labels, + ); } new_stmts.push(stmt); continue; @@ -1446,15 +1633,14 @@ fn replace_in_block( _ => {} } - // Insert write-back before return/break statements, since they exit the - // current scope and would skip the post-loop write-back. - // - // `break