Skip to content

Commit 7747374

Browse files
committed
ZJIT: Get type information from branchif, branchunless, branchnil instructions
Do a sort of "partial static single information (SSI)" form that learns types of operands from branch instructions. A branchif, for example, tells us that in the truthy path, we know the operand is not nil, and not false. Similarly, in the falsy path, we know the operand is either nil or false. Add a RefineType instruction to attach this information. This PR does this in SSA construction because it's pretty straightforward, but we can also do a more aggressive version of this that can learn information about e.g. int ranges from other checks later in the optimization pipeline.
1 parent f3a5b0c commit 7747374

File tree

7 files changed

+324
-133
lines changed

7 files changed

+324
-133
lines changed

zjit/src/codegen.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
446446
&Insn::BoxFixnum { val, state } => gen_box_fixnum(jit, asm, opnd!(val), &function.frame_state(state)),
447447
&Insn::UnboxFixnum { val } => gen_unbox_fixnum(asm, opnd!(val)),
448448
Insn::Test { val } => gen_test(asm, opnd!(val)),
449+
Insn::RefineType { val, .. } => opnd!(val),
449450
Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
450451
Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)),
451452
&Insn::GuardBitEquals { val, expected, reason, state } => gen_guard_bit_equals(jit, asm, opnd!(val), expected, reason, &function.frame_state(state)),

zjit/src/hir.rs

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,10 @@ pub enum Insn {
990990
ObjToString { val: InsnId, cd: *const rb_call_data, state: InsnId },
991991
AnyToString { val: InsnId, str: InsnId, state: InsnId },
992992

993+
/// Refine the known type information of with additional type information.
994+
/// Computes the intersection of the existing type and the new type.
995+
RefineType { val: InsnId, new_type: Type },
996+
993997
/// Side-exit if val doesn't have the expected type.
994998
GuardType { val: InsnId, guard_type: Type, state: InsnId },
995999
GuardTypeNot { val: InsnId, guard_type: Type, state: InsnId },
@@ -1207,6 +1211,7 @@ impl Insn {
12071211
Insn::IncrCounterPtr { .. } => effects::Any,
12081212
Insn::CheckInterrupts { .. } => effects::Any,
12091213
Insn::InvokeProc { .. } => effects::Any,
1214+
Insn::RefineType { .. } => effects::Empty,
12101215
}
12111216
}
12121217

@@ -1502,6 +1507,7 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> {
15021507
Insn::FixnumLShift { left, right, .. } => { write!(f, "FixnumLShift {left}, {right}") },
15031508
Insn::FixnumRShift { left, right, .. } => { write!(f, "FixnumRShift {left}, {right}") },
15041509
Insn::GuardType { val, guard_type, .. } => { write!(f, "GuardType {val}, {}", guard_type.print(self.ptr_map)) },
1510+
Insn::RefineType { val, new_type, .. } => { write!(f, "RefineType {val}, {}", new_type.print(self.ptr_map)) },
15051511
Insn::GuardTypeNot { val, guard_type, .. } => { write!(f, "GuardTypeNot {val}, {}", guard_type.print(self.ptr_map)) },
15061512
Insn::GuardBitEquals { val, expected, .. } => { write!(f, "GuardBitEquals {val}, {}", expected.print(self.ptr_map)) },
15071513
&Insn::GuardShape { val, shape, .. } => { write!(f, "GuardShape {val}, {:p}", self.ptr_map.map_shape(shape)) },
@@ -2164,6 +2170,7 @@ impl Function {
21642170
Jump(target) => Jump(find_branch_edge!(target)),
21652171
&IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) },
21662172
&IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) },
2173+
&RefineType { val, new_type } => RefineType { val: find!(val), new_type },
21672174
&GuardType { val, guard_type, state } => GuardType { val: find!(val), guard_type, state },
21682175
&GuardTypeNot { val, guard_type, state } => GuardTypeNot { val: find!(val), guard_type, state },
21692176
&GuardBitEquals { val, expected, reason, state } => GuardBitEquals { val: find!(val), expected, reason, state },
@@ -2412,6 +2419,7 @@ impl Function {
24122419
Insn::CCall { return_type, .. } => *return_type,
24132420
&Insn::CCallVariadic { return_type, .. } => return_type,
24142421
Insn::GuardType { val, guard_type, .. } => self.type_of(*val).intersection(*guard_type),
2422+
Insn::RefineType { val, new_type, .. } => self.type_of(*val).intersection(*new_type),
24152423
Insn::GuardTypeNot { .. } => types::BasicObject,
24162424
Insn::GuardBitEquals { val, expected, .. } => self.type_of(*val).intersection(Type::from_const(*expected)),
24172425
Insn::GuardShape { val, .. } => self.type_of(*val),
@@ -2582,6 +2590,7 @@ impl Function {
25822590
| Insn::GuardTypeNot { val, .. }
25832591
| Insn::GuardShape { val, .. }
25842592
| Insn::GuardBitEquals { val, .. } => self.chase_insn(val),
2593+
| Insn::RefineType { val, .. } => self.chase_insn(val),
25852594
_ => id,
25862595
}
25872596
}
@@ -4425,6 +4434,7 @@ impl Function {
44254434
worklist.extend(values);
44264435
worklist.push_back(state);
44274436
}
4437+
| &Insn::RefineType { val, .. }
44284438
| &Insn::Return { val }
44294439
| &Insn::Test { val }
44304440
| &Insn::SetLocal { val, .. }
@@ -5342,6 +5352,7 @@ impl Function {
53425352
self.assert_subtype(insn_id, val, types::BasicObject)?;
53435353
self.assert_subtype(insn_id, class, types::Class)
53445354
}
5355+
Insn::RefineType { .. } => Ok(()),
53455356
}
53465357
}
53475358

@@ -5534,6 +5545,19 @@ impl FrameState {
55345545
state.stack.extend_from_slice(new_args);
55355546
state
55365547
}
5548+
5549+
fn replace(&mut self, old: InsnId, new: InsnId) {
5550+
for slot in &mut self.stack {
5551+
if *slot == old {
5552+
*slot = new;
5553+
}
5554+
}
5555+
for slot in &mut self.locals {
5556+
if *slot == old {
5557+
*slot = new;
5558+
}
5559+
}
5560+
}
55375561
}
55385562

55395563
/// Print adaptor for [`FrameState`]. See [`PtrPrintMap`].
@@ -6217,10 +6241,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62176241
let test_id = fun.push_insn(block, Insn::Test { val });
62186242
let target_idx = insn_idx_at_offset(insn_idx, offset);
62196243
let target = insn_idx_to_block[&target_idx];
6244+
let nil_false_type = types::NilClass.union(types::FalseClass);
6245+
let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type });
6246+
let mut iffalse_state = state.clone();
6247+
iffalse_state.replace(val, nil_false);
62206248
let _branch_id = fun.push_insn(block, Insn::IfFalse {
62216249
val: test_id,
6222-
target: BranchEdge { target, args: state.as_args(self_param) }
6250+
target: BranchEdge { target, args: iffalse_state.as_args(self_param) }
62236251
});
6252+
let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass);
6253+
let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type });
6254+
state.replace(val, not_nil_false);
62246255
queue.push_back((state.clone(), target, target_idx, local_inval));
62256256
}
62266257
YARVINSN_branchif => {
@@ -6230,10 +6261,17 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62306261
let test_id = fun.push_insn(block, Insn::Test { val });
62316262
let target_idx = insn_idx_at_offset(insn_idx, offset);
62326263
let target = insn_idx_to_block[&target_idx];
6264+
let not_nil_false_type = types::BasicObject.subtract(types::NilClass).subtract(types::FalseClass);
6265+
let not_nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: not_nil_false_type });
6266+
let mut iftrue_state = state.clone();
6267+
iftrue_state.replace(val, not_nil_false);
62336268
let _branch_id = fun.push_insn(block, Insn::IfTrue {
62346269
val: test_id,
6235-
target: BranchEdge { target, args: state.as_args(self_param) }
6270+
target: BranchEdge { target, args: iftrue_state.as_args(self_param) }
62366271
});
6272+
let nil_false_type = types::NilClass.union(types::FalseClass);
6273+
let nil_false = fun.push_insn(block, Insn::RefineType { val, new_type: nil_false_type });
6274+
state.replace(val, nil_false);
62376275
queue.push_back((state.clone(), target, target_idx, local_inval));
62386276
}
62396277
YARVINSN_branchnil => {
@@ -6243,10 +6281,16 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result<Function, ParseError> {
62436281
let test_id = fun.push_insn(block, Insn::IsNil { val });
62446282
let target_idx = insn_idx_at_offset(insn_idx, offset);
62456283
let target = insn_idx_to_block[&target_idx];
6284+
let nil = fun.push_insn(block, Insn::Const { val: Const::Value(Qnil) });
6285+
let mut iftrue_state = state.clone();
6286+
iftrue_state.replace(val, nil);
62466287
let _branch_id = fun.push_insn(block, Insn::IfTrue {
62476288
val: test_id,
6248-
target: BranchEdge { target, args: state.as_args(self_param) }
6289+
target: BranchEdge { target, args: iftrue_state.as_args(self_param) }
62496290
});
6291+
let new_type = types::BasicObject.subtract(types::NilClass);
6292+
let not_nil = fun.push_insn(block, Insn::RefineType { val, new_type });
6293+
state.replace(val, not_nil);
62506294
queue.push_back((state.clone(), target, target_idx, local_inval));
62516295
}
62526296
YARVINSN_opt_case_dispatch => {
@@ -7665,21 +7709,23 @@ mod graphviz_tests {
76657709
<TR><TD ALIGN="left" PORT="v12">PatchPoint NoTracePoint&nbsp;</TD></TR>
76667710
<TR><TD ALIGN="left" PORT="v14">CheckInterrupts&nbsp;</TD></TR>
76677711
<TR><TD ALIGN="left" PORT="v15">v15:CBool = Test v9&nbsp;</TD></TR>
7668-
<TR><TD ALIGN="left" PORT="v16">IfFalse v15, bb3(v8, v9)&nbsp;</TD></TR>
7669-
<TR><TD ALIGN="left" PORT="v18">PatchPoint NoTracePoint&nbsp;</TD></TR>
7670-
<TR><TD ALIGN="left" PORT="v19">v19:Fixnum[3] = Const Value(3)&nbsp;</TD></TR>
7671-
<TR><TD ALIGN="left" PORT="v21">PatchPoint NoTracePoint&nbsp;</TD></TR>
7672-
<TR><TD ALIGN="left" PORT="v22">CheckInterrupts&nbsp;</TD></TR>
7673-
<TR><TD ALIGN="left" PORT="v23">Return v19&nbsp;</TD></TR>
7712+
<TR><TD ALIGN="left" PORT="v16">v16:Falsy = RefineType v9, Falsy&nbsp;</TD></TR>
7713+
<TR><TD ALIGN="left" PORT="v17">IfFalse v15, bb3(v8, v16)&nbsp;</TD></TR>
7714+
<TR><TD ALIGN="left" PORT="v18">v18:Truthy = RefineType v9, Truthy&nbsp;</TD></TR>
7715+
<TR><TD ALIGN="left" PORT="v20">PatchPoint NoTracePoint&nbsp;</TD></TR>
7716+
<TR><TD ALIGN="left" PORT="v21">v21:Fixnum[3] = Const Value(3)&nbsp;</TD></TR>
7717+
<TR><TD ALIGN="left" PORT="v23">PatchPoint NoTracePoint&nbsp;</TD></TR>
7718+
<TR><TD ALIGN="left" PORT="v24">CheckInterrupts&nbsp;</TD></TR>
7719+
<TR><TD ALIGN="left" PORT="v25">Return v21&nbsp;</TD></TR>
76747720
</TABLE>>];
7675-
bb2:v16 -> bb3:params:n;
7721+
bb2:v17 -> bb3:params:n;
76767722
bb3 [label=<<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0">
7677-
<TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v24:BasicObject, v25:BasicObject)&nbsp;</TD></TR>
7678-
<TR><TD ALIGN="left" PORT="v28">PatchPoint NoTracePoint&nbsp;</TD></TR>
7679-
<TR><TD ALIGN="left" PORT="v29">v29:Fixnum[4] = Const Value(4)&nbsp;</TD></TR>
7680-
<TR><TD ALIGN="left" PORT="v31">PatchPoint NoTracePoint&nbsp;</TD></TR>
7681-
<TR><TD ALIGN="left" PORT="v32">CheckInterrupts&nbsp;</TD></TR>
7682-
<TR><TD ALIGN="left" PORT="v33">Return v29&nbsp;</TD></TR>
7723+
<TR><TD ALIGN="LEFT" PORT="params" BGCOLOR="gray">bb3(v26:BasicObject, v27:Falsy)&nbsp;</TD></TR>
7724+
<TR><TD ALIGN="left" PORT="v30">PatchPoint NoTracePoint&nbsp;</TD></TR>
7725+
<TR><TD ALIGN="left" PORT="v31">v31:Fixnum[4] = Const Value(4)&nbsp;</TD></TR>
7726+
<TR><TD ALIGN="left" PORT="v33">PatchPoint NoTracePoint&nbsp;</TD></TR>
7727+
<TR><TD ALIGN="left" PORT="v34">CheckInterrupts&nbsp;</TD></TR>
7728+
<TR><TD ALIGN="left" PORT="v35">Return v31&nbsp;</TD></TR>
76837729
</TABLE>>];
76847730
}
76857731
"#);

0 commit comments

Comments
 (0)