Skip to content

Commit 2bc1de3

Browse files
authored
[Wave] Added sin op (iree-org#866)
Added sine op for tkw Signed-off-by: Sourish Wawdhane <[email protected]>
1 parent eaa895c commit 2bc1de3

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,10 @@ def roundeven(src: "Register") -> "Register":
170170
...
171171

172172

173+
def sin(src: "Register") -> "Register":
174+
...
175+
176+
173177
def maximum(lhs: "Register", rhs: "Register") -> "Register":
174178
...
175179

@@ -884,6 +888,7 @@ def infer_type(self):
884888
@define_interface_op("log2")
885889
@define_interface_op("reciprocal")
886890
@define_interface_op("roundeven")
891+
@define_interface_op("sin")
887892
@define_interface_op("tanh")
888893
@define_interface_op("tanh_approx")
889894
@define_py_op(operator.neg)

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
scalar,
7575
reshape,
7676
roundeven,
77+
sin,
7778
scheduling_barrier,
7879
scheduling_group_barrier,
7980
self_index,
@@ -920,6 +921,16 @@ def handle_roundeven(source: Value, options: WaveCompileOptions) -> OpResult:
920921
return roundeven
921922

922923

924+
@handle_unary_op(sin)
925+
def handle_sin(source: Value, options: WaveCompileOptions) -> OpResult:
926+
element_type = get_type_or_element_type(source.type)
927+
if _is_float_type(element_type):
928+
sine_of_source = math_d.sin(source)
929+
else:
930+
raise ValidationError(f"Found unhandled operand type for sine: {element_type}")
931+
return sine_of_source
932+
933+
923934
###############################################################################
924935
# Control Flow ops
925936
###############################################################################

lit_tests/kernel/wave/codegen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,7 @@ def test(
895895
res = tkw.tanh_approx(res)
896896
res = tkw.softsign(res, logit_cap=30.0, apply_scaling=True, head_dim=128)
897897
res = tkw.roundeven(res)
898+
res = tkw.sin(res)
898899
tkw.write(res, a, elements_per_thread=4)
899900
tkw.write(res_b, b, elements_per_thread=4)
900901

@@ -944,6 +945,9 @@ def test(
944945
# Tests roundeven
945946
# CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[SOFTSIGN]]
946947

948+
# Tests sin
949+
# CHECK: %[[SIN:.+]] = math.sin %[[ROUNDEVEN]]
950+
947951

948952
# Important to check lowering of scheduling/barrier ops.
949953
@run_test

0 commit comments

Comments
 (0)