Commit 7aa4058
committed
Fix teacher-forced decode loop: avoid scalar-constant specialization and cache_position layout drift
Teacher forcing was feeding a per-step scalar token (ground_truth_tokens[step].to(device)).
On XLA-style backends this commonly takes the scalar-constant path, which can specialize the
compiled program on the token value. In decode this produces many unique programs (one per
token) and can blow instruction/L1 caches.
Fix by slicing on CPU to a stable-shaped tensor [1,1] each step and transferring it as runtime
data. Expand to [batch,1] and materialize a contiguous buffer to avoid broadcast/stride issues.
cache_position updates done on-device produced an si32 buffer with a different (non-tiled)
layout than the compiled model expects (tiled si32), leading to TTIR to TTNN compilation failure
on Gemma. Fix by round-tripping cache_position through CPU: normalize to shape [1] via
reshape(-1)[-1:], increment on host, then re-upload so the device import path restores the
expected layout.1 parent 34392af commit 7aa4058
1 file changed
+11
-14
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
77 | 77 | | |
78 | 78 | | |
79 | 79 | | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
80 | 83 | | |
81 | 84 | | |
82 | 85 | | |
| |||
97 | 100 | | |
98 | 101 | | |
99 | 102 | | |
100 | | - | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
106 | | - | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
114 | 110 | | |
115 | 111 | | |
| 112 | + | |
116 | 113 | | |
117 | 114 | | |
118 | 115 | | |
| |||
0 commit comments