Skip to content

Commit fa11ca2

Browse files
authored
Sink CodeInfo transformation into transform_result_for_cache, continued (JuliaLang#57375)
Continuation of the work started in JuliaLang#56897. Co-authored-by: Cédric Belmant <[email protected]>
1 parent 256680a commit fa11ca2

File tree

8 files changed

+152
-75
lines changed

8 files changed

+152
-75
lines changed

extras/CompilerDevTools/src/CompilerDevTools.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
function Compiler.optimize(interp::SplitCacheInterp, opt::Compiler.OptimizationState, caller::Compiler.InferenceResult)
4949
@invoke Compiler.optimize(interp::Compiler.AbstractInterpreter, opt::Compiler.OptimizationState, caller::Compiler.InferenceResult)
50-
ir = opt.ir::Compiler.IRCode
50+
ir = opt.result.ir::Compiler.IRCode
5151
override = GlobalRef(@__MODULE__(), :with_new_compiler)
5252
for inst in ir.stmts
5353
stmt = inst[:stmt]

src/optimize.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ is_declared_noinline(@nospecialize src::MaybeCompressed) =
129129
# return whether this src should be inlined. If so, retrieve_ir_for_inlining must return an IRCode from it
130130
function src_inlining_policy(interp::AbstractInterpreter,
131131
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt32)
132+
isa(src, OptimizationState) && (src = src.src)
132133
if isa(src, MaybeCompressed)
133134
src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src)
134135
return src_inlineable
@@ -154,10 +155,20 @@ end
154155
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
155156
code_cache(state::InliningState) = WorldView(code_cache(state.interp), state.world)
156157

158+
mutable struct OptimizationResult
159+
ir::IRCode
160+
simplified::Bool # indicates whether the IR was processed with `cfg_simplify!`
161+
end
162+
163+
function simplify_ir!(result::OptimizationResult)
164+
result.ir = cfg_simplify!(result.ir)
165+
result.simplified = true
166+
end
167+
157168
mutable struct OptimizationState{Interp<:AbstractInterpreter}
158169
linfo::MethodInstance
159170
src::CodeInfo
160-
ir::Union{Nothing, IRCode}
171+
result::Union{Nothing, OptimizationResult}
161172
stmt_info::Vector{CallInfo}
162173
mod::Module
163174
sptypes::Vector{VarState}
@@ -226,10 +237,13 @@ include("ssair/passes.jl")
226237
include("ssair/irinterp.jl")
227238

228239
function ir_to_codeinf!(opt::OptimizationState)
229-
(; linfo, src) = opt
230-
src = ir_to_codeinf!(src, opt.ir::IRCode)
231-
src.edges = Core.svec(opt.inlining.edges...)
232-
opt.ir = nothing
240+
(; linfo, src, result) = opt
241+
if result === nothing
242+
return src
243+
end
244+
src = ir_to_codeinf!(src, result.ir)
245+
opt.result = nothing
246+
opt.src = src
233247
maybe_validate_code(linfo, src, "optimized")
234248
return src
235249
end
@@ -489,7 +503,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
489503
@assert !(result isa LimitedAccuracy)
490504
result = widenslotwrapper(result)
491505

492-
opt.ir = ir
506+
opt.result = OptimizationResult(ir, false)
493507

494508
# determine and cache inlineability
495509
if !force_noinline

src/ssair/inlining.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,14 @@ function retrieve_ir_for_inlining(mi::MethodInstance, ir::IRCode, preserve_local
975975
ir.debuginfo.def = mi
976976
return ir, spec_info, DebugInfo(ir.debuginfo, length(ir.stmts))
977977
end
978+
function retrieve_ir_for_inlining(mi::MethodInstance, opt::OptimizationState, preserve_local_sources::Bool)
979+
result = opt.result
980+
if result !== nothing
981+
!result.simplified && simplify_ir!(result)
982+
return retrieve_ir_for_inlining(mi, result.ir, preserve_local_sources)
983+
end
984+
retrieve_ir_for_inlining(mi, opt.src, preserve_local_sources)
985+
end
978986

979987
function handle_single_case!(todo::Vector{Pair{Int,Any}},
980988
ir::IRCode, idx::Int, stmt::Expr, @nospecialize(case),

src/ssair/ir.jl

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,39 +1265,48 @@ function process_phinode_values(old_values::Vector{Any}, late_fixup::Vector{Int}
12651265
values = Vector{Any}(undef, length(old_values))
12661266
for i = 1:length(old_values)
12671267
isassigned(old_values, i) || continue
1268-
val = old_values[i]
1269-
if isa(val, SSAValue)
1270-
if do_rename_ssa
1271-
if !already_inserted(i, OldSSAValue(val.id))
1272-
push!(late_fixup, result_idx)
1273-
val = OldSSAValue(val.id)
1274-
else
1275-
val = renumber_ssa2(val, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa, mark_refined!)
1276-
end
1277-
else
1278-
used_ssas[val.id] += 1
1279-
end
1280-
elseif isa(val, OldSSAValue)
1281-
if !already_inserted(i, val)
1268+
values[i] = process_phinode_value(old_values, i, late_fixup, already_inserted, result_idx, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa, mark_refined!)
1269+
end
1270+
return values
1271+
end
1272+
1273+
function process_phinode_value(old_values::Vector{Any}, i::Int, late_fixup::Vector{Int},
1274+
already_inserted, result_idx::Int,
1275+
ssa_rename::Vector{Any}, used_ssas::Vector{Int},
1276+
new_new_used_ssas::Vector{Int},
1277+
do_rename_ssa::Bool,
1278+
mark_refined!::Union{Refiner, Nothing})
1279+
val = old_values[i]
1280+
if isa(val, SSAValue)
1281+
if do_rename_ssa
1282+
if !already_inserted(i, OldSSAValue(val.id))
12821283
push!(late_fixup, result_idx)
1284+
val = OldSSAValue(val.id)
12831285
else
1284-
# Always renumber these. do_rename_ssa applies only to actual SSAValues
1285-
val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, new_new_used_ssas, true, mark_refined!)
1286-
end
1287-
elseif isa(val, NewSSAValue)
1288-
if val.id < 0
1289-
new_new_used_ssas[-val.id] += 1
1290-
else
1291-
@assert do_rename_ssa
1292-
val = SSAValue(val.id)
1286+
val = renumber_ssa2(val, ssa_rename, used_ssas, new_new_used_ssas, do_rename_ssa, mark_refined!)
12931287
end
1288+
else
1289+
used_ssas[val.id] += 1
12941290
end
1295-
if isa(val, NewSSAValue)
1291+
elseif isa(val, OldSSAValue)
1292+
if !already_inserted(i, val)
12961293
push!(late_fixup, result_idx)
1294+
else
1295+
# Always renumber these. do_rename_ssa applies only to actual SSAValues
1296+
val = renumber_ssa2(SSAValue(val.id), ssa_rename, used_ssas, new_new_used_ssas, true, mark_refined!)
1297+
end
1298+
elseif isa(val, NewSSAValue)
1299+
if val.id < 0
1300+
new_new_used_ssas[-val.id] += 1
1301+
else
1302+
@assert do_rename_ssa
1303+
val = SSAValue(val.id)
12971304
end
1298-
values[i] = val
12991305
end
1300-
return values
1306+
if isa(val, NewSSAValue)
1307+
push!(late_fixup, result_idx)
1308+
end
1309+
return val
13011310
end
13021311

13031312
function renumber_ssa2(val::SSAValue, ssanums::Vector{Any}, used_ssas::Vector{Int},

src/ssair/passes.jl

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2574,51 +2574,50 @@ function cfg_simplify!(ir::IRCode)
25742574
values = phi.values
25752575
(; ssa_rename, late_fixup, used_ssas, new_new_used_ssas) = compact
25762576
ssa_rename[i] = SSAValue(compact.result_idx)
2577-
already_inserted = function (i::Int, val::OldSSAValue)
2577+
already_inserted = function (branch::Int, val::OldSSAValue)
25782578
if val.id in old_bb_stmts
25792579
return val.id <= i
25802580
end
2581-
return bb_rename_pred[phi.edges[i]] < idx
2581+
return 0 < bb_rename_pred[phi.edges[branch]] < idx
25822582
end
2583-
renamed_values = process_phinode_values(values, late_fixup, already_inserted, compact.result_idx, ssa_rename, used_ssas, new_new_used_ssas, true, nothing)
25842583
edges = Int32[]
25852584
values = Any[]
2586-
sizehint!(edges, length(phi.edges)); sizehint!(values, length(renamed_values))
2585+
sizehint!(edges, length(phi.edges)); sizehint!(values, length(phi.values))
25872586
for old_index in 1:length(phi.edges)
25882587
old_edge = phi.edges[old_index]
25892588
new_edge = bb_rename_pred[old_edge]
25902589
if new_edge > 0
25912590
push!(edges, new_edge)
2592-
if isassigned(renamed_values, old_index)
2593-
push!(values, renamed_values[old_index])
2591+
if isassigned(phi.values, old_index)
2592+
val = process_phinode_value(phi.values, old_index, late_fixup, already_inserted, compact.result_idx, ssa_rename, used_ssas, new_new_used_ssas, true, nothing)
2593+
push!(values, val)
25942594
else
25952595
resize!(values, length(values)+1)
25962596
end
25972597
elseif new_edge == -1
25982598
@assert length(phi.edges) == 1
2599-
if isassigned(renamed_values, old_index)
2599+
if isassigned(phi.values, old_index)
2600+
val = process_phinode_value(phi.values, old_index, late_fixup, already_inserted, compact.result_idx, ssa_rename, used_ssas, new_new_used_ssas, true, nothing)
26002601
push!(edges, -1)
2601-
push!(values, renamed_values[old_index])
2602+
push!(values, val)
26022603
end
26032604
elseif new_edge == -3
26042605
# Multiple predecessors, we need to expand out this phi
26052606
all_new_preds = Int32[]
26062607
add_preds!(all_new_preds, bbs, bb_rename_pred, old_edge)
26072608
append!(edges, all_new_preds)
2608-
if isassigned(renamed_values, old_index)
2609-
val = renamed_values[old_index]
2610-
for _ in 1:length(all_new_preds)
2611-
push!(values, val)
2609+
np = length(all_new_preds)
2610+
if np > 0
2611+
if isassigned(phi.values, old_index)
2612+
val = process_phinode_value(phi.values, old_index, late_fixup, already_inserted, compact.result_idx, ssa_rename, used_ssas, new_new_used_ssas, true, nothing)
2613+
for p in 1:np
2614+
push!(values, val)
2615+
p > 2 && count_added_node!(compact, val)
2616+
end
2617+
else
2618+
resize!(values, length(values)+np)
26122619
end
2613-
length(all_new_preds) == 0 && kill_current_use!(compact, val)
2614-
for _ in 2:length(all_new_preds)
2615-
count_added_node!(compact, val)
2616-
end
2617-
else
2618-
resize!(values, length(values)+length(all_new_preds))
26192620
end
2620-
else
2621-
isassigned(renamed_values, old_index) && kill_current_use!(compact, renamed_values[old_index])
26222621
end
26232622
end
26242623
if length(edges) == 0 || (length(edges) == 1 && !isassigned(values, 1))

src/typeinfer.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,21 @@ If set to `true`, record per-method-instance timings within type inference in th
9393
__set_measure_typeinf(onoff::Bool) = __measure_typeinf__[] = onoff
9494
const __measure_typeinf__ = RefValue{Bool}(false)
9595

96-
function finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
96+
function result_edges(interp::AbstractInterpreter, caller::InferenceState)
9797
result = caller.result
9898
opt = result.src
99-
if opt isa OptimizationState
100-
src = ir_to_codeinf!(opt)
101-
edges = src.edges::SimpleVector
102-
caller.src = result.src = src
99+
if isa(opt, OptimizationState)
100+
return Core.svec(opt.inlining.edges...)
103101
else
104-
edges = Core.svec(caller.edges...)
105-
caller.src.edges = edges
102+
return Core.svec(caller.edges...)
106103
end
104+
end
105+
106+
function finish!(interp::AbstractInterpreter, caller::InferenceState, validation_world::UInt)
107+
result = caller.result
107108
#@assert last(result.valid_worlds) <= get_world_counter() || isempty(caller.edges)
108109
if isdefined(result, :ci)
110+
edges = result_edges(interp, caller)
109111
ci = result.ci
110112
# if we aren't cached, we don't need this edge
111113
# but our caller might, so let's just make it anyways
@@ -119,9 +121,10 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
119121
const_flag = is_result_constabi_eligible(result)
120122
discard_src = caller.cache_mode === CACHE_MODE_NULL || const_flag
121123
if !discard_src
122-
inferred_result = transform_result_for_cache(interp, result)
124+
inferred_result = transform_result_for_cache(interp, result, edges)
123125
# TODO: do we want to augment edges here with any :invoke targets that we got from inlining (such that we didn't have a direct edge to it already)?
124126
if inferred_result isa CodeInfo
127+
result.src = inferred_result
125128
if may_compress(interp)
126129
nslots = length(inferred_result.slotflags)
127130
resize!(inferred_result.slottypes::Vector{Any}, nslots)
@@ -278,7 +281,16 @@ function is_result_constabi_eligible(result::InferenceResult)
278281
return isa(result_type, Const) && is_foldable_nothrow(result.ipo_effects) && is_inlineable_constant(result_type.val)
279282
end
280283

281-
transform_result_for_cache(::AbstractInterpreter, result::InferenceResult) = result.src
284+
function transform_result_for_cache(::AbstractInterpreter, result::InferenceResult, edges::SimpleVector)
285+
src = result.src
286+
if isa(src, OptimizationState)
287+
src = ir_to_codeinf!(src)
288+
end
289+
if isa(src, CodeInfo)
290+
src.edges = edges
291+
end
292+
return src
293+
end
282294

283295
function maybe_compress_codeinfo(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInfo)
284296
def = mi.def
@@ -1064,6 +1076,7 @@ function typeinf_frame(interp::AbstractInterpreter, mi::MethodInstance, run_opti
10641076
opt = OptimizationState(frame, interp)
10651077
optimize(interp, opt, frame.result)
10661078
src = ir_to_codeinf!(opt)
1079+
src.edges = Core.svec(opt.inlining.edges...)
10671080
end
10681081
result.src = frame.src = src
10691082
end

test/AbstractInterpreter.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ Compiler.may_optimize(::AbsIntOnlyInterp1) = false
1414
# it should work even if the interpreter discards inferred source entirely
1515
@newinterp AbsIntOnlyInterp2
1616
Compiler.may_optimize(::AbsIntOnlyInterp2) = false
17-
Compiler.transform_result_for_cache(::AbsIntOnlyInterp2, ::Compiler.InferenceResult) = nothing
17+
Compiler.transform_result_for_cache(::AbsIntOnlyInterp2, ::Compiler.InferenceResult, edges::Core.SimpleVector) = nothing
1818
@test Base.infer_return_type(Base.init_stdio, (Ptr{Cvoid},); interp=AbsIntOnlyInterp2()) >: IO
1919

2020
# OverlayMethodTable
@@ -493,9 +493,9 @@ struct CustomData
493493
inferred
494494
CustomData(@nospecialize inferred) = new(inferred)
495495
end
496-
function Compiler.transform_result_for_cache(interp::CustomDataInterp, result::Compiler.InferenceResult)
496+
function Compiler.transform_result_for_cache(interp::CustomDataInterp, result::Compiler.InferenceResult, edges::Core.SimpleVector)
497497
inferred_result = @invoke Compiler.transform_result_for_cache(
498-
interp::Compiler.AbstractInterpreter, result::Compiler.InferenceResult)
498+
interp::Compiler.AbstractInterpreter, result::Compiler.InferenceResult, edges::Core.SimpleVector)
499499
return CustomData(inferred_result)
500500
end
501501
function Compiler.src_inlining_policy(interp::CustomDataInterp, @nospecialize(src),

0 commit comments

Comments
 (0)