Skip to content

[Tracking Issue] Support Axis Union/Intersection #72

@yzh119

Description

@yzh119

The Problem

Currently, SparseTIR does not support lowering code to co-iteration structure, whenever we want to add/multiply two sparse tensors/vectors, we need to create another axis to indicate the union/intersection of axes.

Here is an example of SpMSpV.

I = T.dense_fixed(m)
J = T.sparse_variable(I, (n, nnz), indptr_j, indices_j)
IV = T.dense_fixed(1)
JV = T.sparse_variable(IV, (n, nnz), indptr_jv, indices_jv)
J_and = T.sparse_variable(I, (n, nnz), indptr_j_and, indices_j_and)
A = T.match_sparse_buffer(a, (I, J))
B = T.match_sparse_buffer(b, (IV, JV))
with T.iter([I, J_and], "SR", "spmspv") as [i, j]:
    with T.init():
        C[i] = T.float32(0)
    C[i] = C[i] + A[i, j] * B[0, j]

SparseTIR would generate several binary blocks for indexing A and B because we do not have co-iterations yet, and we need mid arrays generated by binary search blocks to access A and B under the for-loop structure.

Once we support axis union/intersection and co-iteration structure generation, we can declare J_and as:

J_and = T.intersection([J, JV], indptr_j_and, indices_j_and)
J_or = T.union([J, JV], indptr_j_or, indices_j_or)

and sparse iterations on union/intersect axes can yield co-iteration structures in sparse iteration lowering pass.

Milestone

  • Support co-iteration structure (either w/ While construct, or create a new statement in TIR).
  • Support T.intersection/T.union, and possibly more general ones (consider SpGEMM).
  • Modify sparse iteration lowering pass.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requesthelp wantedExtra attention is needed

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions