diff --git a/apps/sotto/internal/riva/client.go b/apps/sotto/internal/riva/client.go index dee7031..bcc272b 100644 --- a/apps/sotto/internal/riva/client.go +++ b/apps/sotto/internal/riva/client.go @@ -41,12 +41,13 @@ type Stream struct { recvDone chan struct{} - mu sync.Mutex - segments []string // committed transcript segments (final and pause-committed interim) - lastInterim string - recvErr error - closedSend bool - debugSinkJSON io.Writer + mu sync.Mutex + segments []string // committed transcript segments (final and high-confidence boundary-committed interim) + lastInterim string + lastInterimStability float32 + recvErr error + closedSend bool + debugSinkJSON io.Writer } // DialStream establishes a stream, sends config, and starts the receive loop. diff --git a/apps/sotto/internal/riva/client_test.go b/apps/sotto/internal/riva/client_test.go index 679103f..c7afc9e 100644 --- a/apps/sotto/internal/riva/client_test.go +++ b/apps/sotto/internal/riva/client_test.go @@ -50,7 +50,7 @@ func TestRecordResponseTracksInterimThenFinal(t *testing.T) { require.Equal(t, []string{"hello world"}, s.segments) } -func TestRecordResponseCommitsInterimAcrossPauseLikeReset(t *testing.T) { +func TestRecordResponseReplacesDivergentInterimWithoutPrecommit(t *testing.T) { s := &Stream{} s.recordResponse(&asrpb.StreamingRecognizeResponse{ @@ -67,10 +67,89 @@ func TestRecordResponseCommitsInterimAcrossPauseLikeReset(t *testing.T) { }}, }) + require.Empty(t, s.segments) + segments := collectSegments(s.segments, s.lastInterim) + require.Equal(t, []string{"second phrase"}, segments) +} + +func TestRecordResponseCommitsStableDivergentInterimForPartialRecovery(t *testing.T) { + s := &Stream{} + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.95, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "first phrase"}}, + }}, + }) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.20, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "second phrase"}}, + }}, + }) + + require.Equal(t, []string{"first phrase"}, s.segments) segments := collectSegments(s.segments, s.lastInterim) require.Equal(t, []string{"first phrase", "second phrase"}, segments) } +func TestRecordResponseDoesNotPrependStaleInterimBeforeFinal(t *testing.T) { + s := &Stream{} + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.05, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "stale words"}}, + }}, + }) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.30, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello world"}}, + }}, + }) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: true, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello world"}}, + }}, + }) + + segments := collectSegments(s.segments, s.lastInterim) + require.Equal(t, []string{"hello world"}, segments) +} + +func TestRecordResponseTreatsSuffixCorrectionAsContinuation(t *testing.T) { + s := &Stream{} + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.95, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "replace reply replied on the review thread with details"}}, + }}, + }) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Stability: 0.95, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "replied on the review thread with details"}}, + }}, + }) + + require.Empty(t, s.segments) + segments := collectSegments(s.segments, s.lastInterim) + require.Equal(t, []string{"replied on the review thread with details"}, segments) +} + func TestAppendSegmentDedupAndPrefixMerge(t *testing.T) { segments := appendSegment(nil, "hello") require.Equal(t, []string{"hello"}, segments) @@ -94,7 +173,13 @@ func TestCleanSegmentAndInterimContinuation(t *testing.T) { require.True(t, isInterimContinuation("hello", "hello world")) require.True(t, isInterimContinuation("hello world", "hello")) + require.True(t, isInterimContinuation("replace reply replied on thread", "replied on thread")) require.False(t, isInterimContinuation("first phrase", "second phrase")) + + require.False(t, shouldCommitPriorInterimOnDivergence("first phrase", 0.2, "second phrase")) + require.True(t, shouldCommitPriorInterimOnDivergence("first phrase", 0.9, "second phrase")) + require.True(t, shouldCommitPriorInterimOnDivergence("Done.", 0.1, "new sentence")) + require.False(t, shouldCommitPriorInterimOnDivergence("replace reply replied on thread", 0.95, "replied on thread")) } func TestDialStreamEndToEndWithDebugSinkAndSpeechContexts(t *testing.T) { diff --git a/apps/sotto/internal/riva/stream_receive.go b/apps/sotto/internal/riva/stream_receive.go index ef0e226..769e642 100644 --- a/apps/sotto/internal/riva/stream_receive.go +++ b/apps/sotto/internal/riva/stream_receive.go @@ -53,12 +53,14 @@ func (s *Stream) recordResponse(resp *asrpb.StreamingRecognizeResponse) { if result.GetIsFinal() { s.segments = appendSegment(s.segments, transcript) s.lastInterim = "" + s.lastInterimStability = 0 continue } - if s.lastInterim != "" && !isInterimContinuation(s.lastInterim, transcript) { + if shouldCommitPriorInterimOnDivergence(s.lastInterim, s.lastInterimStability, transcript) { s.segments = appendSegment(s.segments, s.lastInterim) } s.lastInterim = transcript + s.lastInterimStability = result.GetStability() } } diff --git a/apps/sotto/internal/riva/transcript_segments.go b/apps/sotto/internal/riva/transcript_segments.go index 0e887f6..ea7efb2 100644 --- a/apps/sotto/internal/riva/transcript_segments.go +++ b/apps/sotto/internal/riva/transcript_segments.go @@ -2,6 +2,8 @@ package riva import "strings" +const stableInterimBoundaryThreshold = 0.85 + // collectSegments appends a valid trailing interim segment when needed. func collectSegments(committedSegments []string, lastInterim string) []string { segments := append([]string(nil), committedSegments...) @@ -48,10 +50,12 @@ func isInterimContinuation(previous string, current string) bool { if strings.HasPrefix(current, previous) || strings.HasPrefix(previous, current) { return true } + if strings.HasSuffix(current, previous) || strings.HasSuffix(previous, current) { + return true + } prevWords := strings.Fields(previous) currWords := strings.Fields(current) - common := commonPrefixWords(prevWords, currWords) shorter := len(prevWords) if len(currWords) < shorter { shorter = len(currWords) @@ -59,7 +63,49 @@ func isInterimContinuation(previous string, current string) bool { if shorter == 0 { return true } - return common*2 >= shorter + + commonPrefix := commonPrefixWords(prevWords, currWords) + if commonPrefix*2 >= shorter { + return true + } + + commonSuffix := commonSuffixWords(prevWords, currWords) + if shorter >= 3 && commonSuffix*2 >= shorter { + return true + } + + return false +} + +// shouldCommitPriorInterimOnDivergence decides whether to preserve prior interim +// text when a new interim hypothesis diverges. +func shouldCommitPriorInterimOnDivergence(previous string, previousStability float32, current string) bool { + previous = cleanSegment(previous) + current = cleanSegment(current) + if previous == "" || current == "" { + return false + } + if isInterimContinuation(previous, current) { + return false + } + if previousStability >= stableInterimBoundaryThreshold { + return true + } + return endsWithSentencePunctuation(previous) +} + +// endsWithSentencePunctuation reports whether transcript looks sentence-complete. +func endsWithSentencePunctuation(transcript string) bool { + transcript = strings.TrimSpace(transcript) + if transcript == "" { + return false + } + switch transcript[len(transcript)-1] { + case '.', '!', '?': + return true + default: + return false + } } // commonPrefixWords counts shared leading words across two slices. @@ -78,6 +124,22 @@ func commonPrefixWords(left []string, right []string) int { return count } +// commonSuffixWords counts shared trailing words across two slices. +func commonSuffixWords(left []string, right []string) int { + li := len(left) - 1 + ri := len(right) - 1 + count := 0 + for li >= 0 && ri >= 0 { + if left[li] != right[ri] { + break + } + count++ + li-- + ri-- + } + return count +} + // cleanSegment normalizes transcript whitespace. func cleanSegment(raw string) string { raw = strings.TrimSpace(raw)