Skip to content

Commit 10aff4b

Browse files
benegeevchuravy
authored andcommitted
something on KA
1 parent fded2cc commit 10aff4b

File tree

1 file changed

+64
-3
lines changed

1 file changed

+64
-3
lines changed

docs/src/heterogeneous.md

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ Support for heterogeneous computing is currently being worked on.
77
[Adapt.jl](https://github.com/JuliaGPU/Adapt.jl) is a package in the
88
[JuliaGPU](https://github.com/JuliaGPU) family that allows for
99
the translation of nested data structures. The primary goal is to allow the substitution of `Array`
10-
at the storage leaves with a GPU array like `CuArray` from [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl).
10+
at the storage level with a GPU array like `CuArray` from [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl).
1111

12-
To facilitate this data structures must be parameterized, so instead of:
12+
To facilitate this, data structures must be parameterized, so instead of:
1313

1414
```julia
1515
struct Container <: Trixi.AbstractContainer
@@ -98,4 +98,65 @@ Container{CuArray{Float32, 1, CUDA.DeviceMemory}}(Float32[0.0, 0.0, 0.0])
9898
```
9999

100100
!!! note
101-
`adapt(Array{Float32}, C)` is tempting but will do the wrong thing in the presence of `StaticArrays`.
101+
`adapt(Array{Float32}, C)` is tempting but will do the wrong thing in the presence of `StaticArrays`.
102+
103+
104+
## Writing GPU kernels
105+
106+
Offloading computations to the GPU is done with
107+
[KernelAbstractions.jl](https://github.com/JuliaGPU/KernelAbstractions.jl),
108+
allowing for vendor-agnostic GPU code.
109+
110+
### Example
111+
112+
Given the following Trixi.jl code, which would typically be called from within `rhs!`:
113+
114+
```julia
115+
function trixi_rhs_fct(mesh, equations, solver, cache, args)
116+
@threaded for element in eachelement(solver, cache)
117+
# code
118+
end
119+
end
120+
```
121+
122+
1. Put the inner code in a new function `rhs_fct_per_element`. Besides the index
123+
`element`, pass all required fields as arguments, but make sure to `@unpack` them from
124+
their structs in adavance.
125+
126+
2. Where `trixi_rhs_fct` is called, get the backend, i.e. the hardware we are currently
127+
running on via `trixi_backend(x)`.
128+
This will, e.g., work with `u_ode`. Internally, `KernelAbstractions.jl`'s `get_backend`
129+
will be called, i.e. `KernelAbstractions.jl` has to know the type of `x`.
130+
131+
```julia
132+
backend = trixi_backend(u_ode)
133+
```
134+
135+
3. Add a new argument `backend` to `trixi_rhs_fct` used for dispatch.
136+
When `backend` is `nothing`, the legacy implementation should be used:
137+
```julia
138+
function trixi_rhs_fct(backend::Nothing, mesh, equations, solver, cache, args)
139+
@unpack unpacked_args = cache
140+
@threaded for element in eachelement(solver, cache)
141+
rhs_fct_per_element(element, unpacked_args, args)
142+
end
143+
end
144+
```
145+
146+
4. When `backend` is a `Backend` (a type defined by `KernelAbstractions.jl`), write a
147+
`KernelAbstractions.jl` kernel:
148+
```julia
149+
function trixi_rhs_fct(backend::Backend, mesh, equations, solver, cache, args)
150+
nelements(solver, cache) == 0 && return nothing # return early when there are no elements
151+
@unpack unpacked_args = cache
152+
kernel! = rhs_fct_kernel!(backend)
153+
kernel!(unpacked_args, args,
154+
ndrange = nelements(solver, cache))
155+
return nothing
156+
end
157+
158+
@kernel function rhs_fct_kernel!(unpacked_args, args)
159+
element = @index(Global)
160+
rhs_fct_per_element(element, unpacked_args, args)
161+
end
162+
```

0 commit comments

Comments
 (0)