Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/chatterbox/models/t3/inference/alignment_stream_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,47 @@ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_id

def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx):
"""
Adds a forward hook to a specific attention layer to collect outputs.
Adds a forward wrapper to a specific attention layer to collect outputs.

This approach only forces eager attention and output_attentions=True for the specific
layers that need attention weights, allowing other layers to use optimized SDPA.
"""
def attention_forward_hook(module, input, output):
target_layer = tfmr.layers[layer_idx].self_attn
original_forward = target_layer.forward

def wrapped_forward(*args, **kwargs):
"""
See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
NOTE:
- When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
- `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
Wraps the attention forward to:
1. Temporarily switch to eager attention for this layer
2. Force output_attentions=True
3. Capture attention weights
4. Restore original settings
"""
# Save original attn_implementation
config = target_layer.config
original_attn_impl = getattr(config, '_attn_implementation', None)

# Force eager attention for this layer (SDPA doesn't support output_attentions)
config._attn_implementation = 'eager'
kwargs['output_attentions'] = True

try:
output = original_forward(*args, **kwargs)
finally:
# Restore original attn_implementation
if original_attn_impl is not None:
config._attn_implementation = original_attn_impl

# Capture attention weights
# output is a tuple: (attn_output, attn_weights, past_key_value)
if isinstance(output, tuple) and len(output) > 1 and output[1] is not None:
step_attention = output[1].cpu() # (B, n_heads, T0, Ti)
self.last_aligned_attns[buffer_idx] = step_attention[0, head_idx] # (T0, Ti)

target_layer = tfmr.layers[layer_idx].self_attn
# Register hook and store the handle
target_layer.register_forward_hook(attention_forward_hook)
if hasattr(tfmr, 'config') and hasattr(tfmr.config, 'output_attentions'):
self.original_output_attentions = tfmr.config.output_attentions
tfmr.config.output_attentions = True
return output

# Replace the forward method
target_layer.forward = wrapped_forward

def step(self, logits, next_token=None):
"""
Expand Down