Skip to content

Commit 08e6244

Browse files
authored
[Wave] added cos op (iree-org#869)
1 parent 534047e commit 08e6244

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def tanh(src: "Register") -> "Register":
166166
...
167167

168168

169+
def cos(src: "Register") -> "Register":
170+
...
171+
172+
169173
def roundeven(src: "Register") -> "Register":
170174
...
171175

@@ -896,6 +900,7 @@ def infer_type(self):
896900
@define_interface_op("sin")
897901
@define_interface_op("tanh")
898902
@define_interface_op("tanh_approx")
903+
@define_interface_op("cos")
899904
@define_py_op(operator.neg)
900905
@define_py_op(operator.invert)
901906
@dataclass

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
broadcast,
5454
cast,
5555
conditional,
56+
cos,
5657
eq,
5758
exp2,
5859
extract,
@@ -943,6 +944,16 @@ def handle_sin(source: Value, options: WaveCompileOptions) -> OpResult:
943944
return sine_of_source
944945

945946

947+
@handle_unary_op(cos)
948+
def handle_cos(source: Value, options: WaveCompileOptions) -> OpResult:
949+
element_type = get_type_or_element_type(source.type)
950+
if _is_float_type(element_type):
951+
res = math_d.cos(source)
952+
else:
953+
raise ValidationError(f"Found unhandled operand type for cos: {element_type}")
954+
return res
955+
956+
946957
###############################################################################
947958
# Control Flow ops
948959
###############################################################################

lit_tests/kernel/wave/codegen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ def test(
896896
res = tkw.softsign(res, logit_cap=30.0, apply_scaling=True, head_dim=128)
897897
res = tkw.roundeven(res)
898898
res = tkw.sin(res)
899+
res = tkw.cos(res)
899900
tkw.write(res, a, elements_per_thread=4)
900901
tkw.write(res_b, b, elements_per_thread=4)
901902

@@ -947,6 +948,8 @@ def test(
947948

948949
# Tests sin
949950
# CHECK: %[[SIN:.+]] = math.sin %[[ROUNDEVEN]]
951+
# Tests cos
952+
# CHECK: %[[COS:.+]] = math.cos %[[SIN]]
950953

951954

952955
# Important to check lowering of scheduling/barrier ops.

0 commit comments

Comments
 (0)