|
314 | 314 | function raw_copy!(c::AbstractContainer, from::Int, destination::Int) |
315 | 315 | raw_copy!(c, c, from, from, destination) |
316 | 316 | end |
| 317 | + |
| 318 | +# Trixi storage types must implement these two Adapt.jl methods |
| 319 | +function Adapt.adapt_structure(to, c::AbstractContainer) |
| 320 | + error("Interface: Must implement Adapt.adapt_structure(to, ::$(typeof(c)))") |
| 321 | +end |
| 322 | + |
| 323 | +function Adapt.parent_type(C::Type{<:AbstractContainer}) |
| 324 | + error("Interface: Must implement Adapt.parent_type(::Type{$C}") |
| 325 | +end |
| 326 | + |
| 327 | +function Adapt.unwrap_type(C::Type{<:AbstractContainer}) |
| 328 | + return Adapt.unwrap_type(Adapt.parent_type(C)) |
| 329 | +end |
| 330 | + |
| 331 | +# TODO: Upstream to Adapt |
| 332 | +function storage_type(x) |
| 333 | + return storage_type(typeof(x)) |
| 334 | +end |
| 335 | + |
| 336 | +function storage_type(T::Type) |
| 337 | + error("Interface: Must implement storage_type(::Type{$T}") |
| 338 | +end |
| 339 | + |
| 340 | +function storage_type(::Type{<:Array}) |
| 341 | + Array |
| 342 | +end |
| 343 | + |
| 344 | +function storage_type(C::Type{<:AbstractContainer}) |
| 345 | + return storage_type(Adapt.unwrap_type(C)) |
| 346 | +end |
| 347 | + |
| 348 | +# For some storage backends like CUDA.jl, empty arrays do seem to simply be |
| 349 | +# null pointers which can cause `unsafe_wrap` to fail when calling |
| 350 | +# Adapt.adapt (ArgumentError, see |
| 351 | +# https://github.com/JuliaGPU/CUDA.jl/blob/v5.4.2/src/array.jl#L212-L229). |
| 352 | +# To circumvent this, on length zero arrays this allocates |
| 353 | +# a separate empty array instead of wrapping. |
| 354 | +# However, since zero length arrays are not used in calculations, |
| 355 | +# it should be okay if the underlying storage vectors and wrapped arrays |
| 356 | +# are not the same as long as they are properly wrapped when `resize!`d etc. |
| 357 | +function unsafe_wrap_or_alloc(to, vector, size) |
| 358 | + if length(vector) == 0 |
| 359 | + return similar(vector, size) |
| 360 | + else |
| 361 | + return unsafe_wrap(to, pointer(vector), size) |
| 362 | + end |
| 363 | +end |
| 364 | + |
| 365 | +struct TrixiAdaptor{Storage, Real} end |
| 366 | + |
| 367 | +function trixi_adapt(storage, real, x) |
| 368 | + adapt(TrixiAdaptor{storage, real}(), x) |
| 369 | +end |
| 370 | + |
| 371 | +# Custom rules |
| 372 | +# 1. handling of StaticArrays |
| 373 | +function Adapt.adapt_storage(::TrixiAdaptor{<:Any, Real}, |
| 374 | + x::StaticArrays.StaticArray{S, T, N}) where {Real, S, T, N} |
| 375 | + StaticArrays.similar_type(x, Real)(x) |
| 376 | +end |
| 377 | + |
| 378 | +# 2. Handling of Arrays |
| 379 | +function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, |
| 380 | + x::AbstractArray{T}) where {Storage, Real, |
| 381 | + T <: AbstractFloat} |
| 382 | + adapt(Storage{Real}, x) |
| 383 | +end |
| 384 | + |
| 385 | +function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, |
| 386 | + x::AbstractArray{T}) where {Storage, Real, |
| 387 | + T <: StaticArrays.StaticArray} |
| 388 | + adapt(Storage{StaticArrays.similar_type(T, Real)}, x) |
| 389 | +end |
| 390 | + |
| 391 | +function Adapt.adapt_storage(::TrixiAdaptor{Storage, Real}, |
| 392 | + x::AbstractArray) where {Storage, Real} |
| 393 | + adapt(Storage, x) |
| 394 | +end |
| 395 | + |
| 396 | +# 3. TODO: Should we have a fallback? But that would imply implementing things for NamedTuple again |
| 397 | + |
| 398 | +function unsafe_wrap_or_alloc(::TrixiAdaptor{Storage}, vec, size) where {Storage} |
| 399 | + return unsafe_wrap_or_alloc(Storage, vec, size) |
| 400 | +end |
317 | 401 | end # @muladd |
0 commit comments