Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 206 additions & 19 deletions wado-compiler/src/optimize/field_scalarize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -931,8 +931,22 @@ fn scalarize_loop(
}

// Step 5: Replace field accesses in the loop body, and insert write-back/re-read
// around function calls (only for fields the callee actually accesses)
replace_in_block(loop_body, &candidates, local_types, type_table, cache);
// around function calls (only for fields the callee actually accesses).
// Collect all labels defined within the loop body: breaks targeting these labels
// stay within the HFS scope and don't need write-back.
let inner_labels = {
let mut labels = IndexSet::default();
collect_inner_labels(loop_body, &mut labels);
labels
};
replace_in_block(
loop_body,
&candidates,
local_types,
type_table,
cache,
&inner_labels,
);

ScalarizeResult {
pre_stmts,
Expand Down Expand Up @@ -1311,6 +1325,135 @@ struct ReplaceCtx<'a> {
local_types: &'a [TypeId],
type_table: &'a TypeTable,
cache: &'a FieldUsageCache,
inner_labels: &'a IndexSet<String>,
}

/// Collect all labels defined within a block (at any nesting depth).
/// Used to determine whether a `break` targets an inner scope (no write-back
/// needed) or escapes the current block (write-back needed).
fn collect_inner_labels(block: &TirBlock, labels: &mut IndexSet<String>) {
for stmt in &block.stmts {
collect_inner_labels_in_stmt(stmt, labels);
}
}

fn collect_inner_labels_in_stmt(stmt: &TirStmt, labels: &mut IndexSet<String>) {
match &stmt.kind {
TirStmtKind::LabeledBlock { label, block } => {
labels.insert(label.clone());
collect_inner_labels(block, labels);
}
TirStmtKind::If {
then_block,
else_block,
..
} => {
collect_inner_labels(then_block, labels);
if let Some(eb) = else_block {
collect_inner_labels(eb, labels);
}
}
TirStmtKind::IfLet {
then_block,
else_block,
..
} => {
collect_inner_labels(then_block, labels);
if let Some(eb) = else_block {
collect_inner_labels(eb, labels);
}
}
TirStmtKind::Loop { body } => {
collect_inner_labels(body, labels);
}
TirStmtKind::Expr(expr) => {
collect_inner_labels_in_expr(expr, labels);
}
TirStmtKind::Let { value, .. } => {
collect_inner_labels_in_expr(value, labels);
}
_ => {}
}
}

fn collect_inner_labels_in_expr(expr: &TirExpr, labels: &mut IndexSet<String>) {
match &expr.kind {
TirExprKind::LabeledBlock { label, block, .. } => {
labels.insert(label.clone());
collect_inner_labels(block, labels);
}
TirExprKind::Cast { expr: inner, .. }
| TirExprKind::Unary { expr: inner, .. }
| TirExprKind::FieldAccess { expr: inner, .. } => {
collect_inner_labels_in_expr(inner, labels);
}
TirExprKind::Binary { left, right, .. }
| TirExprKind::Assign {
target: left,
value: right,
} => {
collect_inner_labels_in_expr(left, labels);
collect_inner_labels_in_expr(right, labels);
}
TirExprKind::Call { args, .. } => {
for arg in args {
collect_inner_labels_in_expr(&arg.expr, labels);
}
}
TirExprKind::MethodCall { receiver, args, .. } => {
collect_inner_labels_in_expr(receiver, labels);
for arg in args {
collect_inner_labels_in_expr(&arg.expr, labels);
}
}
TirExprKind::If {
condition,
then_branch,
else_branch,
} => {
collect_inner_labels_in_expr(condition, labels);
collect_inner_labels(then_branch, labels);
if let Some(eb) = else_branch {
collect_inner_labels(eb, labels);
}
}
TirExprKind::Block(block) => {
collect_inner_labels(block, labels);
}
TirExprKind::Index { expr: e, index } => {
collect_inner_labels_in_expr(e, labels);
collect_inner_labels_in_expr(index, labels);
}
TirExprKind::StructLiteral { fields, .. } => {
for field in fields {
collect_inner_labels_in_expr(&field.value, labels);
}
}
TirExprKind::TupleLiteral { elements } => {
for elem in elements {
collect_inner_labels_in_expr(elem, labels);
}
}
TirExprKind::Switch {
scrutinee,
arms,
default,
..
} => {
collect_inner_labels_in_expr(scrutinee, labels);
for arm in arms {
collect_inner_labels(arm, labels);
}
collect_inner_labels(default, labels);
}
TirExprKind::Match { expr: e, arms } => {
collect_inner_labels_in_expr(e, labels);
for arm in arms {
collect_inner_labels_in_expr(&arm.body, labels);
}
}
_ => {}
}
}

fn replace_in_block(
Expand All @@ -1319,11 +1462,13 @@ fn replace_in_block(
local_types: &[TypeId],
type_table: &TypeTable,
cache: &FieldUsageCache,
inner_labels: &IndexSet<String>,
) {
let ctx = ReplaceCtx {
local_types,
type_table,
cache,
inner_labels,
};
let span = crate::token::Span::new(0, 0, 0, 0);
let mut new_stmts = Vec::new();
Expand Down Expand Up @@ -1361,9 +1506,16 @@ fn replace_in_block(
}
}
replace_in_expr(condition, candidates, &ctx);
replace_in_block(then_block, candidates, local_types, type_table, cache);
replace_in_block(
then_block,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
if let Some(eb) = else_block {
replace_in_block(eb, candidates, local_types, type_table, cache);
replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels);
}
new_stmts.push(stmt);
for c in candidates {
Expand Down Expand Up @@ -1400,9 +1552,16 @@ fn replace_in_block(
}
}
replace_in_expr(scrutinee, candidates, &ctx);
replace_in_block(then_block, candidates, local_types, type_table, cache);
replace_in_block(
then_block,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
if let Some(eb) = else_block {
replace_in_block(eb, candidates, local_types, type_table, cache);
replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels);
}
new_stmts.push(stmt);
for c in candidates {
Expand All @@ -1413,12 +1572,26 @@ fn replace_in_block(
continue;
}
TirStmtKind::Loop { body } => {
replace_in_block(body, candidates, local_types, type_table, cache);
replace_in_block(
body,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
new_stmts.push(stmt);
continue;
}
TirStmtKind::LabeledBlock { block: inner, .. } => {
replace_in_block(inner, candidates, local_types, type_table, cache);
replace_in_block(
inner,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
new_stmts.push(stmt);
continue;
}
Expand All @@ -1436,25 +1609,38 @@ fn replace_in_block(
{
replace_in_expr(scrutinee, candidates, &ctx);
for arm in arms {
replace_in_block(arm, candidates, local_types, type_table, cache);
replace_in_block(
arm,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
}
replace_in_block(default, candidates, local_types, type_table, cache);
replace_in_block(
default,
candidates,
local_types,
type_table,
cache,
inner_labels,
);
}
new_stmts.push(stmt);
continue;
}
_ => {}
}

// Insert write-back before return/break statements, since they exit the
// current scope and would skip the post-loop write-back.
//
// `break <label>` in a for-loop desugars to a labeled block exit that
// skips post-loop write-back statements placed inside the labeled block.
if matches!(
stmt.kind,
TirStmtKind::Return { .. } | TirStmtKind::Break { .. }
) {
// Insert write-back before return/break statements that escape this
// block's scope. Breaks targeting labels defined *within* this block
// stay in scope (the block's own exit handles write-back), so they
// are excluded.
if matches!(stmt.kind, TirStmtKind::Return { .. })
|| matches!(&stmt.kind, TirStmtKind::Break { label, .. }
if !label.as_ref().is_some_and(|l| inner_labels.contains(l.as_str())))
{
replace_in_stmt(&mut stmt, candidates, &ctx);
new_stmts.extend(make_write_back_stmts(candidates, span));
new_stmts.push(stmt);
Expand Down Expand Up @@ -2157,6 +2343,7 @@ fn replace_in_expr(expr: &mut TirExpr, candidates: &[ScalarizeCandidate], ctx: &
ctx.local_types,
ctx.type_table,
ctx.cache,
ctx.inner_labels,
);
}
TirExprKind::GlobalVarSet { value, .. } => {
Expand Down
2 changes: 0 additions & 2 deletions wado-compiler/tests/fixtures.golden/base64_decode.wir.wado

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions wado-compiler/tests/fixtures.golden/base64_encode.wir.wado

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading