Skip to content

Commit 98d5f0a

Browse files
committed
xegpu: add xegpu transform op python bindinds
1 parent 85e6478 commit 98d5f0a

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

mlir/lib/Dialect/XeGPU/TransformOps/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ add_mlir_dialect_library(MLIRXeGPUTransformOps
88
MLIRXeGPUTransformOpsIncGen
99

1010
LINK_LIBS PUBLIC
11+
MLIRXeGPUDialect
12+
MLIRXeGPUTransforms
1113
MLIRIR
1214
MLIRTransformDialect
1315
MLIRFuncDialect

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,15 @@ declare_mlir_dialect_extension_python_bindings(
303303
"../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
304304
)
305305

306+
declare_mlir_dialect_extension_python_bindings(
307+
ADD_TO_PARENT MLIRPythonSources.Dialects
308+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
309+
TD_FILE dialects/XeGPUTransformOps.td
310+
SOURCES
311+
dialects/transform/xegpu.py
312+
DIALECT_NAME transform
313+
EXTENSION_NAME xegpu_transform)
314+
306315
declare_mlir_dialect_python_bindings(
307316
ADD_TO_PARENT MLIRPythonSources.Dialects
308317
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===---- XeGPUTransformOps.td -----------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the Python bindings generator for the XeGPU transform ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
14+
#ifndef PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
15+
#define PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
16+
17+
include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td"
18+
19+
#endif // PYTHON_BINDINGS_XEGPU_TRANSFORM_OPS
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._xegpu_transform_ops_gen import *
6+
from .._xegpu_transform_ops_gen import _Dialect
7+
8+
try:
9+
from ...ir import *
10+
from ...dialects import transform
11+
from .._ods_common import _cext as _ods_cext
12+
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
13+
except ImportError as e:
14+
raise RuntimeError("Error loading imports from extension module") from e
15+
16+
from typing import Optional, Sequence, Union, overload
17+
18+
19+
@_ods_cext.register_operation(_Dialect, replace=True)
20+
class XeGPUSetDPASLayoutOp(XeGPUSetDPASLayoutOp):
21+
"""Specialization for XeGPUSetDPASLayoutOp class."""
22+
23+
def __init__(
24+
self,
25+
dpas_op: Union[Operation, Value],
26+
tile_index: Union[int, Attribute],
27+
sg_layout: Union[Sequence[int], Attribute],
28+
sg_data: Union[Sequence[int], Attribute],
29+
inst_data: Union[Sequence[int], Attribute],
30+
*,
31+
load_data: Optional[Union[Sequence[int], Attribute]] = None,
32+
loc=None,
33+
ip=None,
34+
):
35+
super().__init__(
36+
dpas_op,
37+
tile_index,
38+
sg_layout,
39+
sg_data,
40+
inst_data,
41+
loadData=load_data,
42+
loc=loc,
43+
ip=ip
44+
)
45+
46+
47+
@_ods_cext.register_operation(_Dialect, replace=True)
48+
class XeGPUInsertPrefetchOp(XeGPUInsertPrefetchOp):
49+
"""Specialization for XeGPUInsertPrefetchOp class."""
50+
51+
def __init__(
52+
self,
53+
dpas_op: Union[Operation, Value],
54+
loop_op: Union[Operation, Value],
55+
tile_index: Union[int, Attribute],
56+
sg_layout: Union[Sequence[int], Attribute],
57+
sg_data: Union[Sequence[int], Attribute],
58+
loc=None,
59+
ip=None,
60+
):
61+
# results = get_op_result_or_op_results(dpas_op, loop_op)
62+
transformed_dpas_type = transform.AnyOpType.get()
63+
transformed_loop_type = transform.AnyOpType.get()
64+
super().__init__(
65+
transformed_dpas_type,
66+
transformed_loop_type,
67+
_get_op_result_or_value(dpas_op),
68+
_get_op_result_or_value(loop_op),
69+
tile_index,
70+
sg_layout,
71+
sg_data,
72+
loc=loc,
73+
ip=ip
74+
)
75+
76+
77+
@_ods_cext.register_operation(_Dialect, replace=True)
78+
class XeGPUHoistDescOp(XeGPUHoistDescOp):
79+
"""Specialization for XeGPUHoistDescOp class."""
80+
81+
def __init__(
82+
self,
83+
loop_op: Union[Operation, Value],
84+
loc=None,
85+
ip=None,
86+
):
87+
transformed_loop_type = transform.AnyOpType.get()
88+
super().__init__(
89+
transformed_loop_type,
90+
_get_op_result_or_value(loop_op),
91+
loc=loc,
92+
ip=ip
93+
)
94+
95+
96+
@_ods_cext.register_operation(_Dialect, replace=True)
97+
class XeGPUSetGPULaunchThreadsOp(XeGPUSetGPULaunchThreadsOp):
98+
"""Specialization for XeGPUSetGPULaunchThreadsOp class."""
99+
100+
def __init__(
101+
self,
102+
launch_op: Union[Operation, Value],
103+
threads: Union[int, Attribute],
104+
loc=None,
105+
ip=None,
106+
):
107+
super().__init__(
108+
_get_op_result_or_value(launch_op),
109+
threads,
110+
loc=loc,
111+
ip=ip
112+
)

0 commit comments

Comments
 (0)