Skip to content

Commit 8b11bfd

Browse files
committed
xegpu: add tests transform op python bindings
1 parent a45730c commit 8b11bfd

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import xegpu
6+
from mlir.dialects.transform import structured
7+
8+
9+
def run(f):
10+
with Context(), Location.unknown():
11+
module = Module.create()
12+
with InsertionPoint(module.body):
13+
print("\nTEST:", f.__name__)
14+
f()
15+
print(module)
16+
return f
17+
18+
19+
@run
20+
def setOperandLayout():
21+
sequence = transform.SequenceOp(
22+
transform.FailurePropagationMode.Propagate,
23+
[],
24+
transform.OperationType.get("xegpu.dpas"),
25+
)
26+
with InsertionPoint(sequence.body):
27+
xegpu.SetOperandLayoutOp(
28+
sequence.bodyTarget,
29+
index=0,
30+
sg_layout=[6, 4],
31+
sg_data=[32, 16],
32+
inst_data=[8, 16]
33+
)
34+
transform.YieldOp()
35+
# CHECK-LABEL: TEST: setOperandLayout
36+
# CHECK: transform.xegpu.set_operand_layout %
37+
# CHECK: index = 0
38+
# CHECK: sg_layout = [6, 4]
39+
# CHECK: sg_data = [32, 16]
40+
# CHECK: inst_data = [8, 16]
41+
42+
43+
@run
44+
def insertPrefetch():
45+
sequence = transform.SequenceOp(
46+
transform.FailurePropagationMode.Propagate,
47+
[],
48+
transform.AnyOpType.get(),
49+
)
50+
with InsertionPoint(sequence.body):
51+
for_op = structured.MatchOp.match_op_names(sequence.bodyTarget, ["scf.for"])
52+
dpas_op = structured.MatchOp.match_op_names(for_op, ["xegpu.dpas"])
53+
xegpu.InsertPrefetchOp(
54+
dpas_op,
55+
for_op,
56+
index=0,
57+
sg_layout=[6, 4],
58+
sg_data=[32, 16],
59+
)
60+
transform.YieldOp()
61+
# CHECK-LABEL: TEST: insertPrefetch
62+
# CHECK: %[[FOR_OP:.*]] = transform.structured.match
63+
# CHECK: %[[DPAS_OP:.*]] = transform.structured.match
64+
# CHECK: transform.xegpu.insert_prefetch %[[DPAS_OP]] %[[FOR_OP]]
65+
# CHECK: index = 0
66+
# CHECK: sg_layout = [6, 4]
67+
# CHECK: sg_data = [32, 16]
68+
69+
70+
@run
71+
def hoistDescOp():
72+
sequence = transform.SequenceOp(
73+
transform.FailurePropagationMode.Propagate,
74+
[],
75+
transform.OperationType.get("scf.for"),
76+
)
77+
with InsertionPoint(sequence.body):
78+
xegpu.HoistDescOp(sequence.bodyTarget)
79+
transform.YieldOp()
80+
# CHECK-LABEL: TEST: hoistDescOp
81+
# CHECK: transform.xegpu.hoist_desc_ops
82+
83+
84+
@run
85+
def setGPULaunchThreadsOp():
86+
sequence = transform.SequenceOp(
87+
transform.FailurePropagationMode.Propagate,
88+
[],
89+
transform.OperationType.get("gpu.lauch"),
90+
)
91+
with InsertionPoint(sequence.body):
92+
xegpu.SetGPULaunchThreadsOp(
93+
sequence.bodyTarget,
94+
threads=[8, 4, 1]
95+
)
96+
transform.YieldOp()
97+
# CHECK-LABEL: TEST: setGPULaunchThreadsOp
98+
# CHECK: transform.xegpu.set_gpu_launch_threads
99+
# CHECK: threads = [8, 4, 1]

0 commit comments

Comments
 (0)