Skip to content

[Discussion] How to efficiently handle branches in the axes dependency tree for sparse buffers.Β #69

@yzh119

Description

@yzh119

Problem

Our current format annotations do not efficiently handle branches in the axes dependency tree, for example, the [irregular batched-GEMM operator] has the following dependency trees:

"""
     B
   / | \
  I  J  K

X: (B, I, K)
Y: (B, K, J)
Z: (B, I, J)
"""

B = T.dense_fixed(batch_size, "int32")
I = T.dense_variable(B, (m, nnz_I), indptr_I, "int32")
J = T.dense_variable(B, (n, nnz_J), indptr_J, "int32")
K = T.dense_variable(B, (k, nnz_K), indptr_K, "int32")

X = T.match_sparse_buffer(x, (B, I, K), "float32")
Y = T.match_sparse_buffer(y, (B, K, J), "float32")
Z = T.match_sparse_buffer(z, (B, I, J), "float32")

with T.iter([B, I, J, K], "SSSR", "irregular-batched-gemm") as [b, i, j, k]:
    with T.init():
        Z[b, i, j] = T.float32(0)
    Z[b, i, j] = X[b, i, k] * Y[b, k, j] 

The efficient indexing of X/Y/Z requires auxiliary buffers such as indptr_IK, indptr_KJ and indptr_IJ. But currently, SparseTIR does not provide such an interface.

Proposals

Let take B: (B, I, K) as an example:

Alternative 1: Create a new axis IK that follows I to replace K

IK = T.dense_variable(I, ...)

# before lowering
X[i, k]

# after lowering:
x[indptr_ik[indptr_i[b] + i] + k]

Alternative 2: Insert a bridge axis IK that flattens I and K

IK = T.flatten([I, K], ...)

# before
X[i, k]

# after lowering
X[indptr_ik[b] + i * (indptr_k[b + 1] - indptr_k[b]) + k]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    TODO

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions