Skip to content

Commit 63afbe9

Browse files
authored
[CI] Expand OpenAI test_chat.py guided decoding tests (#11048)
Signed-off-by: mgoin <[email protected]>
1 parent 8cef6e0 commit 63afbe9

File tree

1 file changed

+12
-17
lines changed

1 file changed

+12
-17
lines changed

tests/entrypoints/openai/test_chat.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# any model with a chat template should work here
1818
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
1919

20+
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
21+
2022

2123
@pytest.fixture(scope="module")
2224
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
@@ -464,8 +466,7 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
464466
# will fail on the second `guided_decoding_backend` even when I swap their order
465467
# (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256)
466468
@pytest.mark.asyncio
467-
@pytest.mark.parametrize("guided_decoding_backend",
468-
["outlines", "lm-format-enforcer"])
469+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
469470
async def test_guided_choice_chat(client: openai.AsyncOpenAI,
470471
guided_decoding_backend: str,
471472
sample_guided_choice):
@@ -506,8 +507,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
506507

507508

508509
@pytest.mark.asyncio
509-
@pytest.mark.parametrize("guided_decoding_backend",
510-
["outlines", "lm-format-enforcer"])
510+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
511511
async def test_guided_json_chat(client: openai.AsyncOpenAI,
512512
guided_decoding_backend: str,
513513
sample_json_schema):
@@ -554,8 +554,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
554554

555555

556556
@pytest.mark.asyncio
557-
@pytest.mark.parametrize("guided_decoding_backend",
558-
["outlines", "lm-format-enforcer"])
557+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
559558
async def test_guided_regex_chat(client: openai.AsyncOpenAI,
560559
guided_decoding_backend: str, sample_regex):
561560
messages = [{
@@ -613,8 +612,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
613612

614613

615614
@pytest.mark.asyncio
616-
@pytest.mark.parametrize("guided_decoding_backend",
617-
["outlines", "lm-format-enforcer"])
615+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
618616
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
619617
guided_decoding_backend: str,
620618
sample_guided_choice):
@@ -646,8 +644,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
646644

647645

648646
@pytest.mark.asyncio
649-
@pytest.mark.parametrize("guided_decoding_backend",
650-
["outlines", "lm-format-enforcer"])
647+
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
651648
async def test_named_tool_use(client: openai.AsyncOpenAI,
652649
guided_decoding_backend: str,
653650
sample_json_schema):
@@ -681,7 +678,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
681678
"function": {
682679
"name": "dummy_function_name"
683680
}
684-
})
681+
},
682+
extra_body=dict(guided_decoding_backend=guided_decoding_backend))
685683
message = chat_completion.choices[0].message
686684
assert len(message.content) == 0
687685
json_string = message.tool_calls[0].function.arguments
@@ -716,6 +714,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
716714
"name": "dummy_function_name"
717715
}
718716
},
717+
extra_body=dict(guided_decoding_backend=guided_decoding_backend),
719718
stream=True)
720719

721720
output = []
@@ -738,10 +737,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
738737

739738

740739
@pytest.mark.asyncio
741-
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
742-
async def test_required_tool_use_not_yet_supported(
743-
client: openai.AsyncOpenAI, guided_decoding_backend: str,
744-
sample_json_schema):
740+
async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
741+
sample_json_schema):
745742
messages = [{
746743
"role": "system",
747744
"content": "you are a helpful assistant"
@@ -785,9 +782,7 @@ async def test_required_tool_use_not_yet_supported(
785782

786783

787784
@pytest.mark.asyncio
788-
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
789785
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
790-
guided_decoding_backend: str,
791786
sample_json_schema):
792787
messages = [{
793788
"role": "system",

0 commit comments

Comments
 (0)