Skip to content

Commit 7aa4058

Browse files
committed
Fix teacher-forced decode loop: avoid scalar-constant specialization and cache_position layout drift
Teacher forcing was feeding a per-step scalar token (ground_truth_tokens[step].to(device)). On XLA-style backends this commonly takes the scalar-constant path, which can specialize the compiled program on the token value. In decode this produces many unique programs (one per token) and can blow instruction/L1 caches. Fix by slicing on CPU to a stable-shaped tensor [1,1] each step and transferring it as runtime data. Expand to [batch,1] and materialize a contiguous buffer to avoid broadcast/stride issues. cache_position updates done on-device produced an si32 buffer with a different (non-tiled) layout than the compiled model expects (tiled si32), leading to TTIR to TTNN compilation failure on Gemma. Fix by round-tripping cache_position through CPU: normalize to shape [1] via reshape(-1)[-1:], increment on host, then re-upload so the device import path restores the expected layout.
1 parent 34392af commit 7aa4058

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

benchmark/tt-xla/decode_utils.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def teacher_forced_generate(
7777

7878
assert_eval_no_dropout(model, verbose=verbose)
7979

80+
# Capture batch size before the loop replaces input_args["input_ids"].
81+
batch_size = input_args["input_ids"].shape[0]
82+
8083
output_tokens: list[list[str]] = []
8184
output_logits: list[torch.Tensor] = []
8285
predicted_tokens: list[int] = []
@@ -97,22 +100,16 @@ def teacher_forced_generate(
97100
output_text = [tokenizer.decode(token_id) for token_id in next_token_ids]
98101
output_tokens.append(output_text)
99102

100-
# Teacher forcing update: feed ground truth for next step.
101-
if step < ground_truth_tokens.shape[0]:
102-
batch_size = input_args["input_ids"].shape[0]
103-
gt_token = ground_truth_tokens[step].to(device)
104-
input_args["input_ids"] = gt_token.view(1, 1).expand(batch_size, 1).contiguous()
105-
else:
106-
# If caller asks for more steps than ground truth provides, keep feeding last GT token.
107-
batch_size = input_args["input_ids"].shape[0]
108-
gt_token = ground_truth_tokens[-1].to(device)
109-
input_args["input_ids"] = gt_token.view(1, 1).expand(batch_size, 1).contiguous()
110-
111-
host_cache_pos = input_args["cache_position"].to("cpu")
112-
host_cache_pos = torch.tensor([host_cache_pos[-1:] + 1])
113-
input_args["cache_position"] = host_cache_pos.to(device)
103+
# Teacher forcing: keep token as runtime data (stable shape) to avoid scalar-constant specialization.
104+
next_tok_host = ground_truth_tokens[step : step + 1].view(1, 1) # CPU [1,1]
105+
input_args["input_ids"] = next_tok_host.expand(batch_size, 1).contiguous().to(device)
106+
107+
# cache_position: host normalize/update to keep a stable [1] shape.
108+
host_cache_pos = input_args["cache_position"].to("cpu").reshape(-1)[-1:] # CPU [1]
109+
input_args["cache_position"] = (host_cache_pos + 1).to(device)
114110

115111
iteration_times_ns.append(time.perf_counter_ns() - start)
112+
116113
if verbose:
117114
print(f"Iteration\t{step}/{max_tokens_to_generate}\ttook {iteration_times_ns[-1] / 1e6:.04} ms")
118115

0 commit comments

Comments
 (0)