Skip to content

[Warp Specialization] Inefficient codegen for loop carried dependencies #7628

@Mogball

Description

@Mogball

When we have a loop where a loop carried argument and its next value are both used in another partition, the codegen is inefficient:

scf.for ... (%iv = %init) {
  %next = compute_next %iv {partition = 0}
  use_value(%iv, %next) {partition = 1}
  scf.yield %next
}

Previously, with the multiplicity optimization, RewritePartitionDependencies would generate

%iv_buf = local_alloc <2xshapexdtype>
local_store %init, %iv_buf[0]
scf.for ... (%iv = %init) {
  %next = compute_next %iv {partition = 0}
  wait_barrier {partition = 0}
  local_store %next, %iv_buf[(%i+1)%2] {partition = 0}
  arrive_barrier {partition = 0}

  wait_barrier {partition = 1}
  %iv_0 = local_load %iv_buf[%i] {partition = 1}
  arrive_barrier {partition = 1}
  wait_barrier {partition = 1}
  %iv_next = local_load %iv_buf[(%i+1)%2} {partition = 1}
  arrive_barrier {partition = 1}
  use_value(%iv_0, %iv_next) {partition = 1}

  scf.yield %next
}

Note that this is already not optimal because we could reverse the order of the local_loads in partition 1 and make the second wait_barrier a no-op (we know the previous value is ready if the current value is ready).

After removing the multiplicity optimization, we generate two arefs, one for each value that crosses from partition 0 to 1, without recognizing that it is a loop-carried value:

%iv_aref = aref_create
%next_aref = aref_create
scf.for ... (%iv = %init) {
  aref_put %iv, %iv_aref {partiiton = 0}

  %next = compute_next %iv {partition = 0}
  aref_put %next, %next_reaf {partition = 0}
  
  %iv_0 = aref_get %iv_aref
  %next_0 = aref_get %next_aref
  use_value(%iv_0, %next_0)

  scf.yield %next
}

This results in 2 aref writes and 2 aref reads.

Ideally, we can fully optimize this to be a single aref like this:

%iv_aref = aref_create
aref_put %init, %iv_aref {partition = 0}

%init_0 = aref_get %iv_aref {partition = 1}
scf.for ... (%iv = %init, %iv_0 = %init_0) {
  %next = compute_next %iv {partition = 0}
  aref_put %next, %iv_aref {partition = 1}

  %next_0 = aref_get %iv_aref {partition = 1}
  use_value(%iv_0, %next_0)

  scf.yield %next, %next_0
}

This results in 1 aref write and 1 aref read in the inner loop. We can also choose to multibuffer %iv_aref with 2 buffers to improve overlap between the partitions. This would require cloning the whole loop carried value across both partitions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions