Skip to content

Commit 4115c9c

Browse files
committed
Never output EOS token's text
There is a global output-special-tokens option which is disabled by default, but when enabled causes all special tokens to be output including the terminating EOS token. For the use cases we've encountered this is undesirable. I can't think of a case where this would be needed/wanted since it will only ever be at the end of the output, and in these cases the returned stop_reason will be EOS_TOKEN.
1 parent c106c67 commit 4115c9c

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

integration_tests/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ markers = [
2020
"model",
2121
"extensions",
2222
"shards",
23+
"output_special_tokens",
2324
"test_case_file",
2425
]
2526

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
# EOS token
3+
- name: EOS
4+
request:
5+
params:
6+
stopping:
7+
maxNewTokens: 10
8+
requests:
9+
- {"text": "In one word, the capital of France is"}
10+
response:
11+
responses:
12+
- generatedTokenCount: 2
13+
inputTokenCount: 10
14+
stopReason: EOS_TOKEN
15+
text: France

integration_tests/text_generation_tests/test_server.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def start_server(
3131
timeout=20,
3232
model_path=None,
3333
include_cache_env_vars=True,
34+
output_special_tokens=False,
3435
):
3536
# Download weights to the cache first
3637
print(f"Downloading files for model {model_name}...")
@@ -64,6 +65,9 @@ def start_server(
6465
"--max-batch-weight", "80000",
6566
]
6667

68+
if output_special_tokens:
69+
args.append("--output-special-tokens")
70+
6771
env = os.environ.copy()
6872
env["RUST_BACKTRACE"] = "full"
6973
env["PREFIX_STORE_PATH"] = os.path.join(TESTS_DIR, "prompt_prefixes")
@@ -115,7 +119,11 @@ def server_fixture(request):
115119
model_name = request.node.get_closest_marker("model").args[0]
116120
shards = int(request.node.get_closest_marker("shards").args[0])
117121
extensions = request.node.get_closest_marker("extensions").args[0]
118-
p = start_server(model_name, extensions, shards, 3000, 29502)
122+
ost = request.node.get_closest_marker("output_special_tokens")
123+
ost = ost is not None and ost.args[0]
124+
p = start_server(
125+
model_name, extensions, shards, 3000, 29502, output_special_tokens=ost
126+
)
119127
yield p
120128
p.terminate()
121129
assert p.wait(8.0) == 0
@@ -356,6 +364,16 @@ async def test_bloom(server_fixture, test_cases):
356364
await run_test_cases_async(test_cases, sharded=True)
357365

358366

367+
@pytest.mark.model("bigscience/mt0-small")
368+
@pytest.mark.extensions(".bin,.json")
369+
@pytest.mark.shards(1)
370+
@pytest.mark.output_special_tokens(True)
371+
@pytest.mark.test_case_file("test_cases_mt0_ost.yaml")
372+
@pytest.mark.asyncio
373+
async def test_mt0_output_special_tokens(server_fixture, test_cases):
374+
await run_test_cases_async(test_cases)
375+
376+
359377
# Test loading when an explicit local path is provided
360378
def test_explicit_path():
361379
# Test with and without providing TRANSFORMERS_CACHE env var

router/src/decoder.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@ impl Decoder {
3535
}
3636
}
3737

38-
fn decode_full(&self, ids: &[u32]) -> Result<String, InferError> {
38+
fn decode_full(&self, mut ids: &[u32]) -> Result<String, InferError> {
39+
if !self.skip_special_toks && ids.last() == Some(&self.eos_token_id) {
40+
ids = &ids[..(ids.len()-1)];
41+
}
3942
self.tokenizer.decode(ids, self.skip_special_toks).map_err(Error::into)
4043
}
4144

0 commit comments

Comments
 (0)