Skip to content

Commit 534047e

Browse files
badgerbrochDrew Broch
andauthored
[Wave] Added atan2 operation (iree-org#867)
Added a Binary operation for computing atan2. Test added to lit_tests/kernel/wave/codegen.py --------- Co-authored-by: Drew Broch <badgerbroch@gmailcom>
1 parent 2bc1de3 commit 534047e

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def minimum(lhs: "Register", rhs: "Register") -> "Register":
182182
...
183183

184184

185+
def atan2(lhs: "Register", rhs: "Register") -> "Register":
186+
...
187+
188+
185189
def broadcast(
186190
arg: "Register", target_shape: Optional[Sequence[IndexExpr | int]] = None
187191
) -> "Register":
@@ -853,6 +857,7 @@ def infer_shape(self) -> Any:
853857
@define_py_op(operator.truediv)
854858
@define_interface_op("maximum")
855859
@define_interface_op("minimum")
860+
@define_interface_op("atan2")
856861
@dataclass
857862
class BinaryPyOp(BinaryOpBase, ABC):
858863
def infer_type(self):

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
lt,
6868
maximum,
6969
minimum,
70+
atan2,
7071
mma,
7172
permute,
7273
reciprocal,
@@ -650,6 +651,17 @@ def handle_minimum(lhs: Value, rhs: Value, options: WaveCompileOptions) -> OpRes
650651
return result
651652

652653

654+
@handle_binary_op(atan2)
655+
def handle_atan2(lhs: Value, rhs: Value, options: WaveCompileOptions) -> OpResult:
656+
element_type = get_type_or_element_type(lhs.type)
657+
658+
if _is_float_type(element_type):
659+
result = math_d.atan2(lhs, rhs, fastmath=get_fast_math_flags(options))
660+
else:
661+
raise ValidationError(f"Found unhandled operand type for atan2: {element_type}")
662+
return result
663+
664+
653665
###############################################################################
654666
# Unary math Ops
655667
###############################################################################

lit_tests/kernel/wave/codegen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,6 +1916,7 @@ def binary_lowerings(
19161916
res = res * a_reg
19171917
res = res / b_reg
19181918
res = tkw.minimum(a_reg, b_reg)
1919+
res = tkw.atan2(res, a_reg)
19191920
tkw.write(res, a, elements_per_thread=4)
19201921

19211922
binary_lowerings = wave_compile(get_wave_compile_options(), binary_lowerings)
@@ -1926,6 +1927,7 @@ def binary_lowerings(
19261927
# CHECK: %[[MUL:.+]] = arith.mulf %[[SUB]]
19271928
# CHECK: %[[DIV:.+]] = arith.divf %[[MUL]]
19281929
# CHECK: %[[MINIMUM:.+]] = arith.minimumf
1930+
# CHECK: %[[ATAN2:.+]] = math.atan2 %[[MINIMUM]]
19291931

19301932

19311933
@run_test

0 commit comments

Comments
 (0)