Skip to content
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "StaticCompiler"
uuid = "81625895-6c0f-48fc-b932-11a18313743c"
authors = ["Tom Short and contributors"]
version = "0.7.2"
version = "0.7.3"


[deps]
Expand All @@ -18,11 +18,11 @@ StaticTools = "86c06d3c-3f03-46de-9781-57580aa96d0a"

[compat]
CodeInfoTools = "0.3"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 1.5, 1"
LLVM = "6, 7, 8, 9"
MacroTools = "0.5"
StaticTools = "0.8"
julia = "1.8, 1.9"
julia = "1.8, 1.9, 1.10, 1.11"

[extras]
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
Expand Down
33 changes: 25 additions & 8 deletions src/StaticCompiler.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#__precompile__(false)

module StaticCompiler
using InteractiveUtils
using GPUCompiler: GPUCompiler
Expand All @@ -19,10 +21,10 @@ export static_code_llvm, static_code_typed, static_llvm_module, static_code_nati
export @device_override, @print_and_throw
export StaticTarget

include("quirks.jl")
include("interpreter.jl")
include("target.jl")
include("pointer_warning.jl")
include("quirks.jl")
include("dllexport.jl")

fix_name(f::Function) = fix_name(string(nameof(f)))
Expand Down Expand Up @@ -450,9 +452,9 @@ function static_llvm_module(f, tt, name=fix_name(f); demangle=true, target::Stat
if !demangle
name = "julia_"*name
end
job, kwargs = static_job(f, tt; name, target, kwargs...)
job, kwargs = static_job(f, tt; name, target, strip=true, only_entry=false, validate=false, libraries=false, kwargs...)
m = GPUCompiler.JuliaContext() do context
m, _ = GPUCompiler.codegen(:llvm, job; strip=true, only_entry=false, validate=false, libraries=false)
m, _ = GPUCompiler.compile(:llvm, job; kwargs...)
locate_pointers_and_runtime_calls(m)
m
end
Expand All @@ -467,17 +469,17 @@ function static_llvm_module(funcs::Union{Array,Tuple}; demangle=true, target::St
if !demangle
name_f = "julia_"*name_f
end
job, kwargs = static_job(f, tt; name = name_f, target, kwargs...)
mod,_ = GPUCompiler.codegen(:llvm, job; strip=true, only_entry=false, validate=false, libraries=false)
job, kwargs = static_job(f, tt; name = name_f, target, strip=true, only_entry=false, validate=false, libraries=false, kwargs...)
mod,_ = GPUCompiler.compile(:llvm, job; kwargs...)
if length(funcs) > 1
for func in funcs[2:end]
f,tt = func
name_f = fix_name(f)
if !demangle
name_f = "julia_"*name_f
end
job, kwargs = static_job(f, tt; name = name_f, target, kwargs...)
tmod,_ = GPUCompiler.codegen(:llvm, job; strip=true, only_entry=false, validate=false, libraries=false)
job, kwargs = static_job(f, tt; name = name_f, target, strip=true, only_entry=false, validate=false, libraries=false, kwargs...)
tmod,_ = GPUCompiler.compile(:llvm, job; kwargs...)
link!(mod,tmod)
end
end
Expand All @@ -493,10 +495,17 @@ function static_llvm_module(funcs::Union{Array,Tuple}; demangle=true, target::St
name!(modfunc,fname[d:end])
end
end
@dispose pb = NewPMPassBuilder(merge_functions=true) begin
add!(pb, NewPMModulePassManager()) do pass_manager
run!(pb, mod)
end
end
#=
LLVM.ModulePassManager() do pass_manager #remove duplicate functions
LLVM.merge_functions!(pass_manager)
LLVM.run!(pass_manager, mod)
end
=#
return mod
end

Expand Down Expand Up @@ -587,8 +596,16 @@ function generate_obj(funcs::Union{Array,Tuple}, path::String = tempname(), file
obj_path = joinpath(path, "$filenamebase.o")
obj = GPUCompiler.JuliaContext() do ctx
fakejob, _ = static_job(f, tt; target, kwargs...)
@static if VERSION < v"1.9"
obj, _ = GPUCompiler.emit_asm(fakejob, mod; strip=strip_asm, validate=false, format=LLVM.API.LLVMObjectFile)
obj
else
@static if pkgversion(GPUCompiler) < v"1.3.0"
obj, _ = GPUCompiler.emit_asm(fakejob, mod; strip=strip_asm, validate=false, format=LLVM.API.LLVMObjectFile)
else
obj, _ = GPUCompiler.emit_asm(fakejob, mod, LLVM.API.LLVMObjectFile)
end
end
obj
end
open(obj_path, "w") do io
write(io, obj)
Expand Down
100 changes: 56 additions & 44 deletions src/interpreter.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,68 @@
## interpreter

using Core.Compiler:
AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, MethodInstance, OptimizationParams, WorldView, get_world_counter
AbstractInterpreter, InferenceResult, InferenceParams, InferenceState, MethodInstance, OptimizationParams, WorldView
using GPUCompiler:
@safe_debug, AbstractCompilerParams, CodeCache, CompilerJob, methodinstance
@safe_debug, AbstractCompilerParams, CompilerJob, methodinstance, CodeInstance, inference_params, optimization_params, get_inference_world
using CodeInfoTools
using CodeInfoTools: resolve


const HAS_INTEGRATED_CACHE = GPUCompiler.HAS_INTEGRATED_CACHE
@static if HAS_INTEGRATED_CACHE
const CodeCache = Nothing

else
using GPUCompiler: CodeCache
end

# https://github.com/JuliaGPU/GPUCompiler.jl/src/jlgen.jl8#L322
# as from struct GPUInterpreter <: CC.AbstractInterpreter
struct StaticInterpreter <: AbstractInterpreter
global_cache::CodeCache
# The world age we're working inside of
world::UInt
method_table::Union{Nothing,Core.MethodTable}

@static if HAS_INTEGRATED_CACHE
token::Any
else
code_cache::CodeCache # global cache
end

# Cache of inference results for this particular interpreter
local_cache::Vector{InferenceResult}
# The world age we're working inside of
world::UInt

# Parameters for inference and optimization
inf_params::InferenceParams
opt_params::OptimizationParams

function StaticInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
# token_or_cache = token::Any, code_cache::CodeCache
function StaticInterpreter(world::UInt, mt::Union{Nothing,Core.MethodTable}, token_or_cache, ip::InferenceParams, op::OptimizationParams)
@assert world <= Base.get_world_counter()

return new(
cache,
mt,

# Initially empty cache
Vector{InferenceResult}(),

# world age counter
world,

# parameters for inference and optimization
ip,
op
)
# mt = get_method_table_view(world, mt)
local_cache = Vector{Core.Compiler.InferenceResult}() # Initially empty cache
return new(world, mt, token_or_cache, local_cache, ip, op)
end
end


Core.Compiler.InferenceParams(interp::StaticInterpreter) = interp.inf_params
Core.Compiler.OptimizationParams(interp::StaticInterpreter) = interp.opt_params
Core.Compiler.get_world_counter(interp::StaticInterpreter) = interp.world
# Core.Compiler.get_world_counter(interp::StaticInterpreter) = interp.world
GPUCompiler.get_inference_world(interp::StaticInterpreter) = interp.world
Core.Compiler.get_inference_cache(interp::StaticInterpreter) = interp.local_cache
Core.Compiler.code_cache(interp::StaticInterpreter) = WorldView(interp.global_cache, interp.world)
@static if HAS_INTEGRATED_CACHE
Core.Compiler.cache_owner(interp::StaticInterpreter) = interp.token
else
Core.Compiler.code_cache(interp::StaticInterpreter) = WorldView(interp.code_cache, interp.world)
end

# No need to do any locking since we're not putting our results into the runtime cache
Core.Compiler.lock_mi_inference(interp::StaticInterpreter, mi::MethodInstance) = nothing
Core.Compiler.unlock_mi_inference(interp::StaticInterpreter, mi::MethodInstance) = nothing

function Core.Compiler.add_remark!(interp::StaticInterpreter, sv::InferenceState, msg)
@safe_debug "Inference remark during static compilation of $(sv.linfo): $msg"
@safe_debug "Inference remark during static compilation of $(sv.linfo): $msg"
end


#####
##### Pre-inference
#####
Expand All @@ -77,24 +84,28 @@ function custom_pass!(interp::StaticInterpreter, result::InferenceResult, mi::Co
end

function Core.Compiler.InferenceState(result::InferenceResult, cache::Symbol, interp::StaticInterpreter)
world = get_world_counter(interp)
src = @static if VERSION >= v"1.10.0-DEV.873"
world = get_inference_world(interp)
src = @static if VERSION >= v"1.10.0-DEV.873"
Core.Compiler.retrieve_code_info(result.linfo, world)
else
Core.Compiler.retrieve_code_info(result.linfo)
Core.Compiler.retrieve_code_info(result.linfo)
end
mi = result.linfo
src = custom_pass!(interp, result, mi, src)
src === nothing && return nothing
Core.Compiler.validate_code_in_debug_mode(result.linfo, src, "lowered")
src === nothing && return @static if VERSION < v"1.11"
Core.Compiler.validate_code_in_debug_mode(result.linfo, src, "lowered")
else
Core.Compiler.maybe_validate_code(result.linfo, src, "lowered")
end
return InferenceState(result, src, cache, interp)
end

Core.Compiler.may_optimize(interp::StaticInterpreter) = true
Core.Compiler.may_compress(interp::StaticInterpreter) = true
Core.Compiler.may_discard_trees(interp::StaticInterpreter) = true
Core.Compiler.verbose_stmt_info(interp::StaticInterpreter) = false

if isdefined(Core.Compiler, :verbose_stmt_inf)
Core.Compiler.verbose_stmt_info(interp::StaticInterpreter) = false
end

if isdefined(Base.Experimental, Symbol("@overlay"))
using Core.Compiler: OverlayMethodTable
Expand All @@ -112,13 +123,13 @@ end

# semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349)
@static if VERSION >= v"1.9.0-DEV.1248"
function Core.Compiler.concrete_eval_eligible(interp::StaticInterpreter,
@nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret === false && return nothing
return ret
end
function Core.Compiler.concrete_eval_eligible(interp::StaticInterpreter,
@nospecialize(f), result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret = @invoke Core.Compiler.concrete_eval_eligible(interp::AbstractInterpreter,
f::Any, result::Core.Compiler.MethodCallResult, arginfo::Core.Compiler.ArgInfo)
ret === false && return nothing
return ret
end
end

struct StaticCompilerParams <: AbstractCompilerParams
Expand All @@ -127,8 +138,9 @@ struct StaticCompilerParams <: AbstractCompilerParams
cache::CodeCache
end

function StaticCompilerParams(; opt = false,
optlevel = Base.JLOptions().opt_level,
cache = CodeCache())
function StaticCompilerParams(; opt=false,
optlevel=Base.JLOptions().opt_level,
cache=CodeCache()
)
return StaticCompilerParams(opt, optlevel, cache)
end
3 changes: 2 additions & 1 deletion src/pointer_warning.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ function locate_pointers_and_runtime_calls(mod)
end
end
if warned
lines = split(string(func),"\n")
@warn("LLVM function generated warnings due to raw pointers embedded in the code. This will likely cause errors or undefined behaviour.",
func = func)
func = join(lines[1:min(20, end)], "\n")) # just print the first 20 lines
end
end
end
Expand Down
44 changes: 39 additions & 5 deletions src/quirks.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
libcexit(x::Int32) = @symbolcall exit(x::Int32)::Nothing
@static if isdefined(Base.Experimental, Symbol("@overlay"))
Base.Experimental.@MethodTable(method_table)
Base.Experimental.@MethodTable(empty_table)
else
const method_table = nothing
end

"""
```julia
@device_override old_bad_method(arg1::Type1, arg2::Type2) = new_good_method(arg1, arg2)
```
Override a non-static-compilable method (e.g. `old_bad_method(::Type1, ::Type2)`)
with a more compileable replacement.
### Examples
```
@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
@print_and_throw c"Inexact conversion"
```
"""
macro device_override(ex)
ex = macroexpand(__module__, ex)
if Meta.isexpr(ex, :call)
@show ex = eval(ex)
error()
end
code = quote
$Base.Experimental.@overlay($StaticCompiler.method_table, $ex)
end
return esc(code)
end

macro print_and_throw(err)
quote
println($err)
printf($err)
libcexit(Int32(1))
end
end
libcexit(x::Int32) = @symbolcall exit(x::Int32)::Nothing

# math.jl
@device_override @noinline Base.Math.throw_complex_domainerror(f::Symbol, x) =
Expand Down Expand Up @@ -37,9 +68,12 @@ end
@device_override @noinline Core.throw_inexacterror(f::Symbol, ::Type{T}, val) where {T} =
@print_and_throw c"Inexact conversion"

# abstractarray.jl
@device_override @noinline Base.throw_boundserror(A, I) =
@print_and_throw c"Out-of-bounds array access"
# abstractarray.jl
# Base.throw_boundserror is removed since v1.11
if VERSION < v"1.11"
@device_override @noinline Base.throw_boundserror(A, I) =
@print_and_throw c"Out-of-bounds array access"
end

# trig.jl
@device_override @noinline Base.Math.sincos_domain_error(x) =
Expand Down
Loading
Loading