Skip to content

Commit 6f1453d

Browse files
authored
ZJIT: Support optional keyword arguments in direct send (ruby#15873)
This fills in constants when unspecified optional keyword args have static default values. For complex defaults we calculate the kw_bits and utilize the checkkeyword logic we already had. The following benchmarks used to register param_kw_opt. Some of them (like graphql*) just trade that for some other complexity, or "too_many_args_for_lir". Notable improvements include activerecord where the previous param_kw_opt count has a corresponding drop in complex args and dynamic_send_count and a nearly equal rise in optimized_send_count. The gains are similar but not as complete in hexapdf, liquid-render, lobsters, railsbench, shipit. | Benchmark | param_kw_opt | Δ one_or_more_complex | Δ too_many_args | Δ dynamic_send | Δ optimized_send | |-----------|-------------:|----------------------:|----------------:|---------------:|-----------------:| | activerecord | 6,307,141 | -6,253,823 | +4,084 | -6,306,223 | +6,279,766 | | blurhash | 21 | -21 | +0 | -23 | +20 | | chunky-png | 813,604 | -813,604 | +0 | -813,616 | +813,556 | | erubi-rails | 1,590,395 | -590,274 | +35,578 | -552,914 | +550,826 | | fluentd | 4,906 | -4,854 | +21 | -5,745 | +5,080 | | graphql | 1,610,439 | -1,610,432 | +1,605,751 | -4,688 | +4,628 | | graphql-native | 16,332,386 | -16,332,375 | +16,309,681 | -22,701 | +22,638 | | hexapdf | 9,165,465 | -9,124,509 | +203,754 | -8,920,727 | +8,839,295 | | liquid-compile | 14,817 | -14,792 | +0 | -14,705 | +15,045 | | liquid-render | 3,994,905 | -3,994,901 | +0 | -3,994,868 | +3,020,779 | | lobsters | 2,467,510 | -2,297,298 | +205,610 | -2,216,583 | +1,694,092 | | protoboeuf | 11,521 | -11,521 | +0 | -11,523 | +11,520 | | psych-load | 77,612 | -77,609 | +29,942 | -77,613 | -12,242 | | rack | 2,743 | -2,742 | +0 | -2,750 | +2,668 | | railsbench | 3,579,778 | -2,517,615 | +432,575 | -2,084,480 | +1,882,928 | | ruby-lsp | 287,171 | -379,716 | +37 | -409,368 | -267,248 | | rubyboy | 5,993,004 | -5,993,003 | +0 | -5,993,006 | +5,992,993 | | sequel | 182,652 | -182,631 | +0 | -182,563 | +122,687 | | shipit | 3,289,456 | -2,778,419 | +306,867 | -3,201,395 | +1,068,505 | | tinygql | 2,732 | -2,732 | +1 | -2,734 | +2,729 |
1 parent 01984fa commit 6f1453d

File tree

6 files changed

+589
-138
lines changed

6 files changed

+589
-138
lines changed

test/ruby/test_zjit.rb

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,61 @@ def entry
833833
}, call_threshold: 2
834834
end
835835

836+
def test_pos_optional_with_maybe_too_many_args
837+
assert_compiles '[[1, 2, 3, 4, 5, 6], [10, 20, 30, 4, 5, 6], [10, 20, 30, 40, 50, 60]]', %q{
838+
def target(a = 1, b = 2, c = 3, d = 4, e = 5, f:) = [a, b, c, d, e, f]
839+
def test = [target(f: 6), target(10, 20, 30, f: 6), target(10, 20, 30, 40, 50, f: 60)]
840+
test
841+
test
842+
}, call_threshold: 2
843+
end
844+
845+
def test_send_kwarg_partial_optional
846+
assert_compiles '[[1, 2, 3], [1, 20, 3], [10, 2, 30]]', %q{
847+
def test(a: 1, b: 2, c: 3) = [a, b, c]
848+
def entry = [test, test(b: 20), test(c: 30, a: 10)]
849+
entry
850+
entry
851+
}, call_threshold: 2
852+
end
853+
854+
def test_send_kwarg_optional_a_lot
855+
assert_compiles '[[1, 2, 3, 4, 5, 6], [1, 2, 3, 7, 8, 9], [2, 4, 6, 8, 10, 12]]', %q{
856+
def test(a: 1, b: 2, c: 3, d: 4, e: 5, f: 6) = [a, b, c, d, e, f]
857+
def entry = [test, test(d: 7, f: 9, e: 8), test(f: 12, e: 10, d: 8, c: 6, b: 4, a: 2)]
858+
entry
859+
entry
860+
}, call_threshold: 2
861+
end
862+
863+
def test_send_kwarg_non_constant_default
864+
assert_compiles '[[1, 2], [10, 2]]', %q{
865+
def make_default = 2
866+
def test(a: 1, b: make_default) = [a, b]
867+
def entry = [test, test(a: 10)]
868+
entry
869+
entry
870+
}, call_threshold: 2
871+
end
872+
873+
def test_send_kwarg_optional_static_with_side_exit
874+
# verify frame reconstruction with synthesized keyword defaults is correct
875+
assert_compiles '[10, 2, 10]', %q{
876+
def callee(a: 1, b: 2)
877+
# use binding to force side-exit
878+
x = binding.local_variable_get(:a)
879+
[a, b, x]
880+
end
881+
882+
def entry
883+
callee(a: 10) # b should get default value
884+
end
885+
886+
entry
887+
entry
888+
}, call_threshold: 2
889+
end
890+
836891
def test_send_all_arg_types
837892
assert_compiles '[:req, :opt, :post, :kwr, :kwo, true]', %q{
838893
def test(a, b = :opt, c, d:, e: :kwo) = [a, b, c, d, e, block_given?]
@@ -1388,6 +1443,190 @@ def test
13881443
}, call_threshold: 2
13891444
end
13901445

1446+
def test_invokesuper_with_optional_keyword_args
1447+
assert_compiles '[1, 2, 3]', %q{
1448+
class Parent
1449+
def foo(a, b: 2, c: 3) = [a, b, c]
1450+
end
1451+
1452+
class Child < Parent
1453+
def foo(a) = super(a)
1454+
end
1455+
1456+
def test = Child.new.foo(1)
1457+
1458+
test
1459+
test
1460+
}, call_threshold: 2
1461+
end
1462+
1463+
def test_send_with_non_constant_keyword_default
1464+
assert_compiles '[[2, 4, 16], [10, 4, 16], [2, 20, 16], [2, 4, 30], [10, 20, 30]]', %q{
1465+
def dbl(x = 1) = x * 2
1466+
1467+
def foo(a: dbl, b: dbl(2), c: dbl(2 ** 3))
1468+
[a, b, c]
1469+
end
1470+
1471+
def test
1472+
[
1473+
foo,
1474+
foo(a: 10),
1475+
foo(b: 20),
1476+
foo(c: 30),
1477+
foo(a: 10, b: 20, c: 30)
1478+
]
1479+
end
1480+
1481+
test
1482+
test
1483+
}, call_threshold: 2
1484+
end
1485+
1486+
def test_send_with_non_constant_keyword_default_not_evaluated_when_provided
1487+
assert_compiles '[1, 2, 3]', %q{
1488+
def foo(a: raise, b: raise, c: raise)
1489+
[a, b, c]
1490+
end
1491+
1492+
def test
1493+
foo(a: 1, b: 2, c: 3)
1494+
end
1495+
1496+
test
1497+
test
1498+
}, call_threshold: 2
1499+
end
1500+
1501+
def test_send_with_non_constant_keyword_default_evaluated_when_not_provided
1502+
assert_compiles '["a", "b", "c"]', %q{
1503+
def raise_a = raise "a"
1504+
def raise_b = raise "b"
1505+
def raise_c = raise "c"
1506+
1507+
def foo(a: raise_a, b: raise_b, c: raise_c)
1508+
[a, b, c]
1509+
end
1510+
1511+
def test_a
1512+
foo(b: 2, c: 3)
1513+
rescue RuntimeError => e
1514+
e.message
1515+
end
1516+
1517+
def test_b
1518+
foo(a: 1, c: 3)
1519+
rescue RuntimeError => e
1520+
e.message
1521+
end
1522+
1523+
def test_c
1524+
foo(a: 1, b: 2)
1525+
rescue RuntimeError => e
1526+
e.message
1527+
end
1528+
1529+
def test
1530+
[test_a, test_b, test_c]
1531+
end
1532+
1533+
test
1534+
test
1535+
}, call_threshold: 2
1536+
end
1537+
1538+
def test_send_with_non_constant_keyword_default_jit_to_jit
1539+
# Test that kw_bits passing works correctly in JIT-to-JIT calls
1540+
assert_compiles '[2, 4, 6]', %q{
1541+
def make_default(x) = x * 2
1542+
1543+
def callee(a: make_default(1), b: make_default(2), c: make_default(3))
1544+
[a, b, c]
1545+
end
1546+
1547+
def caller_method
1548+
callee
1549+
end
1550+
1551+
# Warm up callee first so it gets JITted
1552+
callee
1553+
callee
1554+
1555+
# Now warm up caller - this creates JIT-to-JIT call
1556+
caller_method
1557+
caller_method
1558+
}, call_threshold: 2
1559+
end
1560+
1561+
def test_send_with_non_constant_keyword_default_side_exit
1562+
# Verify frame reconstruction includes correct values for non-constant defaults
1563+
assert_compiles '[10, 2, 30]', %q{
1564+
def make_b = 2
1565+
1566+
def callee(a: 1, b: make_b, c: 3)
1567+
x = binding.local_variable_get(:a)
1568+
y = binding.local_variable_get(:b)
1569+
z = binding.local_variable_get(:c)
1570+
[x, y, z]
1571+
end
1572+
1573+
def test
1574+
callee(a: 10, c: 30)
1575+
end
1576+
1577+
test
1578+
test
1579+
}, call_threshold: 2
1580+
end
1581+
1582+
def test_send_with_non_constant_keyword_default_evaluation_order
1583+
# Verify defaults are evaluated left-to-right and only when not provided
1584+
assert_compiles '[["a", "b", "c"], ["b", "c"], ["a", "c"], ["a", "b"]]', %q{
1585+
def log(x)
1586+
$order << x
1587+
x
1588+
end
1589+
1590+
def foo(a: log("a"), b: log("b"), c: log("c"))
1591+
[a, b, c]
1592+
end
1593+
1594+
def test
1595+
results = []
1596+
1597+
$order = []
1598+
foo
1599+
results << $order.dup
1600+
1601+
$order = []
1602+
foo(a: "A")
1603+
results << $order.dup
1604+
1605+
$order = []
1606+
foo(b: "B")
1607+
results << $order.dup
1608+
1609+
$order = []
1610+
foo(c: "C")
1611+
results << $order.dup
1612+
1613+
results
1614+
end
1615+
1616+
test
1617+
test
1618+
}, call_threshold: 2
1619+
end
1620+
1621+
def test_send_with_too_many_non_constant_keyword_defaults
1622+
assert_compiles '35', %q{
1623+
def many_kwargs( k1: 1, k2: 2, k3: 3, k4: 4, k5: 5, k6: 6, k7: 7, k8: 8, k9: 9, k10: 10, k11: 11, k12: 12, k13: 13, k14: 14, k15: 15, k16: 16, k17: 17, k18: 18, k19: 19, k20: 20, k21: 21, k22: 22, k23: 23, k24: 24, k25: 25, k26: 26, k27: 27, k28: 28, k29: 29, k30: 30, k31: 31, k32: 32, k33: 33, k34: k33 + 1) = k1 + k34
1624+
def t = many_kwargs
1625+
t
1626+
t
1627+
}, call_threshold: 2
1628+
end
1629+
13911630
def test_invokebuiltin
13921631
# Not using assert_compiles due to register spill
13931632
assert_runs '["."]', %q{

zjit/src/codegen.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio
401401
&Insn::Send { cd, blockiseq, state, reason, .. } => gen_send(jit, asm, cd, blockiseq, &function.frame_state(state), reason),
402402
&Insn::SendForward { cd, blockiseq, state, reason, .. } => gen_send_forward(jit, asm, cd, blockiseq, &function.frame_state(state), reason),
403403
&Insn::SendWithoutBlock { cd, state, reason, .. } => gen_send_without_block(jit, asm, cd, &function.frame_state(state), reason),
404-
Insn::SendWithoutBlockDirect { cme, iseq, recv, args, state, .. } => gen_send_iseq_direct(cb, jit, asm, *cme, *iseq, opnd!(recv), opnds!(args), &function.frame_state(*state), None),
404+
Insn::SendWithoutBlockDirect { cme, iseq, recv, args, kw_bits, state, .. } => gen_send_iseq_direct(cb, jit, asm, *cme, *iseq, opnd!(recv), opnds!(args), *kw_bits, &function.frame_state(*state), None),
405405
&Insn::InvokeSuper { cd, blockiseq, state, reason, .. } => gen_invokesuper(jit, asm, cd, blockiseq, &function.frame_state(state), reason),
406406
&Insn::InvokeBlock { cd, state, reason, .. } => gen_invokeblock(jit, asm, cd, &function.frame_state(state), reason),
407407
Insn::InvokeProc { recv, args, state, kw_splat } => gen_invokeproc(jit, asm, opnd!(recv), opnds!(args), *kw_splat, &function.frame_state(*state)),
@@ -1358,6 +1358,7 @@ fn gen_send_iseq_direct(
13581358
iseq: IseqPtr,
13591359
recv: Opnd,
13601360
args: Vec<Opnd>,
1361+
kw_bits: u32,
13611362
state: &FrameState,
13621363
block_handler: Option<Opnd>,
13631364
) -> lir::Opnd {
@@ -1404,12 +1405,13 @@ fn gen_send_iseq_direct(
14041405
// Write "keyword_bits" to the callee's frame if the callee accepts keywords.
14051406
// This is a synthetic local/parameter that the callee reads via checkkeyword to determine
14061407
// which optional keyword arguments need their defaults evaluated.
1408+
// We write this to the local table slot at bits_start so that:
1409+
// 1. The interpreter can read it via checkkeyword if we side-exit
1410+
// 2. The JIT entry can read it via GetLocal
14071411
if unsafe { rb_get_iseq_flags_has_kw(iseq) } {
14081412
let keyword = unsafe { rb_get_iseq_body_param_keyword(iseq) };
14091413
let bits_start = unsafe { (*keyword).bits_start } as usize;
1410-
// Currently we only support required keywords, so all bits are 0 (all keywords specified).
1411-
// TODO: When supporting optional keywords, calculate actual unspecified_bits here.
1412-
let unspecified_bits = VALUE::fixnum_from_usize(0);
1414+
let unspecified_bits = VALUE::fixnum_from_usize(kw_bits as usize);
14131415
let bits_offset = (state.stack().len() - args.len() + bits_start) * SIZEOF_VALUE;
14141416
asm_comment!(asm, "write keyword bits to callee frame");
14151417
asm.store(Opnd::mem(64, SP, bits_offset as i32), unspecified_bits.into());
@@ -1435,10 +1437,11 @@ fn gen_send_iseq_direct(
14351437
let lead_num = params.lead_num as u32;
14361438
let opt_num = params.opt_num as u32;
14371439
let keyword = params.keyword;
1438-
let kw_req_num = if keyword.is_null() { 0 } else { unsafe { (*keyword).required_num } } as u32;
1439-
let req_num = lead_num + kw_req_num;
1440-
assert!(args.len() as u32 <= req_num + opt_num);
1441-
let num_optionals_passed = args.len() as u32 - req_num;
1440+
let kw_total_num = if keyword.is_null() { 0 } else { unsafe { (*keyword).num } } as u32;
1441+
assert!(args.len() as u32 <= lead_num + opt_num + kw_total_num);
1442+
// For computing optional positional entry point, only count positional args
1443+
let positional_argc = args.len() as u32 - kw_total_num;
1444+
let num_optionals_passed = positional_argc.saturating_sub(lead_num);
14421445
num_optionals_passed
14431446
} else {
14441447
0

0 commit comments

Comments
 (0)