11module m_cuda_backend
2+ use iso_fortran_env, only: stderr = > error_unit
23 use cudafor
34
45 use m_allocator, only: allocator_t, field_t
56 use m_base_backend, only: base_backend_t
6- use m_common, only: dp, globs_t
7+ use m_common, only: dp, globs_t, RDR_X2Y, RDR_X2Z, RDR_Y2X, RDR_Y2Z, RDR_Z2Y
78 use m_tdsops, only: dirps_t, tdsops_t
89
910 use m_cuda_allocator, only: cuda_allocator_t, cuda_field_t
@@ -12,6 +13,8 @@ module m_cuda_backend
1213 use m_cuda_sendrecv, only: sendrecv_fields, sendrecv_3fields
1314 use m_cuda_tdsops, only: cuda_tdsops_t
1415 use m_cuda_kernels_dist, only: transeq_3fused_dist, transeq_3fused_subs
16+ use m_cuda_kernels_reorder, only: reorder_x2y, reorder_x2z, reorder_y2x, &
17+ reorder_y2z, reorder_z2y
1518
1619 implicit none
1720
@@ -32,11 +35,7 @@ module m_cuda_backend
3235 procedure :: transeq_y = > transeq_y_cuda
3336 procedure :: transeq_z = > transeq_z_cuda
3437 procedure :: tds_solve = > tds_solve_cuda
35- procedure :: trans_x2y = > trans_x2y_cuda
36- procedure :: trans_x2z = > trans_x2z_cuda
37- procedure :: trans_y2z = > trans_y2z_cuda
38- procedure :: trans_z2y = > trans_z2y_cuda
39- procedure :: trans_y2x = > trans_y2x_cuda
38+ procedure :: reorder = > reorder_cuda
4039 procedure :: sum_yzintox = > sum_yzintox_cuda
4140 procedure :: vecadd = > vecadd_cuda
4241 procedure :: set_fields = > set_fields_cuda
@@ -74,6 +73,10 @@ function init(globs, allocator) result(backend)
7473 backend% zthreads = dim3(SZ, 1 , 1 )
7574 backend% zblocks = dim3(globs% n_groups_z, 1 , 1 )
7675
76+ backend% nx_loc = globs% nx_loc
77+ backend% ny_loc = globs% ny_loc
78+ backend% nz_loc = globs% nz_loc
79+
7780 n_halo = 4
7881 n_block = globs% n_groups_x
7982
@@ -415,50 +418,48 @@ subroutine tds_solve_dist(self, du, u, dirps, tdsops, blocks, threads)
415418
416419 end subroutine tds_solve_dist
417420
418- subroutine trans_x2y_cuda (self , u_y , v_y , w_y , u , v , w )
419- implicit none
420-
421- class(cuda_backend_t) :: self
422- class(field_t), intent (inout ) :: u_y, v_y, w_y
423- class(field_t), intent (in ) :: u, v, w
424-
425- end subroutine trans_x2y_cuda
426-
427- subroutine trans_x2z_cuda (self , u_z , v_z , w_z , u , v , w )
421+ subroutine reorder_cuda (self , u_o , u_i , direction )
428422 implicit none
429423
430424 class(cuda_backend_t) :: self
431- class(field_t), intent (inout ) :: u_z, v_z, w_z
432- class(field_t), intent (in ) :: u, v, w
433-
434- end subroutine trans_x2z_cuda
425+ class(field_t), intent (inout ) :: u_o
426+ class(field_t), intent (in ) :: u_i
427+ integer , intent (in ) :: direction
435428
436- subroutine trans_y2z_cuda (self , u_z , u_y )
437- implicit none
438-
439- class(cuda_backend_t) :: self
440- class(field_t), intent (inout ) :: u_z
441- class(field_t), intent (in ) :: u_y
442-
443- end subroutine trans_y2z_cuda
444-
445- subroutine trans_z2y_cuda (self , u_y , u_z )
446- implicit none
447-
448- class(cuda_backend_t) :: self
449- class(field_t), intent (inout ) :: u_y
450- class(field_t), intent (in ) :: u_z
451-
452- end subroutine trans_z2y_cuda
453-
454- subroutine trans_y2x_cuda (self , u_x , u_y )
455- implicit none
429+ real (dp), device, pointer , dimension (:, :, :) :: u_o_d, u_i_d
430+ type (dim3) :: blocks, threads
456431
457- class(cuda_backend_t) :: self
458- class(field_t), intent (inout ) :: u_x
459- class(field_t), intent (in ) :: u_y
432+ select type (u_o); type is (cuda_field_t); u_o_d = > u_o% data_d; end select
433+ select type (u_i); type is (cuda_field_t); u_i_d = > u_i% data_d; end select
434+
435+ select case (direction)
436+ case (RDR_X2Y) ! x2y
437+ blocks = dim3(self% nx_loc/ SZ, self% nz_loc, self% ny_loc/ SZ)
438+ threads = dim3(SZ, SZ, 1 )
439+ call reorder_x2y<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
440+ case (RDR_X2Z) ! x2z
441+ blocks = dim3(self% nx_loc, self% ny_loc/ SZ, 1 )
442+ threads = dim3(SZ, 1 , 1 )
443+ call reorder_x2z<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
444+ case (RDR_Y2X) ! y2x
445+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
446+ threads = dim3(SZ, SZ, 1 )
447+ call reorder_y2x<<<blocks, threads>>>(u_o_d, u_i_d, self% nz_loc)
448+ case (RDR_Y2Z) ! y2z
449+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
450+ threads = dim3(SZ, SZ, 1 )
451+ call reorder_y2z<<<blocks, threads>>>(u_o_d, u_i_d, &
452+ self% nx_loc, self% nz_loc)
453+ case (RDR_Z2Y) ! z2y
454+ blocks = dim3(self% nx_loc/ SZ, self% ny_loc/ SZ, self% nz_loc)
455+ threads = dim3(SZ, SZ, 1 )
456+ call reorder_z2y<<<blocks, threads>>>(u_o_d, u_i_d, &
457+ self% nx_loc, self% nz_loc)
458+ case default
459+ error stop ' Reorder direction is undefined.'
460+ end select
460461
461- end subroutine trans_y2x_cuda
462+ end subroutine reorder_cuda
462463
463464 subroutine sum_yzintox_cuda (self , du , dv , dw , &
464465 du_y , dv_y , dw_y , du_z , dv_z , dw_z )
0 commit comments