Skip to content

Commit d8a3312

Browse files
authored
[Wave] Introduce softsign kernel to replace tanh_approx (iree-org#829)
Add a new `softsign` kernel variant that replaces the previous `tanh_approx` implementation. Benchmarks show a 10~15% speedup over `tanh_approx`, but with a modest accuracy drop (e.g. Grok performance falls from ~83% to ~79%). This commit provides the softsign implementation alongside existing approximations for users to choose the appropriate tradeoff.
1 parent a91324c commit d8a3312

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

iree/turbine/kernel/ops/wave_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,15 @@ def abs(src: "Register") -> "Register":
149149
...
150150

151151

152+
def softsign(
153+
src: "Register",
154+
logit_cap: float = 30.0,
155+
apply_scaling: bool = False,
156+
head_dim: int = None,
157+
) -> "Register":
158+
...
159+
160+
152161
def tanh_approx(src: "Register") -> "Register":
153162
...
154163

@@ -900,6 +909,23 @@ def infer_type(self):
900909
self.type = src_type
901910

902911

912+
@define_interface_op("softsign")
913+
@dataclass
914+
class SoftsignOp(CustomOp, ABC):
915+
arg: fx.Node
916+
logit_cap: float = 30.0
917+
apply_scaling: bool = False
918+
head_dim: int = None
919+
920+
@property
921+
def indexing_dims(self) -> list[IndexSymbol]:
922+
return get_custom(self.arg).indexing_dims
923+
924+
def infer_type(self):
925+
src_type = get_custom(self.arg).type
926+
self.type = src_type
927+
928+
903929
@define_op("select")
904930
@dataclass
905931
class SelectOp(CustomOp):

iree/turbine/kernel/wave/codegen/handlers.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
tanh,
8686
tanh_approx,
8787
workgroup_barrier,
88+
softsign,
8889
)
8990
from ...compiler.base import CodegenError, ValidationError, NDEBUG
9091
from ...compiler.builder import IRProxyValue
@@ -842,6 +843,61 @@ def handle_tanh_approx(source: Value, options: WaveCompileOptions) -> OpResult:
842843
return result
843844

844845

846+
@handle_op(softsign)
847+
def handle_softsign(emitter: WaveEmitter, node: fx.Node) -> None:
848+
"""
849+
Implements softsign-like logit cap using reciprocal:
850+
logit = logit / (1 + abs(logit / cap))
851+
= logit * (1 / (1 + abs(logit * (1 / cap))))
852+
= logit * reciprocal(1 + abs(logit * reciprocal(cap)))
853+
"""
854+
try:
855+
src_arg, logit_cap, apply_scaling, head_dim = node.args
856+
except ValueError as e:
857+
raise ValidationError("Malformed arguments for softsign") from e
858+
859+
source = cast_py_value(emitter, src_arg).ir_value
860+
element_type = get_type_or_element_type(source.type)
861+
opts = emitter.options
862+
863+
# Compute effective cap
864+
if apply_scaling:
865+
if head_dim is None:
866+
raise ValidationError("`head_dim` must be provided if `apply_scaling=True`")
867+
eff_cap = logit_cap * (head_dim**-0.5)
868+
else:
869+
eff_cap = logit_cap
870+
871+
# # Constants
872+
873+
reci_cap = 1.0 / eff_cap
874+
reci_cap_const = get_constant_attr(reci_cap, element_type)
875+
one = arith_d.ConstantOp(
876+
source.type,
877+
DenseElementsAttr.get_splat(source.type, get_constant_attr(1.0, element_type)),
878+
)
879+
reciprocal_cap = arith_d.ConstantOp(
880+
source.type, DenseElementsAttr.get_splat(source.type, reci_cap_const)
881+
)
882+
883+
# scaled = logit * (1 / cap)
884+
scaled = arith_d.mulf(source, reciprocal_cap, fastmath=get_fast_math_flags(opts))
885+
886+
# abs_scaled = abs(logit * (1 / cap))
887+
abs_scaled = math_d.absf(scaled)
888+
889+
# denom = 1 + abs(...)
890+
denom = arith_d.addf(one, abs_scaled, fastmath=get_fast_math_flags(opts))
891+
892+
# reciprocal_denom = 1 / denom
893+
reciprocal_denom = arith_d.divf(one, denom, fastmath=get_fast_math_flags(opts))
894+
895+
# result = logit * (1 / denom)
896+
result = arith_d.mulf(source, reciprocal_denom, fastmath=get_fast_math_flags(opts))
897+
898+
emitter.bind_node_proxy(node, IRProxyValue(result))
899+
900+
845901
@handle_unary_op(tanh)
846902
def handle_tanh(source: Value, options: WaveCompileOptions) -> OpResult:
847903
element_type = get_type_or_element_type(source.type)

iree/turbine/kernel/wave/templates/extend_attention.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,13 @@ def first_loop(
280280
if logit_cap > 0:
281281
logit_cap_reg_inv = tkw.reciprocal(logit_cap_reg)
282282
x_j = logit_cap_reg * tkw.tanh_approx(x_j * logit_cap_reg_inv)
283+
# We could use tkw.softsign to provide ~10% performance improvement, but this will compromise accuracy.
284+
# x_j = logit_cap_reg * tkw.softsign(
285+
# x_j * logit_cap_reg_inv,
286+
# logit_cap=30.0,
287+
# apply_scaling=True,
288+
# head_dim=128,
289+
# )
283290
n_kv_index = tkw.self_index(N_KV, tkl.i32)
284291
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
285292
mask = tkw.broadcast(mask, target_shape=[N_Q, N_KV])
@@ -343,6 +350,13 @@ def second_loop(
343350
if logit_cap > 0:
344351
logit_cap_reg_inv = tkw.reciprocal(logit_cap_reg)
345352
x_j = logit_cap_reg * tkw.tanh_approx(x_j * logit_cap_reg_inv)
353+
# We could use tkw.softsign to provide ~10% performance improvement, but this will compromise accuracy.
354+
# x_j = logit_cap_reg * tkw.softsign(
355+
# x_j * logit_cap_reg_inv,
356+
# logit_cap=30.0,
357+
# apply_scaling=True,
358+
# head_dim=128,
359+
# )
346360
n_kv_index = tkw.self_index(N_KV, tkl.i32)
347361
mask = tkw.apply_expr(n_kv_index, lambda x: x < N_KV)
348362
mask = tkw.broadcast(mask, target_shape=[N_Q, N_KV])

lit_tests/kernel/wave/codegen.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,7 @@ def test(
893893
res_b = tkw.abs(b_reg)
894894
res = tkw.tanh(res)
895895
res = tkw.tanh_approx(res)
896+
res = tkw.softsign(res, logit_cap=30.0, apply_scaling=True, head_dim=128)
896897
res = tkw.roundeven(res)
897898
tkw.write(res, a, elements_per_thread=4)
898899
tkw.write(res_b, b, elements_per_thread=4)
@@ -931,8 +932,17 @@ def test(
931932
# CHECK: %[[R:.+]] = arith.addf %[[TEMP]], %[[RECIP]] : vector<4xf16>
932933
# CHECK: %[[TANH_APPROX:.+]] = math.copysign %[[R]], %[[TANH]] : vector<4xf16>
933934

935+
# Tests softsign
936+
# CHECK: %[[ONE:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16>
937+
# CHECK: %[[CAP:.+]] = arith.constant dense<3.771970e-01> : vector<4xf16>
938+
# CHECK: %[[SCALED:.+]] = arith.mulf %[[TANH_APPROX]], %[[CAP]] : vector<4xf16>
939+
# CHECK: %[[ABS2:.+]] = math.absf %[[SCALED]] : vector<4xf16>
940+
# CHECK: %[[ADD:.+]] = arith.addf %[[ONE]], %[[ABS2]] : vector<4xf16>
941+
# CHECK: %[[RECIP_DENOM:.+]] = arith.divf %[[ONE]], %[[ADD]] : vector<4xf16>
942+
# CHECK: %[[SOFTSIGN:.+]] = arith.mulf %[[TANH_APPROX]], %[[RECIP_DENOM]] : vector<4xf16>
943+
934944
# Tests roundeven
935-
# CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[TANH_APPROX]]
945+
# CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[SOFTSIGN]]
936946

937947

938948
# Important to check lowering of scheduling/barrier ops.

0 commit comments

Comments
 (0)