Skip to content

Commit c379f7c

Browse files
authored
[MLIR][XeGPU] Add integration with XeGPU load / store ops to / from memref subview. (llvm#170385)
Add XeGPU integration test for missing usage case: base memory from memref subview.
1 parent 70dd63b commit c379f7c

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
// RUN: mlir-opt %s --gpu-lower-to-xevm-pipeline="xegpu-op-level=lane" \
2+
// RUN: | mlir-runner \
3+
// RUN: --shared-libs=%mlir_levelzero_runtime \
4+
// RUN: --shared-libs=%mlir_runner_utils \
5+
// RUN: --entry-point-result=void \
6+
// RUN: | FileCheck %s
7+
8+
module @subview attributes {gpu.container_module} {
9+
gpu.module @kernel {
10+
gpu.func @subview(%src: memref<256xf32>, %dst: memref<256xf32>) kernel {
11+
%src_subview = memref.subview %src[5] [251] [1] : memref<256xf32> to memref<251xf32, strided<[1], offset: 5>>
12+
%dst_subview = memref.subview %dst[10] [246] [1] : memref<256xf32> to memref<246xf32, strided<[1], offset: 10>>
13+
%lane_id = gpu.lane_id
14+
%mask = arith.constant 1 : i1
15+
%loaded = xegpu.load %src_subview[%lane_id], %mask : memref<251xf32, strided<[1], offset: 5>>, index, i1 -> f32
16+
xegpu.store %loaded, %dst_subview[%lane_id], %mask : f32, memref<246xf32, strided<[1], offset: 10>>, index, i1
17+
gpu.return
18+
}
19+
}
20+
func.func @test(%src: memref<256xf32>, %dst: memref<256xf32>) -> memref<256xf32> {
21+
%memref_src = gpu.alloc () : memref<256xf32>
22+
gpu.memcpy %memref_src, %src : memref<256xf32>, memref<256xf32>
23+
%memref_dst = gpu.alloc () : memref<256xf32>
24+
gpu.memcpy %memref_dst, %dst : memref<256xf32>, memref<256xf32>
25+
%c1 = arith.constant 1 : index
26+
%c16 = arith.constant 16 : index
27+
gpu.launch_func @kernel::@subview blocks in (%c1, %c1, %c1) threads in (%c16, %c1, %c1) args(%memref_src : memref<256xf32>, %memref_dst : memref<256xf32>)
28+
gpu.wait // Wait for the kernel to finish.
29+
gpu.memcpy %dst, %memref_dst : memref<256xf32>, memref<256xf32>
30+
gpu.dealloc %memref_src : memref<256xf32>
31+
gpu.dealloc %memref_dst : memref<256xf32>
32+
return %dst : memref<256xf32>
33+
}
34+
func.func @main() {
35+
%c0 = arith.constant 0 : index
36+
%c1 = arith.constant 1 : index
37+
%c256 = arith.constant 256 : index
38+
%memref_src = memref.alloc() : memref<256xf32>
39+
%memref_dst = memref.alloc() : memref<256xf32>
40+
// Initialize source memref
41+
scf.for %i = %c0 to %c256 step %c1 {
42+
%val = arith.index_cast %i : index to i32
43+
%val_float = arith.sitofp %val : i32 to f32
44+
memref.store %val_float, %memref_src[%i] : memref<256xf32>
45+
}
46+
// Initialize destination memref to zero
47+
scf.for %i = %c0 to %c256 step %c1 {
48+
%zero = arith.constant 0.0 : f32
49+
memref.store %zero, %memref_dst[%i] : memref<256xf32>
50+
}
51+
// Call test function
52+
%gpu_result = call @test(%memref_src, %memref_dst) : (memref<256xf32>, memref<256xf32>) -> memref<256xf32>
53+
%gpu_result_casted = memref.cast %gpu_result : memref<256xf32> to memref<*xf32>
54+
// CHECK: Unranked Memref base@ = 0x{{[0-9a-f]+}}
55+
// CHECK: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
56+
call @printMemrefF32(%gpu_result_casted) : (memref<*xf32>) -> ()
57+
// Deallocate memrefs
58+
memref.dealloc %memref_src : memref<256xf32>
59+
memref.dealloc %memref_dst : memref<256xf32>
60+
return
61+
}
62+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
63+
}

0 commit comments

Comments
 (0)