-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Description
When using sequence batching + implicit state with multiple state tensors, Triton can swap cached states across requests if the lexicographic order of output_names differs from the lexicographic order of input_names. This manifests as Triton injecting the wrong state tensor into the wrong model input on the 2nd request of a sequence.
A concrete failing pair is:
cache_last_channel(FP16, 4D)cache_last_channel_len(INT64 or INT32, 2D)
With outputs:
cache_last_channel_nextcache_last_channel_len_next
On request number 2, Triton injects the len state into cache_last_channel and the channel state into cache_last_channel_len, producing a TensorRT binding dimension error and a client-side “invalid shape for input” failure.
Workaround: Renaming the “_len” state to avoid being a prefix-extension of the other state name (e.g., cache_last_chan_len) fixes the issue in both Triton 25.04 and 25.12.
Strong suspected root cause (core): SequenceStates::OutputState() pairs output-state to input-state by index in two different std::maps (distance/advance), rather than by the configured (input_name, output_name) mapping. That index-based mapping is incorrect when the two maps’ key ordering differs - exactly what happens with names like cache_last_channel_next vs cache_last_channel_len_next.
Environment
Host / HW
- GPU: NVIDIA L40S
- Driver: 535.274.02
- CUDA: 12.2
Triton containers tested
nvcr.io/nvidia/tritonserver:25.04-py3- Triton Server: 2.57.0
- TensorRT: 10.9.0
nvcr.io/nvidia/tritonserver:25.12-py3- Triton Server: 2.64.0
- TensorRT: 10.14.1
Client
- Python
tritonclientgRPC - Uses
sequence_idwithsequence_start/sequence_end
Model
- TensorRT plan built from ONNX export
- Implicit state via
sequence_batching { state [...] }
Model state interface
State inputs (implicit)
cache_last_time: FP16,(B, 18, 384, 30)cache_last_channel: FP16,(B, 32, 8, 50)cache_last_channel_len: INT32 or INT64,(B, 1)
State outputs (implicit)
cache_last_time_nextcache_last_channel_nextcache_last_channel_len_next
Triton model config snippet (relevant part)
sequence_batching {
max_sequence_idle_microseconds: 15000000
oldest { max_queue_delay_microseconds: 10000 max_candidate_sequences: 4096 }
state: [
{
input_name: "cache_last_channel_len"
output_name: "cache_last_channel_len_next"
data_type: TYPE_INT64 # also reproduced with INT32
dims: [ 1 ]
initial_state: { name: "cache_last_channel_len_initial" data_type: TYPE_INT64 dims: [ 1 ] zero_data: true }
},
{
input_name: "cache_last_channel"
output_name: "cache_last_channel_next"
data_type: TYPE_FP16
dims: [ 32, 8, 50 ]
initial_state: { name: "cache_last_channel_initial" data_type: TYPE_FP16 dims: [ 32, 8, 50 ] zero_data: true }
},
{
input_name: "cache_last_time"
output_name: "cache_last_time_next"
data_type: TYPE_FP16
dims: [ 18, 384, 30 ]
initial_state: { name: "cache_last_time_initial" data_type: TYPE_FP16 dims: [ 18, 384, 30 ] zero_data: true }
}
]
}Steps to reproduce
- Export ONNX that includes the three states (time/channel/channel_len).
- Build TensorRT engine from ONNX (
trtexec).
I could provide resources if needed
-
Start Triton with verbose logs enabled, e.g.:
tritonserver --model-repository=/models --log-verbose=4
-
Run a minimal client that:
- sends only the non-state inputs (
audio_signal,length) - uses
sequence_id - sets
sequence_start=Trueon step 0 - requests output state tensors (
cache_last_*_next) to confirm correctness
- sends only the non-state inputs (
-
Observe:
- step 0 succeeds and returns correct shapes/dtypes for all outputs
- step 1 fails with an invalid shape error for
cache_last_channel(4D expected, 2D provided)
Expected behavior
Across requests in the same sequence, Triton should:
- feed
cache_last_channel_nextfrom request i intocache_last_channelfor request i+1 - feed
cache_last_channel_len_nextfrom request i intocache_last_channel_lenfor request i+1 - never swap or alias state buffers between different state entries
Actual behavior
-
Request 0 (step 0) succeeds, state outputs look correct.
-
Request 1 (step 1) fails.
-
Client error:
[StatusCode.INTERNAL] request specifies invalid shape for input 'cache_last_channel' ... model expected 4 dimensions but received 2 dimensions -
Server error (25.12):
instance_state.cc:767] "error setting the binding dimension" -
High-verbosity logs show that on the failing request, the injected override inputs are swapped:
cache_last_channelbecomes INT64/INT32 with shape[1,1]cache_last_channel_lenbecomes FP16 with shape[1,32,8,50]
Evidence (log excerpt showing the swap on the 2nd request)
On request 2 (step 1), Triton logs show:
added input override for cache_last_channel:
input: cache_last_channel, type: INT64, original shape: [1,1], ...
added input override for cache_last_channel_len:
input: cache_last_channel_len, type: FP16, original shape: [1,32,8,50], ...
...
E... instance_state.cc:767] "error setting the binding dimension"
This is consistent with the model expecting cache_last_channel to be 4D FP16, but receiving the 2D integer length state instead.
Key observation / workaround
Renaming the length state to avoid a prefix extension of another state name resolves the issue:
- Fails:
cache_last_channel+cache_last_channel_len - Works:
cache_last_channel+cache_last_chan_len(or any non-prefix-colliding name)
This fixes the issue in both:
25.04-py3(Triton 2.57.0)25.12-py3(Triton 2.64.0)
Notably, changing the length dtype (INT32 ↔ INT64) does not fix the issue. The bug appears related to name-based ordering/mapping, not dtype.
Suspected root cause (core r25.12): index-based pairing across std::map
https://github.com/triton-inference-server/core/blob/r25.12/src/sequence_state.cc#L309
In triton-inference-server/core (branch r25.12), the suspected problem is in sequence_state.cc:
SequenceStates::OutputState()finds an output state by name inoutput_states_(std::mapkeyed byoutput_name).- It then computes an index via
std::distance(output_states_.begin(), it)and uses that index to select the “corresponding” input state frominput_states_by advancing an iterator by the same amount. - But
input_states_is a differentstd::map, keyed byinput_name.
This implicitly assumes that input_states_ and output_states_ have identical key ordering. That is not generally true and fails deterministically when the sorted key order differs.
Why the prefix-collision example deterministically swaps
Given:
Input keys (sorted by std::map):
cache_last_channelcache_last_channel_lencache_last_time
Output keys (sorted by std::map):
cache_last_channel_len_nextcache_last_channel_nextcache_last_time_next
Because "cache_last_channel_len_next" < "cache_last_channel_next" lexicographically, indices don’t align.
So:
OutputState("cache_last_channel_next")is index 1 inoutput_states_
→ picks index 1 ininput_states_
→ maps tocache_last_channel_len(wrong)OutputState("cache_last_channel_len_next")is index 0 inoutput_states_
→ picks index 0 ininput_states_
→ maps tocache_last_channel(wrong)
This exactly matches the observed swap in the logs.
Additional notes
- Repro is independent of integer dtype (INT32 and INT64 both reproduce).
- Repro is present in both
25.04and25.12containers. - Renaming the state to avoid prefix/lexicographic collision is an effective workaround but not acceptable as a general requirement.