Skip to content

[bug] Implicit sequence state mapping swaps states when output_name lexicographic order differs from input_name order #8586

@Dan1aR

Description

@Dan1aR

When using sequence batching + implicit state with multiple state tensors, Triton can swap cached states across requests if the lexicographic order of output_names differs from the lexicographic order of input_names. This manifests as Triton injecting the wrong state tensor into the wrong model input on the 2nd request of a sequence.

A concrete failing pair is:

  • cache_last_channel (FP16, 4D)
  • cache_last_channel_len (INT64 or INT32, 2D)

With outputs:

  • cache_last_channel_next
  • cache_last_channel_len_next

On request number 2, Triton injects the len state into cache_last_channel and the channel state into cache_last_channel_len, producing a TensorRT binding dimension error and a client-side “invalid shape for input” failure.

Workaround: Renaming the “_len” state to avoid being a prefix-extension of the other state name (e.g., cache_last_chan_len) fixes the issue in both Triton 25.04 and 25.12.

Strong suspected root cause (core): SequenceStates::OutputState() pairs output-state to input-state by index in two different std::maps (distance/advance), rather than by the configured (input_name, output_name) mapping. That index-based mapping is incorrect when the two maps’ key ordering differs - exactly what happens with names like cache_last_channel_next vs cache_last_channel_len_next.


Environment

Host / HW

  • GPU: NVIDIA L40S
  • Driver: 535.274.02
  • CUDA: 12.2

Triton containers tested

  • nvcr.io/nvidia/tritonserver:25.04-py3
    • Triton Server: 2.57.0
    • TensorRT: 10.9.0
  • nvcr.io/nvidia/tritonserver:25.12-py3
    • Triton Server: 2.64.0
    • TensorRT: 10.14.1

Client

  • Python tritonclient gRPC
  • Uses sequence_id with sequence_start / sequence_end

Model

  • TensorRT plan built from ONNX export
  • Implicit state via sequence_batching { state [...] }

Model state interface

State inputs (implicit)

  • cache_last_time: FP16, (B, 18, 384, 30)
  • cache_last_channel: FP16, (B, 32, 8, 50)
  • cache_last_channel_len: INT32 or INT64, (B, 1)

State outputs (implicit)

  • cache_last_time_next
  • cache_last_channel_next
  • cache_last_channel_len_next

Triton model config snippet (relevant part)

sequence_batching {
  max_sequence_idle_microseconds: 15000000
  oldest { max_queue_delay_microseconds: 10000 max_candidate_sequences: 4096 }

  state: [
    {
      input_name: "cache_last_channel_len"
      output_name: "cache_last_channel_len_next"
      data_type: TYPE_INT64  # also reproduced with INT32
      dims: [ 1 ]
      initial_state: { name: "cache_last_channel_len_initial" data_type: TYPE_INT64 dims: [ 1 ] zero_data: true }
    },
    {
      input_name: "cache_last_channel"
      output_name: "cache_last_channel_next"
      data_type: TYPE_FP16
      dims: [ 32, 8, 50 ]
      initial_state: { name: "cache_last_channel_initial" data_type: TYPE_FP16 dims: [ 32, 8, 50 ] zero_data: true }
    },
    {
      input_name: "cache_last_time"
      output_name: "cache_last_time_next"
      data_type: TYPE_FP16
      dims: [ 18, 384, 30 ]
      initial_state: { name: "cache_last_time_initial" data_type: TYPE_FP16 dims: [ 18, 384, 30 ] zero_data: true }
    }
  ]
}

Steps to reproduce

  1. Export ONNX that includes the three states (time/channel/channel_len).
  2. Build TensorRT engine from ONNX (trtexec).

I could provide resources if needed

  1. Start Triton with verbose logs enabled, e.g.:

    tritonserver --model-repository=/models --log-verbose=4
  2. Run a minimal client that:

    • sends only the non-state inputs (audio_signal, length)
    • uses sequence_id
    • sets sequence_start=True on step 0
    • requests output state tensors (cache_last_*_next) to confirm correctness
  3. Observe:

    • step 0 succeeds and returns correct shapes/dtypes for all outputs
    • step 1 fails with an invalid shape error for cache_last_channel (4D expected, 2D provided)

Expected behavior

Across requests in the same sequence, Triton should:

  • feed cache_last_channel_next from request i into cache_last_channel for request i+1
  • feed cache_last_channel_len_next from request i into cache_last_channel_len for request i+1
  • never swap or alias state buffers between different state entries

Actual behavior

  • Request 0 (step 0) succeeds, state outputs look correct.

  • Request 1 (step 1) fails.

  • Client error:

    [StatusCode.INTERNAL] request specifies invalid shape for input 'cache_last_channel' ...
    model expected 4 dimensions but received 2 dimensions
    
  • Server error (25.12):

    instance_state.cc:767] "error setting the binding dimension"
    
  • High-verbosity logs show that on the failing request, the injected override inputs are swapped:

    • cache_last_channel becomes INT64/INT32 with shape [1,1]
    • cache_last_channel_len becomes FP16 with shape [1,32,8,50]

Evidence (log excerpt showing the swap on the 2nd request)

On request 2 (step 1), Triton logs show:

added input override for cache_last_channel:
  input: cache_last_channel, type: INT64, original shape: [1,1], ...

added input override for cache_last_channel_len:
  input: cache_last_channel_len, type: FP16, original shape: [1,32,8,50], ...
...
E... instance_state.cc:767] "error setting the binding dimension"

This is consistent with the model expecting cache_last_channel to be 4D FP16, but receiving the 2D integer length state instead.


Key observation / workaround

Renaming the length state to avoid a prefix extension of another state name resolves the issue:

  • Fails: cache_last_channel + cache_last_channel_len
  • Works: cache_last_channel + cache_last_chan_len (or any non-prefix-colliding name)

This fixes the issue in both:

  • 25.04-py3 (Triton 2.57.0)
  • 25.12-py3 (Triton 2.64.0)

Notably, changing the length dtype (INT32 ↔ INT64) does not fix the issue. The bug appears related to name-based ordering/mapping, not dtype.


Suspected root cause (core r25.12): index-based pairing across std::map

https://github.com/triton-inference-server/core/blob/r25.12/src/sequence_state.cc#L309
In triton-inference-server/core (branch r25.12), the suspected problem is in sequence_state.cc:

  • SequenceStates::OutputState() finds an output state by name in output_states_ (std::map keyed by output_name).
  • It then computes an index via std::distance(output_states_.begin(), it) and uses that index to select the “corresponding” input state from input_states_ by advancing an iterator by the same amount.
  • But input_states_ is a different std::map, keyed by input_name.

This implicitly assumes that input_states_ and output_states_ have identical key ordering. That is not generally true and fails deterministically when the sorted key order differs.

Why the prefix-collision example deterministically swaps

Given:

Input keys (sorted by std::map):

  1. cache_last_channel
  2. cache_last_channel_len
  3. cache_last_time

Output keys (sorted by std::map):

  1. cache_last_channel_len_next
  2. cache_last_channel_next
  3. cache_last_time_next

Because "cache_last_channel_len_next" < "cache_last_channel_next" lexicographically, indices don’t align.

So:

  • OutputState("cache_last_channel_next") is index 1 in output_states_
    → picks index 1 in input_states_
    maps to cache_last_channel_len (wrong)
  • OutputState("cache_last_channel_len_next") is index 0 in output_states_
    → picks index 0 in input_states_
    maps to cache_last_channel (wrong)

This exactly matches the observed swap in the logs.


Additional notes

  • Repro is independent of integer dtype (INT32 and INT64 both reproduce).
  • Repro is present in both 25.04 and 25.12 containers.
  • Renaming the state to avoid prefix/lexicographic collision is an effective workaround but not acceptable as a general requirement.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions