@@ -7,9 +7,9 @@ Support for heterogeneous computing is currently being worked on.
77[ Adapt.jl] ( https://github.com/JuliaGPU/Adapt.jl ) is a package in the
88[ JuliaGPU] ( https://github.com/JuliaGPU ) family that allows for
99the translation of nested data structures. The primary goal is to allow the substitution of ` Array `
10- at the storage leaves with a GPU array like ` CuArray ` from [ CUDA.jl] ( https://github.com/JuliaGPU/CUDA.jl ) .
10+ at the storage level with a GPU array like ` CuArray ` from [ CUDA.jl] ( https://github.com/JuliaGPU/CUDA.jl ) .
1111
12- To facilitate this data structures must be parameterized, so instead of:
12+ To facilitate this, data structures must be parameterized, so instead of:
1313
1414``` julia
1515struct Container <: Trixi.AbstractContainer
@@ -98,4 +98,65 @@ Container{CuArray{Float32, 1, CUDA.DeviceMemory}}(Float32[0.0, 0.0, 0.0])
9898```
9999
100100!!! note
101- ` adapt(Array{Float32}, C) ` is tempting but will do the wrong thing in the presence of ` StaticArrays ` .
101+ ` adapt(Array{Float32}, C) ` is tempting but will do the wrong thing in the presence of ` StaticArrays ` .
102+
103+
104+ ## Writing GPU kernels
105+
106+ Offloading computations to the GPU is done with
107+ [ KernelAbstractions.jl] ( https://github.com/JuliaGPU/KernelAbstractions.jl ) ,
108+ allowing for vendor-agnostic GPU code.
109+
110+ ### Example
111+
112+ Given the following Trixi.jl code, which would typically be called from within ` rhs! ` :
113+
114+ ``` julia
115+ function trixi_rhs_fct (mesh, equations, solver, cache, args)
116+ @threaded for element in eachelement (solver, cache)
117+ # code
118+ end
119+ end
120+ ```
121+
122+ 1 . Put the inner code in a new function ` rhs_fct_per_element ` . Besides the index
123+ ` element ` , pass all required fields as arguments, but make sure to ` @unpack ` them from
124+ their structs in adavance.
125+
126+ 2 . Where ` trixi_rhs_fct ` is called, get the backend, i.e. the hardware we are currently
127+ running on via ` trixi_backend(x) ` .
128+ This will, e.g., work with ` u_ode ` . Internally, ` KernelAbstractions.jl ` 's ` get_backend `
129+ will be called, i.e. ` KernelAbstractions.jl ` has to know the type of ` x ` .
130+
131+ ``` julia
132+ backend = trixi_backend (u_ode)
133+ ```
134+
135+ 3. Add a new argument ` backend` to ` trixi_rhs_fct` used for dispatch.
136+ When ` backend` is ` nothing` , the legacy implementation should be used:
137+ ``` julia
138+ function trixi_rhs_fct(backend::Nothing, mesh, equations, solver, cache, args)
139+ @unpack unpacked_args = cache
140+ @threaded for element in eachelement(solver, cache)
141+ rhs_fct_per_element(element, unpacked_args, args)
142+ end
143+ end
144+ ```
145+
146+ 4. When ` backend` is a ` Backend` (a type defined by ` KernelAbstractions.jl` ), write a
147+ ` KernelAbstractions.jl` kernel:
148+ ``` julia
149+ function trixi_rhs_fct(backend::Backend, mesh, equations, solver, cache, args)
150+ nelements(solver, cache) == 0 && return nothing # return early when there are no elements
151+ @unpack unpacked_args = cache
152+ kernel! = rhs_fct_kernel!(backend)
153+ kernel!(unpacked_args, args,
154+ ndrange = nelements(solver, cache))
155+ return nothing
156+ end
157+
158+ @kernel function rhs_fct_kernel!(unpacked_args, args)
159+ element = @index(Global)
160+ rhs_fct_per_element(element, unpacked_args, args)
161+ end
162+ ```
0 commit comments