We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4cc9705 commit afa82d9Copy full SHA for afa82d9
thunder/executors/nvfuserex_impl.py
@@ -3268,11 +3268,12 @@ def cumsum_transform(
3268
mask = fd.ops.triu(mask)
3269
3270
out = fd.ops.matmul(nv_a, mask)
3271
- out = fd.ops.cast(out, out_dtype)
3272
else:
3273
out = fd.ops.cast(nv_a, out_dtype)
3274
if a.ndim >= 1:
3275
out = fd.ops.cumsum(out, dim)
+ # restore output dtype in case nvfuser cumsum does implicit type promotion
3276
+ out = fd.ops.cast(out, out_dtype)
3277
return out
3278
3279
0 commit comments