Skip to content

Commit 0ab032a

Browse files
authored
optimize: revise inlining costs (JuliaLang#51599)
Add a bonus for Intrinsics called with mostly constant arguments. We know that simple expressions like `x*1 + 0` will get optimized later by LLVM, and also likely fold into other expressions, so try to reflect that in the cost estimated earlier. Additionally rebalance some of the other costs to more accurately reflect what they take in assembly.
1 parent f919e8f commit 0ab032a

File tree

3 files changed

+62
-34
lines changed

3 files changed

+62
-34
lines changed

base/compiler/optimize.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ function stmt_effect_flags(𝕃ₒ::AbstractLattice, @nospecialize(stmt), @nospe
309309
nothrow = _builtin_nothrow(𝕃ₒ, f, argtypes, rt)
310310
return (true, nothrow, nothrow)
311311
end
312-
if f === Intrinsics.cglobal
312+
if f === Intrinsics.cglobal || f === Intrinsics.llvmcall
313313
# TODO: these are not yet linearized
314314
return (false, false, false)
315315
end
@@ -1031,11 +1031,36 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
10311031
f = singleton_type(ftyp)
10321032
if isa(f, IntrinsicFunction)
10331033
iidx = Int(reinterpret(Int32, f::IntrinsicFunction)) + 1
1034-
if !isassigned(T_IFUNC_COST, iidx)
1035-
# unknown/unhandled intrinsic
1036-
return params.inline_nonleaf_penalty
1034+
if isassigned(T_IFUNC, iidx)
1035+
minarg, maxarg, = T_IFUNC[iidx]
1036+
nargs = length(ex.args)
1037+
if minarg + 1 <= nargs <= maxarg + 1
1038+
# With mostly constant arguments, all Intrinsics tend to become very cheap
1039+
# and are likely to combine with the operations around them,
1040+
# so reduce their cost by half.
1041+
cost = T_IFUNC_COST[iidx]
1042+
if cost == 0 || nargs < 3 ||
1043+
(f === Intrinsics.cglobal || f === Intrinsics.llvmcall) # these hold malformed IR, so argextype will crash on them
1044+
return cost
1045+
end
1046+
aty2 = widenconditional(argextype(ex.args[2], src, sptypes))
1047+
nconst = Int(aty2 isa Const)
1048+
for i = 3:nargs
1049+
aty = widenconditional(argextype(ex.args[i], src, sptypes))
1050+
if widenconst(aty) != widenconst(aty2)
1051+
nconst = 0
1052+
break
1053+
end
1054+
nconst += aty isa Const
1055+
end
1056+
if nconst + 2 >= nargs
1057+
cost = (cost - 1) ÷ 2
1058+
end
1059+
return cost
1060+
end
10371061
end
1038-
return T_IFUNC_COST[iidx]
1062+
# unknown/unhandled intrinsic
1063+
return params.inline_nonleaf_penalty
10391064
end
10401065
if isa(f, Builtin) && f !== invoke
10411066
# The efficiency of operations like a[i] and s.b
@@ -1046,9 +1071,12 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
10461071
# tuple iteration/destructuring makes that impossible
10471072
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
10481073
return 0
1049-
elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3
1074+
elseif (f === Core.arrayref || f === Core.const_arrayref) && length(ex.args) >= 3
10501075
atyp = argextype(ex.args[3], src, sptypes)
10511076
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
1077+
elseif f === Core.arrayset && length(ex.args) >= 3
1078+
atyp = argextype(ex.args[2], src, sptypes)
1079+
return isknowntype(atyp) ? 8 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
10521080
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
10531081
return 1
10541082
end

base/compiler/tfuncs.jl

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ end
153153
@nospecs conversion_tfunc(𝕃::AbstractLattice, t, x) = conversion_tfunc(widenlattice(𝕃), t, x)
154154
@nospecs conversion_tfunc(::JLTypeLattice, t, x) = instanceof_tfunc(t, true)[1]
155155

156-
add_tfunc(bitcast, 2, 2, bitcast_tfunc, 1)
157-
add_tfunc(sext_int, 2, 2, conversion_tfunc, 1)
158-
add_tfunc(zext_int, 2, 2, conversion_tfunc, 1)
159-
add_tfunc(trunc_int, 2, 2, conversion_tfunc, 1)
156+
add_tfunc(bitcast, 2, 2, bitcast_tfunc, 0)
157+
add_tfunc(sext_int, 2, 2, conversion_tfunc, 0)
158+
add_tfunc(zext_int, 2, 2, conversion_tfunc, 0)
159+
add_tfunc(trunc_int, 2, 2, conversion_tfunc, 0)
160160
add_tfunc(fptoui, 2, 2, conversion_tfunc, 1)
161161
add_tfunc(fptosi, 2, 2, conversion_tfunc, 1)
162162
add_tfunc(uitofp, 2, 2, conversion_tfunc, 1)
@@ -170,30 +170,30 @@ add_tfunc(fpext, 2, 2, conversion_tfunc, 1)
170170
@nospecs math_tfunc(𝕃::AbstractLattice, args...) = math_tfunc(widenlattice(𝕃), args...)
171171
@nospecs math_tfunc(::JLTypeLattice, x, xs...) = widenconst(x)
172172

173-
add_tfunc(neg_int, 1, 1, math_tfunc, 1)
173+
add_tfunc(neg_int, 1, 1, math_tfunc, 0)
174174
add_tfunc(add_int, 2, 2, math_tfunc, 1)
175175
add_tfunc(sub_int, 2, 2, math_tfunc, 1)
176-
add_tfunc(mul_int, 2, 2, math_tfunc, 4)
177-
add_tfunc(sdiv_int, 2, 2, math_tfunc, 30)
178-
add_tfunc(udiv_int, 2, 2, math_tfunc, 30)
179-
add_tfunc(srem_int, 2, 2, math_tfunc, 30)
180-
add_tfunc(urem_int, 2, 2, math_tfunc, 30)
176+
add_tfunc(mul_int, 2, 2, math_tfunc, 3)
177+
add_tfunc(sdiv_int, 2, 2, math_tfunc, 20)
178+
add_tfunc(udiv_int, 2, 2, math_tfunc, 20)
179+
add_tfunc(srem_int, 2, 2, math_tfunc, 20)
180+
add_tfunc(urem_int, 2, 2, math_tfunc, 20)
181181
add_tfunc(add_ptr, 2, 2, math_tfunc, 1)
182182
add_tfunc(sub_ptr, 2, 2, math_tfunc, 1)
183183
add_tfunc(neg_float, 1, 1, math_tfunc, 1)
184-
add_tfunc(add_float, 2, 2, math_tfunc, 1)
185-
add_tfunc(sub_float, 2, 2, math_tfunc, 1)
186-
add_tfunc(mul_float, 2, 2, math_tfunc, 4)
187-
add_tfunc(div_float, 2, 2, math_tfunc, 4)
188-
add_tfunc(fma_float, 3, 3, math_tfunc, 5)
189-
add_tfunc(muladd_float, 3, 3, math_tfunc, 5)
184+
add_tfunc(add_float, 2, 2, math_tfunc, 2)
185+
add_tfunc(sub_float, 2, 2, math_tfunc, 2)
186+
add_tfunc(mul_float, 2, 2, math_tfunc, 8)
187+
add_tfunc(div_float, 2, 2, math_tfunc, 10)
188+
add_tfunc(fma_float, 3, 3, math_tfunc, 8)
189+
add_tfunc(muladd_float, 3, 3, math_tfunc, 8)
190190

191191
# fast arithmetic
192192
add_tfunc(neg_float_fast, 1, 1, math_tfunc, 1)
193-
add_tfunc(add_float_fast, 2, 2, math_tfunc, 1)
194-
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 1)
195-
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 2)
196-
add_tfunc(div_float_fast, 2, 2, math_tfunc, 2)
193+
add_tfunc(add_float_fast, 2, 2, math_tfunc, 2)
194+
add_tfunc(sub_float_fast, 2, 2, math_tfunc, 2)
195+
add_tfunc(mul_float_fast, 2, 2, math_tfunc, 8)
196+
add_tfunc(div_float_fast, 2, 2, math_tfunc, 10)
197197

198198
# bitwise operators
199199
# -----------------
@@ -280,12 +280,12 @@ add_tfunc(le_float_fast, 2, 2, cmp_tfunc, 1)
280280
@nospecs chk_tfunc(𝕃::AbstractLattice, x, y) = chk_tfunc(widenlattice(𝕃), x, y)
281281
@nospecs chk_tfunc(::JLTypeLattice, x, y) = Tuple{widenconst(x), Bool}
282282

283-
add_tfunc(checked_sadd_int, 2, 2, chk_tfunc, 10)
284-
add_tfunc(checked_uadd_int, 2, 2, chk_tfunc, 10)
285-
add_tfunc(checked_ssub_int, 2, 2, chk_tfunc, 10)
286-
add_tfunc(checked_usub_int, 2, 2, chk_tfunc, 10)
287-
add_tfunc(checked_smul_int, 2, 2, chk_tfunc, 10)
288-
add_tfunc(checked_umul_int, 2, 2, chk_tfunc, 10)
283+
add_tfunc(checked_sadd_int, 2, 2, chk_tfunc, 2)
284+
add_tfunc(checked_uadd_int, 2, 2, chk_tfunc, 2)
285+
add_tfunc(checked_ssub_int, 2, 2, chk_tfunc, 2)
286+
add_tfunc(checked_usub_int, 2, 2, chk_tfunc, 2)
287+
add_tfunc(checked_smul_int, 2, 2, chk_tfunc, 5)
288+
add_tfunc(checked_umul_int, 2, 2, chk_tfunc, 5)
289289

290290
# other, misc
291291
# -----------

doc/src/devdocs/inference.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ Each statement gets analyzed for its total cost in a function called
9696
as follows:
9797
```jldoctest; filter=r"tuple.jl:\d+"
9898
julia> Base.print_statement_costs(stdout, map, (typeof(sqrt), Tuple{Int},)) # map(sqrt, (2,))
99-
map(f, t::Tuple{Any}) @ Base tuple.jl:291
99+
map(f, t::Tuple{Any}) @ Base tuple.jl:281
100100
0 1 ─ %1 = $(Expr(:boundscheck, true))::Bool
101101
0 │ %2 = Base.getfield(_3, 1, %1)::Int64
102102
1 │ %3 = Base.sitofp(Float64, %2)::Float64
103-
2 │ %4 = Base.lt_float(%3, 0.0)::Bool
103+
0 │ %4 = Base.lt_float(%3, 0.0)::Bool
104104
0 └── goto #3 if not %4
105105
0 2 ─ invoke Base.Math.throw_complex_domainerror(:sqrt::Symbol, %3::Float64)::Union{}
106106
0 └── unreachable

0 commit comments

Comments
 (0)