diff --git a/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py index 255d50f03..68ce951b0 100644 --- a/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +++ b/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py @@ -64,6 +64,26 @@ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_id self.last_aligned_attns += [None] self._add_attention_spy(tfmr, i, layer_idx, head_idx) + def reset(self, text_tokens_slice): + """ + Resets the internal state for a new utterance. + """ + self.text_tokens_slice = (i, j) = text_tokens_slice + self.alignment = torch.zeros(0, j-i) + self.curr_frame_pos = 0 + self.text_position = 0 + + self.started = False + self.started_at = None + + self.complete = False + self.completed_at = None + + # Track generated tokens for repetition detection + self.generated_tokens = [] + + self.last_aligned_attns = [None] * len(LLAMA_ALIGNED_HEADS) + def _add_attention_spy(self, tfmr, buffer_idx, layer_idx, head_idx): """ Adds a forward hook to a specific attention layer to collect outputs. diff --git a/src/chatterbox/models/t3/t3.py b/src/chatterbox/models/t3/t3.py index 625302876..507178ff0 100644 --- a/src/chatterbox/models/t3/t3.py +++ b/src/chatterbox/models/t3/t3.py @@ -270,8 +270,6 @@ def inference( # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic # Note the llama-specific logic. Other tfmr types can be added later. - self.compiled = False - # TODO? synchronize the expensive compile function # with self.compile_lock: if not self.compiled: @@ -296,6 +294,10 @@ def inference( ) self.patched_model = patched_model self.compiled = True + else: + self.patched_model.alignment_stream_analyzer.reset( + text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), + ) # # Run normal generate method, which calls our custom extended methods # return self.patched_model.generate(