Skip to content

[Bug] DataflowUseInplaceCalls corrupts concat input via in-place buffer reuse #19577

@wuyii8941

Description

@wuyii8941

Summary

relax.transform.DataflowUseInplaceCalls() produces incorrect results when a value is used both as an input to a binary op (multiply) and as a later input to concat. The pass incorrectly allows the multiply to write in-place to the shared buffer, corrupting the value before concat reads it.

Minimal Reproducer

import numpy as np
import tvm
from tvm import relax
import tvm.relax.op as R
from tvm.relax.transform import LegalizeOps

bb = relax.BlockBuilder()
x = relax.Var("x", relax.TensorStructInfo((4,), "float32"))
with bb.function("main", [x]):
    with bb.dataflow():
        a = bb.emit(R.expand_dims(x, axis=1))      # (4,1)
        b = bb.emit(R.expand_dims(x, axis=1))      # (4,1)
        prod = bb.emit(R.multiply(a, b))             # (4,1) = x^2
        out = bb.emit(R.concat([prod, b], axis=1))   # (4,2), expected [x^2, x]
        gv = bb.emit_output(out)
    bb.emit_func_output(gv)
mod = bb.finalize()

x_np = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)

# Correct result (without pass):
# [[ 1,  1], [ 4,  2], [ 9,  3], [16,  4]]

# Incorrect result (with pass):
# [[ 1,  1], [ 4,  4], [ 9,  9], [16, 16]]
# Column 1 should be x=[1,2,3,4] but gets x^2=[1,4,9,16]

Expected Behavior

concat([prod, b], axis=1) should produce [[x^2, x]] for each row. Column 0 = x^2, Column 1 = x.

Actual Behavior

After DataflowUseInplaceCalls, column 1 also contains x^2. The multiply operation writes in-place to b's buffer, so when concat reads b, it sees the overwritten x^2 values instead of the original x values.

Root Cause

The liveness analysis in DataflowUseInplaceCalls fails to detect that b is live at the point where multiply(a, b) executes, because b is also needed later by concat. The pass incorrectly determines that b's buffer can be reused for the multiply output.

The pattern that triggers this is:

  1. Two values a and b derived from the same input (both are expand_dims(x))
  2. prod = multiply(a, b) — the pass allows in-place write to b's storage
  3. concat([prod, b]) — reads b after it was overwritten

Impact

This is a silent correctness bug — no crash, no warning, just wrong numerical results. It affects any model where:

  • A value is used as input to both a binary op and a later concat/use
  • DataflowUseInplaceCalls is in the optimization pipeline (it's included in standard pipelines)

Environment

  • TVM version: 0.24.dev0
  • Target: llvm (CPU)

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions