Skip to content

Commit ec4527a

Browse files
authored
Merge pull request #734 from wado-lang/claude/optimize-json-twitter-WBQFl
Skip HFS write-back before inlined labeled block breaks
2 parents 2a7152e + a853c0d commit ec4527a

32 files changed

+779
-574
lines changed

wado-compiler/src/optimize/field_scalarize.rs

Lines changed: 206 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,22 @@ fn scalarize_loop(
931931
}
932932

933933
// Step 5: Replace field accesses in the loop body, and insert write-back/re-read
934-
// around function calls (only for fields the callee actually accesses)
935-
replace_in_block(loop_body, &candidates, local_types, type_table, cache);
934+
// around function calls (only for fields the callee actually accesses).
935+
// Collect all labels defined within the loop body: breaks targeting these labels
936+
// stay within the HFS scope and don't need write-back.
937+
let inner_labels = {
938+
let mut labels = IndexSet::default();
939+
collect_inner_labels(loop_body, &mut labels);
940+
labels
941+
};
942+
replace_in_block(
943+
loop_body,
944+
&candidates,
945+
local_types,
946+
type_table,
947+
cache,
948+
&inner_labels,
949+
);
936950

937951
ScalarizeResult {
938952
pre_stmts,
@@ -1311,6 +1325,135 @@ struct ReplaceCtx<'a> {
13111325
local_types: &'a [TypeId],
13121326
type_table: &'a TypeTable,
13131327
cache: &'a FieldUsageCache,
1328+
inner_labels: &'a IndexSet<String>,
1329+
}
1330+
1331+
/// Collect all labels defined within a block (at any nesting depth).
1332+
/// Used to determine whether a `break` targets an inner scope (no write-back
1333+
/// needed) or escapes the current block (write-back needed).
1334+
fn collect_inner_labels(block: &TirBlock, labels: &mut IndexSet<String>) {
1335+
for stmt in &block.stmts {
1336+
collect_inner_labels_in_stmt(stmt, labels);
1337+
}
1338+
}
1339+
1340+
fn collect_inner_labels_in_stmt(stmt: &TirStmt, labels: &mut IndexSet<String>) {
1341+
match &stmt.kind {
1342+
TirStmtKind::LabeledBlock { label, block } => {
1343+
labels.insert(label.clone());
1344+
collect_inner_labels(block, labels);
1345+
}
1346+
TirStmtKind::If {
1347+
then_block,
1348+
else_block,
1349+
..
1350+
} => {
1351+
collect_inner_labels(then_block, labels);
1352+
if let Some(eb) = else_block {
1353+
collect_inner_labels(eb, labels);
1354+
}
1355+
}
1356+
TirStmtKind::IfLet {
1357+
then_block,
1358+
else_block,
1359+
..
1360+
} => {
1361+
collect_inner_labels(then_block, labels);
1362+
if let Some(eb) = else_block {
1363+
collect_inner_labels(eb, labels);
1364+
}
1365+
}
1366+
TirStmtKind::Loop { body } => {
1367+
collect_inner_labels(body, labels);
1368+
}
1369+
TirStmtKind::Expr(expr) => {
1370+
collect_inner_labels_in_expr(expr, labels);
1371+
}
1372+
TirStmtKind::Let { value, .. } => {
1373+
collect_inner_labels_in_expr(value, labels);
1374+
}
1375+
_ => {}
1376+
}
1377+
}
1378+
1379+
fn collect_inner_labels_in_expr(expr: &TirExpr, labels: &mut IndexSet<String>) {
1380+
match &expr.kind {
1381+
TirExprKind::LabeledBlock { label, block, .. } => {
1382+
labels.insert(label.clone());
1383+
collect_inner_labels(block, labels);
1384+
}
1385+
TirExprKind::Cast { expr: inner, .. }
1386+
| TirExprKind::Unary { expr: inner, .. }
1387+
| TirExprKind::FieldAccess { expr: inner, .. } => {
1388+
collect_inner_labels_in_expr(inner, labels);
1389+
}
1390+
TirExprKind::Binary { left, right, .. }
1391+
| TirExprKind::Assign {
1392+
target: left,
1393+
value: right,
1394+
} => {
1395+
collect_inner_labels_in_expr(left, labels);
1396+
collect_inner_labels_in_expr(right, labels);
1397+
}
1398+
TirExprKind::Call { args, .. } => {
1399+
for arg in args {
1400+
collect_inner_labels_in_expr(&arg.expr, labels);
1401+
}
1402+
}
1403+
TirExprKind::MethodCall { receiver, args, .. } => {
1404+
collect_inner_labels_in_expr(receiver, labels);
1405+
for arg in args {
1406+
collect_inner_labels_in_expr(&arg.expr, labels);
1407+
}
1408+
}
1409+
TirExprKind::If {
1410+
condition,
1411+
then_branch,
1412+
else_branch,
1413+
} => {
1414+
collect_inner_labels_in_expr(condition, labels);
1415+
collect_inner_labels(then_branch, labels);
1416+
if let Some(eb) = else_branch {
1417+
collect_inner_labels(eb, labels);
1418+
}
1419+
}
1420+
TirExprKind::Block(block) => {
1421+
collect_inner_labels(block, labels);
1422+
}
1423+
TirExprKind::Index { expr: e, index } => {
1424+
collect_inner_labels_in_expr(e, labels);
1425+
collect_inner_labels_in_expr(index, labels);
1426+
}
1427+
TirExprKind::StructLiteral { fields, .. } => {
1428+
for field in fields {
1429+
collect_inner_labels_in_expr(&field.value, labels);
1430+
}
1431+
}
1432+
TirExprKind::TupleLiteral { elements } => {
1433+
for elem in elements {
1434+
collect_inner_labels_in_expr(elem, labels);
1435+
}
1436+
}
1437+
TirExprKind::Switch {
1438+
scrutinee,
1439+
arms,
1440+
default,
1441+
..
1442+
} => {
1443+
collect_inner_labels_in_expr(scrutinee, labels);
1444+
for arm in arms {
1445+
collect_inner_labels(arm, labels);
1446+
}
1447+
collect_inner_labels(default, labels);
1448+
}
1449+
TirExprKind::Match { expr: e, arms } => {
1450+
collect_inner_labels_in_expr(e, labels);
1451+
for arm in arms {
1452+
collect_inner_labels_in_expr(&arm.body, labels);
1453+
}
1454+
}
1455+
_ => {}
1456+
}
13141457
}
13151458

13161459
fn replace_in_block(
@@ -1319,11 +1462,13 @@ fn replace_in_block(
13191462
local_types: &[TypeId],
13201463
type_table: &TypeTable,
13211464
cache: &FieldUsageCache,
1465+
inner_labels: &IndexSet<String>,
13221466
) {
13231467
let ctx = ReplaceCtx {
13241468
local_types,
13251469
type_table,
13261470
cache,
1471+
inner_labels,
13271472
};
13281473
let span = crate::token::Span::new(0, 0, 0, 0);
13291474
let mut new_stmts = Vec::new();
@@ -1361,9 +1506,16 @@ fn replace_in_block(
13611506
}
13621507
}
13631508
replace_in_expr(condition, candidates, &ctx);
1364-
replace_in_block(then_block, candidates, local_types, type_table, cache);
1509+
replace_in_block(
1510+
then_block,
1511+
candidates,
1512+
local_types,
1513+
type_table,
1514+
cache,
1515+
inner_labels,
1516+
);
13651517
if let Some(eb) = else_block {
1366-
replace_in_block(eb, candidates, local_types, type_table, cache);
1518+
replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels);
13671519
}
13681520
new_stmts.push(stmt);
13691521
for c in candidates {
@@ -1400,9 +1552,16 @@ fn replace_in_block(
14001552
}
14011553
}
14021554
replace_in_expr(scrutinee, candidates, &ctx);
1403-
replace_in_block(then_block, candidates, local_types, type_table, cache);
1555+
replace_in_block(
1556+
then_block,
1557+
candidates,
1558+
local_types,
1559+
type_table,
1560+
cache,
1561+
inner_labels,
1562+
);
14041563
if let Some(eb) = else_block {
1405-
replace_in_block(eb, candidates, local_types, type_table, cache);
1564+
replace_in_block(eb, candidates, local_types, type_table, cache, inner_labels);
14061565
}
14071566
new_stmts.push(stmt);
14081567
for c in candidates {
@@ -1413,12 +1572,26 @@ fn replace_in_block(
14131572
continue;
14141573
}
14151574
TirStmtKind::Loop { body } => {
1416-
replace_in_block(body, candidates, local_types, type_table, cache);
1575+
replace_in_block(
1576+
body,
1577+
candidates,
1578+
local_types,
1579+
type_table,
1580+
cache,
1581+
inner_labels,
1582+
);
14171583
new_stmts.push(stmt);
14181584
continue;
14191585
}
14201586
TirStmtKind::LabeledBlock { block: inner, .. } => {
1421-
replace_in_block(inner, candidates, local_types, type_table, cache);
1587+
replace_in_block(
1588+
inner,
1589+
candidates,
1590+
local_types,
1591+
type_table,
1592+
cache,
1593+
inner_labels,
1594+
);
14221595
new_stmts.push(stmt);
14231596
continue;
14241597
}
@@ -1436,25 +1609,38 @@ fn replace_in_block(
14361609
{
14371610
replace_in_expr(scrutinee, candidates, &ctx);
14381611
for arm in arms {
1439-
replace_in_block(arm, candidates, local_types, type_table, cache);
1612+
replace_in_block(
1613+
arm,
1614+
candidates,
1615+
local_types,
1616+
type_table,
1617+
cache,
1618+
inner_labels,
1619+
);
14401620
}
1441-
replace_in_block(default, candidates, local_types, type_table, cache);
1621+
replace_in_block(
1622+
default,
1623+
candidates,
1624+
local_types,
1625+
type_table,
1626+
cache,
1627+
inner_labels,
1628+
);
14421629
}
14431630
new_stmts.push(stmt);
14441631
continue;
14451632
}
14461633
_ => {}
14471634
}
14481635

1449-
// Insert write-back before return/break statements, since they exit the
1450-
// current scope and would skip the post-loop write-back.
1451-
//
1452-
// `break <label>` in a for-loop desugars to a labeled block exit that
1453-
// skips post-loop write-back statements placed inside the labeled block.
1454-
if matches!(
1455-
stmt.kind,
1456-
TirStmtKind::Return { .. } | TirStmtKind::Break { .. }
1457-
) {
1636+
// Insert write-back before return/break statements that escape this
1637+
// block's scope. Breaks targeting labels defined *within* this block
1638+
// stay in scope (the block's own exit handles write-back), so they
1639+
// are excluded.
1640+
if matches!(stmt.kind, TirStmtKind::Return { .. })
1641+
|| matches!(&stmt.kind, TirStmtKind::Break { label, .. }
1642+
if !label.as_ref().is_some_and(|l| inner_labels.contains(l.as_str())))
1643+
{
14581644
replace_in_stmt(&mut stmt, candidates, &ctx);
14591645
new_stmts.extend(make_write_back_stmts(candidates, span));
14601646
new_stmts.push(stmt);
@@ -2157,6 +2343,7 @@ fn replace_in_expr(expr: &mut TirExpr, candidates: &[ScalarizeCandidate], ctx: &
21572343
ctx.local_types,
21582344
ctx.type_table,
21592345
ctx.cache,
2346+
ctx.inner_labels,
21602347
);
21612348
}
21622349
TirExprKind::GlobalVarSet { value, .. } => {

wado-compiler/tests/fixtures.golden/base64_decode.wir.wado

Lines changed: 0 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

wado-compiler/tests/fixtures.golden/base64_encode.wir.wado

Lines changed: 0 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)