Skip to content

Commit 1f799ad

Browse files
author
RageLtMan
committed
Handle Multiple EOS Tokens in Grammar
Handle runaway model output in "normal grammar" modality masking possible EOS tokens and producing nonsensical output once the model has completed its normal tool-calls and chat stream: - guidance-ai/llguidance#304 - guidance-ai/llguidance#305
1 parent 31e65a8 commit 1f799ad

File tree

5 files changed

+267
-91
lines changed

5 files changed

+267
-91
lines changed

ReadMe-CN.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -421,19 +421,19 @@ pip install maturin[patchelf] # Linux/Windows 平台
421421
2. **构建 Python 包**
422422

423423
```bash
424-
# Naive CUDA (只能用于单卡推理)
424+
# Naive CUDA (只能用于单卡推理)
425425
maturin build --release --features cuda,python
426426
427427
# Naive CUDA (+CUDA Graph, 实验阶段)
428428
maturin build --release --features cuda,graph,python
429429
430-
# CUDA (支持Context-cache与FP8 KV Cache,不使用Flash attention)
430+
# CUDA (支持Context-cache与FP8 KV Cache,不使用Flash attention)
431431
./build.sh --release --features cuda,nccl,python
432432
433-
# CUDA (+Flash attention,仅prefill时启用)
433+
# CUDA (+Flash attention,仅prefill时启用)
434434
./build.sh --release --features cuda,nccl,flash-attn,python
435435
436-
# CUDA (+Flash attention,prefill/decoding均使用Flash attention,编译时间最长)
436+
# CUDA (+Flash attention,prefill/decoding均使用Flash attention,编译时间最长)
437437
./build.sh --release --features cuda,nccl,flash-context,python
438438
439439
# macOS(Metal, 支持Context-cache与FP8 KV Cache,但不支持多GPU推理)

docs/llguidance-integration.md

Lines changed: 167 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -19,70 +19,173 @@ This document provides comprehensive documentation for the llguidance integratio
1919

2020
### Component Overview
2121

22-
```
23-
┌─────────────────────────────────────────────────────────────────┐
24-
│ User Request │
25-
│ (OpenAI / Claude / MCP / Structured Outputs) │
26-
└───────────────────────┬─────────────────────────────────────────┘
27-
28-
29-
┌─────────────────────────────────────────────────────────────────┐
30-
│ Server Layer │
31-
│ ┌───────────────────────────────────────────────────────────┐ │
32-
│ │ chat_completion() [src/server/server.rs:250] │ │
33-
│ │ - Parse request fields │ │
34-
│ │ - Resolve tools (MCP + request) │ │
35-
│ │ - Build constraints from structured_outputs/response_format│ │
36-
│ │ - Build grammar from tools (if enable_tool_grammar=true) │ │
37-
│ │ - Compose grammars via compose_grammars() │ │
38-
│ └───────────────────────────────────────────────────────────┘ │
39-
└───────────────────────┬─────────────────────────────────────────┘
40-
41-
42-
┌─────────────────────────────────────────────────────────────────┐
43-
│ Engine Layer │
44-
│ ┌───────────────────────────────────────────────────────────┐ │
45-
│ │ LLMEngine::generate_stream() / generate_sync() │ │
46-
│ │ - Create Sequence │ │
47-
│ │ - Allocate blocks │ │
48-
│ │ - Initialize GuidanceState if grammar exists │ │
49-
│ └───────────────────────────────────────────────────────────┘ │
50-
└───────────────────────┬─────────────────────────────────────────┘
51-
52-
53-
┌─────────────────────────────────────────────────────────────────┐
54-
│ Scheduler Layer │
55-
│ ┌───────────────────────────────────────────────────────────┐ │
56-
│ │ postprocess() [src/core/scheduler.rs:616] │ │
57-
│ │ - Rollback validation via ModelRunner │ │
58-
│ │ - Sequence state management │ │
59-
│ └───────────────────────────────────────────────────────────┘ │
60-
└───────────────────────┬─────────────────────────────────────────┘
61-
62-
63-
┌─────────────────────────────────────────────────────────────────┐
64-
│ ModelRunner Layer │
65-
│ ┌───────────────────────────────────────────────────────────┐ │
66-
│ │ GuidanceState::compute_mask() │ │
67-
│ │ GuidanceState::validate_tokens() │ │
68-
│ │ GuidanceState::rollback_to() │ │
69-
│ │ GuidanceState::consume_ff_tokens() │ │
70-
│ │ - Compute valid token mask from llguidance Matcher │ │
71-
│ │ - Validate generated tokens against grammar │ │
72-
│ │ - Rollback to previous state on failure │ │
73-
│ └───────────────────────────────────────────────────────────┘ │
74-
└───────────────────────┬─────────────────────────────────────────┘
75-
76-
77-
┌─────────────────────────────────────────────────────────────────┐
78-
│ BlockManager Layer │
79-
│ ┌───────────────────────────────────────────────────────────┐ │
80-
│ │ rollback_to_seq_tokens() [src/core/block_manager.rs:946]│ │
81-
│ │ - Release KV cache blocks │ │
82-
│ │ - Clean up prefix cache │ │
83-
│ │ - Invalidate Mamba state │ │
84-
│ └───────────────────────────────────────────────────────────┘ │
85-
└─────────────────────────────────────────────────────────────────┘
22+
```mermaid
23+
sequenceDiagram
24+
participant User
25+
participant API
26+
participant Pipeline
27+
participant LLGFactory
28+
participant Matcher
29+
participant TokenParser
30+
participant EarleyParser
31+
participant Lexer
32+
participant TokTrie
33+
participant Sampler
34+
participant LogitsProcessor
35+
participant Model
36+
37+
User->>API: Request with constraint (regex/json_schema/lark/llguidance)
38+
39+
Note over User,API: Phase 1: Request Setup and Grammar Building
40+
41+
API->>Pipeline: build_llg_factory(tokenizer)
42+
Pipeline->>LLGFactory: toktrie_hf_tokenizers::ByteTokenizer::from_tokenizer(tokenizer)
43+
LLGFactory->>TokTrie: Create token trie from tokenizer vocabulary
44+
TokTrie-->>LLGFactory: Return TokEnv with trie
45+
LLGFactory->>LLGFactory: ParserFactory::new_simple(&env)
46+
LLGFactory-->>Pipeline: Return Arc<ParserFactory>
47+
48+
Pipeline->>Pipeline: llg_grammar_from_constraint(&request.constraint)
49+
Pipeline->>Matcher: constraint_from_llg_grammar(&factory, grm)
50+
Matcher->>Matcher: factory.create_parser(grm)
51+
Matcher->>TokenParser: Create with grammar_init
52+
TokenParser->>EarleyParser: Build CGrammar from grammar
53+
TokenParser->>Lexer: Build LexerSpec from grammar
54+
Lexer->>TokTrie: Precompute large lexemes if needed
55+
TokTrie-->>Lexer: Return optimized lexeme sets
56+
57+
Note over User,Matcher: Phase 2: Prompt Processing (if needed)
58+
59+
User->>API: Optional: process_prompt(prompt_tokens)
60+
API->>TokenParser: process_prompt(prompt_tokens)
61+
TokenParser->>TokenParser: tokenize_bytes_marker(&prompt_bytes)
62+
TokenParser->>TokenParser: process_prompt() returns new prompt
63+
64+
Note over User,Matcher: Phase 3: Inference Loop
65+
66+
loop for each token generation
67+
68+
Model->>Model: Forward pass on input tokens
69+
Model-->>Pipeline: Return logits tensor
70+
71+
Pipeline->>Sampler: sample_sequence(logits, seq, ...)
72+
73+
Note over Sampler: Two-stage sampling with llguidance
74+
75+
Sampler->>LogitsProcessor: Apply llguidance constraint
76+
77+
LogitsProcessor->>TokenParser: compute_mask()
78+
TokenParser->>TokenParser: compute_mask_inner()
79+
TokenParser->>EarleyParser: run_speculative("compute_mask")
80+
EarleyParser->>EarleyParser: trie_started("compute_mask")
81+
EarleyParser->>EarleyParser: compute_bias()
82+
EarleyParser->>Lexer: compute_bias() with token_prefix
83+
84+
Note over Lexer,TokTrie: Lexical Scope Analysis
85+
86+
Lexer->>TokTrie: Walk token trie for allowed lexemes
87+
TokTrie-->>Lexer: Return SimpleVob bit mask
88+
89+
Lexer->>EarleyParser: Return mask to TokenParser
90+
TokenParser->>TokenParser: cache mask for fast-forward
91+
92+
TokenParser-->>LogitsProcessor: Return SimpleVob mask
93+
94+
LogitsProcessor->>LogitsProcessor: Check if sampled token is allowed
95+
LogitsProcessor->>Sampler: Apply logit biasing
96+
97+
alt Token is allowed
98+
Sampler->>Sampler: No biasing needed
99+
else Token is not allowed
100+
Sampler->>Sampler: Set invalid tokens to -f32::INFINITY
101+
Sampler->>Sampler: Re-sample with biased logits
102+
end
103+
104+
Sampler->>TokenParser: consume_token(sampled_token)
105+
TokenParser->>TokenParser: apply_token(sampled_token)
106+
TokenParser->>TokenParser: llm_tokens.push(sampled_token)
107+
TokenParser->>TokenParser: llm_bytes.extend(token_bytes)
108+
TokenParser->>EarleyParser: parser.apply_token(token_bytes, token_id)
109+
EarleyParser->>Lexer: advance lexer state
110+
Lexer->>Lexer: Update lexer_stack with new state
111+
Lexer->>EarleyParser: Return backtrack count
112+
113+
alt Backtrack needed
114+
EarleyParser->>EarleyParser: rollback(backtrack_bytes)
115+
EarleyParser->>EarleyParser: Update llm_tokens and llm_bytes
116+
end
117+
118+
TokenParser->>TokenParser: check_stop()
119+
TokenParser-->>Sampler: Return CommitResult
120+
121+
Note over Sampler: Phase 4: Fast-Forward (if enabled)
122+
123+
Sampler->>TokenParser: compute_ff_tokens()
124+
TokenParser->>TokenParser: ff_tokens()
125+
TokenParser->>TokTrie: Tokenize forced bytes
126+
TokTrie-->>TokenParser: Return fast-forward tokens
127+
128+
alt Fast-forward tokens available
129+
TokenParser->>TokenParser: consume_ff_tokens()
130+
loop for each ff_token
131+
TokenParser->>TokenParser: consume_token(ff_token)
132+
TokenParser->>TokenParser: llm_tokens.push(ff_token)
133+
TokenParser->>TokenParser: llm_bytes.extend(ff_token_bytes)
134+
end
135+
end
136+
137+
Note over Sampler: Phase 5: Speculative Decoding (if enabled)
138+
139+
Model->>Model: Draft model forward pass
140+
Model-->>Pipeline: Return draft logits
141+
142+
Pipeline->>Sampler: sample_target_sequence_speculative()
143+
Sampler->>TokenParser: rollback(n_toks)
144+
TokenParser->>EarleyParser: parser.rollback(bytes_to_drop)
145+
EarleyParser->>Lexer: pop lexer states
146+
Lexer-->>TokenParser: Return rollback result
147+
148+
Sampler->>Sampler: Sample draft tokens
149+
Sampler->>TokenParser: validate_tokens(draft_tokens)
150+
TokenParser->>TokenParser: consume_token(draft_token)
151+
152+
alt Draft token accepted
153+
TokenParser->>TokenParser: Continue with next draft
154+
else Draft token rejected
155+
TokenParser->>TokenParser: Accept partial draft
156+
TokenParser->>TokenParser: Rollback to last valid state
157+
end
158+
159+
end
160+
161+
Note over User,Matcher: Phase 6: Token Geometry and Binary Data State
162+
163+
TokTrie->>TokTrie: Token encoding (8:24 bit split)
164+
TokTrie->>TokTrie: node.bits = (token_id << 8) | byte
165+
TokTrie->>TokTrie: node.bits2 = (subtree_size << 10) | num_parents
166+
167+
TokTrie->>SimpleVob: Bit mask storage
168+
SimpleVob->>SimpleVob: data: Vec<u32> (32 tokens per word)
169+
SimpleVob->>SimpleVob: allow_token(tok): data[tok>>5] |= 1 << (tok&31)
170+
171+
Note over User,Matcher: Phase 7: Rollback and Verification
172+
173+
TokenParser->>TokenParser: validate_tokens(tokens)
174+
TokenParser->>EarleyParser: validate_tokens_raw(tokens)
175+
EarleyParser->>Lexer: Check if tokens match current lexer state
176+
Lexer-->>TokenParser: Return number of valid tokens
177+
178+
TokenParser->>TokenParser: rollback(n_tokens)
179+
TokenParser->>EarleyParser: parser.rollback(bytes_to_drop)
180+
EarleyParser->>Lexer: pop lexer states
181+
TokenParser->>TokenParser: llm_tokens.truncate(new_len)
182+
TokenParser->>TokenParser: llm_bytes.truncate(new_len)
183+
184+
Note over User,Matcher: Phase 8: Response Generation
185+
186+
Pipeline->>API: Return completion with tokens
187+
API->>User: Stream or return final response
188+
end
86189
```
87190

88191
### Key Data Structures

src/core/scheduler.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,11 @@ impl Scheduler {
12831283
false
12841284
}
12851285

1286+
/// Get the EOS token IDs from the scheduler
1287+
pub fn eos_token_ids(&self) -> &[u32] {
1288+
&self.eos_token_id
1289+
}
1290+
12861291
fn stop_sequence_match_index(&self, token: u32, seq: &Sequence) -> Option<usize> {
12871292
let Some(stop_sequences) = &seq.sampling_params.stop_token_ids else {
12881293
return None;

src/server/server.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,14 +486,19 @@ pub async fn chat_completion(
486486
if constraint_grammars.is_empty() && !engine_config.enable_tool_grammar {
487487
crate::log_debug!("[llg] No constraint or tool grammar - not setting guidance");
488488
} else {
489+
// Get EOS token IDs from engine scheduler for building TEXT pattern with EOS bounding
490+
let engine = data.engine.read();
491+
let eos_token_ids = engine.scheduler.eos_token_ids();
489492
let llg_grammar = compose_grammars(
490493
constraint_grammars,
491494
tool_grammar,
492495
has_tools,
493496
tool_choice_required,
494497
forced_tool_name,
495498
Some(max_tokens.clone()),
499+
eos_token_ids,
496500
);
501+
drop(engine); // Explicitly drop the lock guard
497502
let lark_string = get_lark_from_top_level_grammar(&llg_grammar);
498503
crate::log_debug!("[llg] TopLevelGrammar for SamplingParams: {:?}", &llg_grammar);
499504
crate::log_debug!("[llg] Lark grammar string:\n{}", lark_string);

0 commit comments

Comments
 (0)