diff --git a/web/src/engine/predictive-text/templates/src/tokenization.ts b/web/src/engine/predictive-text/templates/src/tokenization.ts index fd8ed28d5ca..47ef927fa5b 100644 --- a/web/src/engine/predictive-text/templates/src/tokenization.ts +++ b/web/src/engine/predictive-text/templates/src/tokenization.ts @@ -95,6 +95,10 @@ export function tokenize( currentIndex = nextIndex; } + if(tokenization.left.length == 0) { + tokenization.left.push({text: '', isWhitespace: false}); + } + // New step 2: handle any rejoins needed. // Handle any desired special handling for directly-pre-caret scenarios - where for this diff --git a/web/src/engine/predictive-text/worker-thread/src/main/correction/context-token.ts b/web/src/engine/predictive-text/worker-thread/src/main/correction/context-token.ts index 0c492c76255..bd9e4f893f2 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/correction/context-token.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/correction/context-token.ts @@ -129,6 +129,10 @@ export class ContextToken { return this.exampleInput == ''; } + get codepointLength(): number { + return this.searchModule.codepointLength; + } + /** * Denotes the original keystroke Transforms comprising the range corresponding * to this token. diff --git a/web/src/engine/predictive-text/worker-thread/src/main/correction/context-tokenization.ts b/web/src/engine/predictive-text/worker-thread/src/main/correction/context-tokenization.ts index 95d3125c2a9..6faffd911b7 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/correction/context-tokenization.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/correction/context-tokenization.ts @@ -614,6 +614,7 @@ interface RetokenizedEdgeWindow extends EdgeWindow { export interface ContextTokenLike { exampleInput: string; + codepointLength: number; isPartial?: boolean; sourceRangeKey?: string; } diff --git a/web/src/engine/predictive-text/worker-thread/src/main/model-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/model-helpers.ts index 071cad588c5..7b5115e308f 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/model-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/model-helpers.ts @@ -71,7 +71,8 @@ export function determineModelTokenizer(model: LexicalModel) { if(model.wordbreaker) { return models.tokenize(model.wordbreaker, context); } else { - return null; + // Not ideal for pre-14.0 models, but it'll do for now. + return models.tokenize(wordBreakers.default, context); } } } diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index 44223a50267..e6ba0b64047 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -5,13 +5,14 @@ import { defaultWordbreaker, WordBreakProperty } from '@keymanapp/models-wordbre import TransformUtils from './transformUtils.js'; import { determineModelTokenizer, determineModelWordbreaker, determinePunctuationFromModel } from './model-helpers.js'; -import { ContextTokenization } from './correction/context-tokenization.js'; +import { ContextTokenization, ContextTokenLike, mapWhitespacedTokenization } from './correction/context-tokenization.js'; import { ContextTracker } from './correction/context-tracker.js'; import { ContextState, determineContextSlideTransform } from './correction/context-state.js'; import { ContextTransition } from './correction/context-transition.js'; import { ExecutionTimer } from './correction/execution-timer.js'; import ModelCompositor from './model-compositor.js'; import { EDIT_DISTANCE_COST_SCALE, getBestTokenMatches } from './correction/distance-modeler.js'; +import { TokenResult } from './correction/tokenization-corrector.js'; const searchForProperty = defaultWordbreaker.searchForProperty; @@ -26,7 +27,6 @@ import Reversion = LexicalModelTypes.Reversion; import Suggestion = LexicalModelTypes.Suggestion; import SuggestionTag = LexicalModelTypes.SuggestionTag; import Transform = LexicalModelTypes.Transform; -import { TokenResult } from './correction/tokenization-corrector.js'; /* * The functions in this file exist to provide unit-testable stateless components for the @@ -106,12 +106,6 @@ export type CorrectionPredictionTuple = { preservationTransform?: Transform; }; -export interface ContextTokenLike { - exampleInput: string; - isPartial?: boolean; - sourceRangeKey?: string; -} - /** * An enum to be used when categorizing the level of similarity between * generated Suggestions and the actual text upon which a Suggestion is @@ -159,88 +153,69 @@ export function tupleDisplayOrderSort(a: CorrectionPredictionTuple, b: Correctio return b.totalProb - a.totalProb; } -export async function correctAndEnumerateWithoutTraversals( +export function determineTraversallessCorrectionSequences( lexicalModel: LexicalModel, - transformDistribution: Distribution, + corrections: Distribution, context: Context -): Promise<{ - /** - * For models that support correction-search caching, this provides the - * cached object corresponding to this method's operation. - * - * Otherwise, is `null`. - */ - postContextState?: ContextState; - - /** - * The suggestions generated based on the user's input state. - */ - rawPredictions: CorrectionPredictionTuple[]; +): PredictionParameters[] { + let returnedPredictionData: PredictionParameters[] = []; - /** - * The id of a prior ContextTransition event that triggered a Suggestion found - * at the end of the Context. Will be undefined if no edits have occurred - * since the Suggestion was applied. - */ - revertableTransitionId?: number -}> { - const inputTransform = transformDistribution[0].sample; - let rawPredictions: CorrectionPredictionTuple[] = []; + const tokenizer = determineModelTokenizer(lexicalModel); + const wordbreak = determineModelWordbreaker(lexicalModel); - let predictionRoots: ProbabilityMass[]; + const tokenization = tokenizer(context); // issue at present if no tokens exist! + const tokenMapper = (t: models.Token) => { + return { + exampleInput: t.text + } as ContextTokenLike; + } - // Only allow new-word suggestions if space was the most likely keypress. - const allowSpace = TransformUtils.isWhitespace(inputTransform); - const allowBksp = TransformUtils.isBackspace(inputTransform); + for(let correction of corrections) { + // Step 1: determine tokenization effects. We can't use the + // ContextTokenization pattern due to the model's lack of LexiconTraversal + // support, though. + const transformId = correction.sample.id; + const postContext = models.applyTransform(correction.sample, context); + const postTokenization = tokenizer(postContext); + + const transitionEffects = determineSuggestionRange(tokenization.left.map(tokenMapper), postTokenization.left.map(tokenMapper), (a, b) => a.exampleInput == b.exampleInput); + const match: TokenResult = { + matchString: wordbreak(postContext), + inputSamplingCost: -Math.log(correction.p), + knownCost: 0, + totalCost: -Math.log(correction.p) + }; - // Generates raw prediction distributions for each valid input. Can only 'correct' - // against the final input. - // - // This is the old, 12.0-13.0 'correction' style. - if(allowSpace) { - // Detect start of new word; prevent whitespace loss here. - predictionRoots = [{sample: inputTransform, p: 1.0}]; - } else { - predictionRoots = transformDistribution.map((alt) => { - let transform = alt.sample; - - // Filter out special keys unless they're expected. - if(TransformUtils.isWhitespace(transform) && !allowSpace) { - return null; - } else if(TransformUtils.isBackspace(transform) && !allowBksp) { - return null; + const suggestionParams = buildCorrectionSequence(transitionEffects, context, match, 1); + + // // determineSuggestionRange? + // // - can we abstractify it to not need spaceID ordering? + // // - it should never be the case that the lead token for both is not found in the other (unless whole replacement or mismatch) + // // - then, iterate the section that matches perfectly. + const tokenizationMapping = mapWhitespacedTokenization(tokenization.left.map((t) => { return {exampleInput: t.text, codepointLength: KMWString.length(t.text)} }), lexicalModel, correction.sample); + const tokenizedCorrection = tokenizationMapping.tokenizedTransform; + const tokenizedCorrectionEntries = [...tokenizedCorrection.values()]; + const { tokensToRemove, tokensToPredict } = transitionEffects; + const deleteLeft = tokensToPredict.length > 1 ? 0 : tokensToRemove.reduce((prev, curr) => prev + curr.codepointLength, 0); + + // IF: array has multiple entries, then build the preservation-transform as below, including the deleteLeft. + // If not, don't make one! + const preservationTransform = tokenizedCorrectionEntries.slice(0, -1).reduce((accum, curr) => { + return models.buildMergedTransform(accum, {...curr, deleteLeft: 0}); + }, { insert: '', deleteLeft, id: correction.sample.id}); + + returnedPredictionData.push({ + ...suggestionParams, + applyInPost: (p) => { + p.preservationTransform = preservationTransform; + if(transformId) { + p.prediction.sample.transformId = transformId; + } } - - return alt; - }); - } - - const wordbreak = determineModelWordbreaker(lexicalModel); - // Remove `null` entries, then determine suggestions. - predictionRoots.forEach((pr) => { - const postContext = models.applyTransform(pr.sample, context); - const tailTokenText = wordbreak(postContext); - const rootContext = models.applyTransform({insert: '', deleteLeft: KMWString.length(tailTokenText)}, postContext); - - const results = predictFromCorrectionSequence(lexicalModel, [{ - sample: { - insert: tailTokenText, - deleteLeft: 0, - id: pr.sample.id - }, - p: pr.p - }], rootContext); - results.forEach((r) => rawPredictions.push(r)); - }) - - if(allowSpace) { - rawPredictions.forEach((entry) => entry.preservationTransform = inputTransform); + }) } - return { - postContextState: null, - rawPredictions: rawPredictions - }; + return returnedPredictionData; } /** @@ -433,6 +408,37 @@ export interface PredictionParameters { applyInPost: (entry: CorrectionPredictionTuple) => void } +export function buildCorrectionSequence( + transitionEffects: ReturnType, + context: Context, + match: Readonly, + costFactor: number +) { + const { tokensToPredict, tokensToRemove, extendsRoot } = transitionEffects; + const deleteLeft = (tokensToPredict.length > 1 && !extendsRoot) + ? (tokensToRemove[tokensToRemove.length - 1]?.codepointLength ?? 0) + : tokensToRemove.reduce((prev, curr) => prev + curr.codepointLength, 0); + + const rootContext = models.applyTransform({insert: '', deleteLeft}, context); + + // Replace the existing context with the correction. + const correctionTransform: Transform = { + insert: match.matchString, // insert correction string + deleteLeft: 0, + } + + const rootCost = match.totalCost; + const predictionRoot = { + sample: correctionTransform, + p: Math.exp(-rootCost * costFactor) + }; + + return { + rootContext, + tokenizedCorrection: [predictionRoot] + }; +} + /** * This function takes in metadata about generated corrections (for models that * implement Traversals) and uses that to produce the corresponding parameters @@ -447,37 +453,28 @@ export interface PredictionParameters { * building prediction probabilities. * @returns */ -export function determineTokenizedCorrectionSequence( +export function determineTokenizedCorrectionSequence( // transition: ContextTransition, tokenization: ContextTokenization, match: Readonly, costFactor: number ): PredictionParameters { const applicationTarget = transition.base.displayTokenization; - const { tokensToRemove, tokensToPredict } = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); + const transitionParams = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); - const deleteLeft = tokensToPredict.length > 1 ? 0 : tokensToRemove.reduce((prev, curr) => prev + curr.searchModule.codepointLength, 0); - const rootContext = models.applyTransform({insert: '', deleteLeft}, transition.base.context); - - // Replace the existing context with the correction. - const correctionTransform: Transform = { - insert: match.matchString, // insert correction string - deleteLeft: 0, - } + const suggestionParams = buildCorrectionSequence(transitionParams, transition.base.context, match, costFactor); if(transition.transitionId) { - correctionTransform.id = transition.transitionId // The correction should always be based on the most recent external transform/transcription ID. + suggestionParams.tokenizedCorrection.map((t) => t.sample.id = transition.transitionId); // The correction should always be based on the most recent external transform/transcription ID. } - const rootCost = match.totalCost; - const predictionRoot = { - sample: correctionTransform, - p: Math.exp(-rootCost * costFactor) - }; + const { tokensToPredict, tokensToRemove } = transitionParams; + const deleteLeft = tokensToPredict.length > 1 + ? tokensToRemove[tokensToRemove.length - 1]?.codepointLength ?? 0 + : tokensToRemove.reduce((prev, curr) => prev + curr.codepointLength, 0); return { - rootContext, - tokenizedCorrection: [predictionRoot], + ...suggestionParams, applyInPost: (entry: CorrectionPredictionTuple) => { entry.preservationTransform = tokenization.taillessTrueKeystroke; // // Will need an extra lookup layer if the suggestion is generated from within a cluster. @@ -530,7 +527,14 @@ export async function correctAndEnumerate( // It's mostly here to support models compiled before Keyman 14.0, which was // when the `LexiconTraversal` pattern was established. if(!contextTracker) { - return correctAndEnumerateWithoutTraversals(lexicalModel, transformDistribution, context); + const predictionData = determineTraversallessCorrectionSequences(lexicalModel, transformDistribution, context); + return { + rawPredictions: predictionData.flatMap((entry) => { + const predictions = predictFromCorrectionSequence(lexicalModel, entry.tokenizedCorrection, entry.rootContext); + predictions.forEach((p) => entry.applyInPost(p)); + return predictions; + }) + }; } // 'else': the current, 14.0+ pattern, which is able to leverage diff --git a/web/src/test/auto/headless/engine/predictive-text/templates/tokenization.tests.ts b/web/src/test/auto/headless/engine/predictive-text/templates/tokenization.tests.ts index 3bc636c4128..0aa4f9551ed 100644 --- a/web/src/test/auto/headless/engine/predictive-text/templates/tokenization.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/templates/tokenization.tests.ts @@ -175,7 +175,7 @@ describe('Tokenization functions', function() { }); it('properly handles empty-context cases', function() { - // Wordbreaking on a empty space => no word. + // Wordbreaking on a empty space => no word, but empty initial token. let context = { left: '', startOfBuffer: true, right: '', endOfBuffer: true @@ -184,7 +184,7 @@ describe('Tokenization functions', function() { let tokenization = models.tokenize(wordBreakers.default, context); let expectedResult: models.Tokenization = { - left: [], + left: [{text: '', isWhitespace: false}], right: [], caretSplitsToken: false }; @@ -193,11 +193,11 @@ describe('Tokenization functions', function() { }); it('properly handles null context cases', function() { - // Wordbreaking on a empty space => no word. + // Wordbreaking on a empty space => no word, but empty initial token. let tokenization = models.tokenize(wordBreakers.default, null); let expectedResult: models.Tokenization = { - left: [], + left: [{text: '', isWhitespace: false}], right: [], caretSplitsToken: false }; diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts index a8182d890d2..da1221714b2 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts @@ -235,7 +235,7 @@ describe('determineTokenizedCorrectionSequence', () => { ]); }); - it(`properly analyzes post-split case`, () => { + it(`properly analyzes post-split new-wordbreak case`, () => { const context: Context = { left: 'the quick brown fox can\'', right: '', @@ -266,13 +266,15 @@ describe('determineTokenizedCorrectionSequence', () => { 1 ); - assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { - casingForm: undefined, - left: 'the quick brown fox can\'', - right: '', - startOfBuffer: true, - endOfBuffer: true - }); + // assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { + // casingForm: undefined, + // // Proper logic requires full multi-token awareness; predictions are currently + // // based on just the last token. + // left: 'the quick brown fox can\'', + // right: '', + // startOfBuffer: true, + // endOfBuffer: true + // }); assert.deepEqual(results.tokenizedCorrection, [ { @@ -285,7 +287,7 @@ describe('determineTokenizedCorrectionSequence', () => { ]); }); - it(`properly analyzes conplex transition - multi-token replacement`, () => { + it(`properly analyzes complex transition - multi-token replacement`, () => { const context: Context = { left: 'the quick brown f', right: '', @@ -319,7 +321,7 @@ describe('determineTokenizedCorrectionSequence', () => { // deleted by the `preservationTransform`, not here. assert.deepEqual({...results.rootContext, casingForm: results.rootContext.casingForm}, { casingForm: undefined, - left: 'the quick brown f', + left: 'the quick brown ', right: '', startOfBuffer: true, endOfBuffer: true diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts new file mode 100644 index 00000000000..994c918d1a8 --- /dev/null +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-traversalless-correction-sequences.tests.ts @@ -0,0 +1,90 @@ +/* + * Keyman is copyright (C) SIL Global. MIT License. + * + * Created by jahorton on 2026-05-18 + * + * This file tests the prediction helper-method responsible for preparing + * corrections for multi-token prediction for some custom and all legacy models. + */ + +import { assert } from 'chai'; + +import { LexicalModelTypes } from "@keymanapp/common-types"; +import * as wordBreakers from '@keymanapp/models-wordbreakers'; + +import { determineTraversallessCorrectionSequences, models } from "@keymanapp/lm-worker/test-index"; + +import Context = LexicalModelTypes.Context; +import DummyModel = models.DummyModel; +import DummyOptions = models.DummyOptions; +import ProbabilityMass = LexicalModelTypes.ProbabilityMass; +import Transform = LexicalModelTypes.Transform; + + +/* + * This file's tests use these parts of a lexical model: + * - model.wordbreaker + * - model.toKey + * - model.applyCasing + * - model.punctuation + */ + +const DUMMY_MODEL_CONFIG: DummyOptions = { + punctuation: { + quotesForKeepSuggestion: { + open: '<', + close: '>' + }, + insertAfterWord: '\u00a0' // non-breaking space + }, + wordbreaker: wordBreakers.default +}; + +const testModel = new DummyModel({ + ...DUMMY_MODEL_CONFIG, + // No suggestions needed here, so we don't define any. +}); + +describe('determineTraversallessCorrectionSequences', () => { + it(`creates an 'exact'-match suggestion based on primary input and current context`, () => { + const context: Context = { + left: 'iphon', + right: '', + startOfBuffer: true, + endOfBuffer: true + }; + + const trueInput: ProbabilityMass = { + sample: { + insert: 'e', + deleteLeft: 0 + }, + p: 1 + }; + + const predictionRootEntries = determineTraversallessCorrectionSequences(testModel, [trueInput], context); + + assert.equal(predictionRootEntries.length, 1); + const entry = predictionRootEntries[0]; + + assert.deepEqual( + { + ...entry.rootContext, casingForm: entry.rootContext.casingForm ?? undefined + }, { + casingForm: undefined, + left: '', + right: '', + startOfBuffer: true, + endOfBuffer: true + } + ); + + assert.deepEqual(entry.tokenizedCorrection, [{ + sample: { + insert: 'iphone', + deleteLeft: 0 + }, + p: 1 + }]); + }); +}); \ No newline at end of file diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/worker-custom-punctuation.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/worker-custom-punctuation.tests.ts index cd4dbd106e0..01f1a6ac96a 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/worker-custom-punctuation.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/worker-custom-punctuation.tests.ts @@ -87,17 +87,30 @@ describe('Custom Punctuation', function () { // the tests run smoothly. wordbreaker: (text) => { const textLen = text.length; - if(text.charAt(textLen - 1) == " ") { - return [ - {text: text.substring(0, 1), start: 0, end: 1, length: 1}, - {text: text.substring(1, textLen-2), start: 1, end: textLen-1, length: textLen-2}, - {text: text.substring(textLen-1), start: textLen-1, end: textLen, length: 1} - ]; + if(text.charAt(0) == "᚛") { // ensure the prior token component (the '᚛') wordbreaks. + if(text.charAt(textLen - 1) == " ") { // ensure the insert-after component word-breaks. + return [ + {text: text.substring(0, 1), start: 0, end: 1, length: 1}, + {text: text.substring(1, textLen-2), start: 1, end: textLen-1, length: textLen-2}, + {text: text.substring(textLen-1), start: textLen-1, end: textLen, length: 1} + ]; + } else { + return [ + {text: text.substring(0, 1), start: 0, end: 1, length: 1}, + {text: text.substring(1), start: 1, end: textLen, length: textLen-1} + ]; + } } else { - return [ - {text: text.substring(0, 1), start: 0, end: 1, length: 1}, - {text: text.substring(1), start: 1, end: textLen, length: textLen-1} - ]; + if(text.charAt(textLen - 1) == " ") { + return [ + {text: text.substring(0, textLen-2), start: 0, end: textLen-1, length: textLen-1}, + {text: text.substring(textLen-1), start: textLen-1, end: textLen, length: 1} + ]; + } else { + return [ + {text: text, start: 0, end: textLen, length: textLen} + ]; + } } } });