Skip to content

Commit e013f92

Browse files
author
Yousef El-Kurdi
committed
backend interface using _raw_generate
1 parent 1df37d5 commit e013f92

File tree

5 files changed

+49
-67
lines changed

5 files changed

+49
-67
lines changed

mellea/backends/openai.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ def generate_with_budget_forcing(
482482
answer_suffix: str = "The final answer is:",
483483
answer_regex: str = "boxed",
484484
model_options: dict | None = None,
485+
generate_logs: list[GenerateLog] | None = None,
485486
) -> tuple[str, int]:
486487
"""Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
487488
The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
@@ -537,23 +538,13 @@ def generate_with_budget_forcing(
537538
break
538539

539540
backend_opts["max_tokens"] = rem_toks
540-
try:
541-
completion_response = self._client.completions.create(
542-
model=self._hf_model_id, prompt=curr_prompt, **backend_opts
543-
) # type: ignore
544-
except openai.BadRequestError as e:
545-
if openai_ollama_batching_error in e.message:
546-
FancyLogger.get_logger().error(
547-
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
548-
"your requests will fail since ollama doesn't support batching requests."
549-
)
550-
raise e
551-
552-
# Necessary for type checker.
553-
assert isinstance(completion_response.usage, CompletionUsage)
554-
gen_tok_count += completion_response.usage.completion_tokens
541+
# TODO workaround to obtain generated token counts
542+
# The token count should be relayed by openai's CompletionUsage
543+
backend_opts["logprobs"] = 1 # To get number of generated tokens
544+
result = self._generate_from_raw([prompt], model_options=backend_opts, generate_logs=generate_logs)
545+
gen_tok_count += len(result[0]._meta['oai_completion_response']['logprobs']['token_logprobs'])
555546
rem_toks = think_max_tokens - gen_tok_count
556-
response = completion_response.choices[0].text
547+
response = result[0].value
557548

558549
if think_wait_suffix == "":
559550
# non-strict budget form
@@ -611,22 +602,10 @@ def generate_with_budget_forcing(
611602
else:
612603
backend_opts.pop("max_tokens", None) # generate unconditionally
613604

614-
try:
615-
completion_response = self._client.completions.create(
616-
model=self._hf_model_id, prompt=prompt, **backend_opts
617-
) # type: ignore
618-
except openai.BadRequestError as e:
619-
if openai_ollama_batching_error in e.message:
620-
FancyLogger.get_logger().error(
621-
"If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
622-
"your requests will fail since ollama doesn't support batching requests."
623-
)
624-
raise e
625-
626-
# Necessary for type checker.
627-
assert isinstance(completion_response.usage, CompletionUsage)
628-
response += completion_response.choices[0].text
629-
gen_tok_count += completion_response.usage.completion_tokens
605+
backend_opts["logprobs"] = 1 # To get number of generated tokens
606+
result = self._generate_from_raw([prompt], model_options=backend_opts, generate_logs=generate_logs)
607+
response += result[0].value
608+
gen_tok_count += len(result[0]._meta['oai_completion_response']['logprobs']['token_logprobs'])
630609
return response, gen_tok_count
631610

632611
def _generate_from_raw(

test/backends/test_think_budget_forcing/install.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,6 @@ in-conda uv pip install pre-commit
1616
in-conda uv pip install pytest
1717
in-conda uv pip install vllm==0.10.0
1818
in-conda uv pip install outlines
19+
# in-conda uv pip install unsloth
20+
in-conda uv pip install ipdb
21+

test/backends/test_think_budget_forcing/run_test.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#!/bin/bash
22

3+
export PYTHONBREAKPOINT="ipdb.set_trace"
4+
export LOCAL_TEST_MODEL="ibm-granite/granite-4.0-tiny-preview"
5+
# export LOCAL_TEST_MODEL="unsloth/Llama-3.2-1B"
6+
37
ENV_NAME=mellea_tbf
48
eval "$(conda shell.bash hook)"
59
conda activate $ENV_NAME

test/backends/test_think_budget_forcing/serve.sh

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,11 @@
11
#!/bin/bash
22

3-
# @Masa note:
4-
# the following code is a bash snippet Kristian gave me
5-
# for how to run vllm with lora adapter loaded.
6-
7-
# HF_GRANITE_ALORA_SNAPSHOT=${HF_HOME:-$HOME/.cache/huggingface}
8-
# HF_GRANITE_ALORA_SNAPSHOT+=/hub/
9-
# HF_GRANITE_ALORA_SNAPSHOT+=models--ibm-granite--granite-3.2-8b-alora-requirement-check/
10-
# HF_GRANITE_ALORA_SNAPSHOT+=snapshots/d55a7a7f5796609bc938c5c151a864cfcc6ab54e
11-
12-
# vllm serve ibm-granite/granite-3.2-8b-instruct \
13-
# --enable-lora \
14-
# --lora-modules "{\"name\": \"ibm-granite/granite-3.2-8b-alora-requirement-check\", \"path\": \"${HF_GRANITE_ALORA_SNAPSHOT}\", \"base_model_name\": \"ibm-granite/granite-3.2-8b-instruct\"}" \
15-
# --dtype bfloat16 \
16-
# --max-lora-rank 64 \
17-
# --enable-prefix-caching
18-
19-
# However, in our test, we do not load the alora when we serve.
20-
# In this test, we use the dynamic loading interface from
21-
# https://docs.vllm.ai/en/stable/features/lora.html#dynamically-serving-lora-adapters
22-
23-
# Using this feature requires the following environment variable.
24-
# If you use conda/miniforge,
25-
# this variable must have been set already when you set up the environment.
26-
# see environment.yml.
273
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
284

295
echo "launching a vllm server. Logs are found in $(readlink -ef $(dirname $0))/vllm.log"
306
# At the time of writing this code, Granite 4.4 vLLM serving did not support prefix-caching
317
# --enable-prefix-caching \
32-
vllm serve ibm-granite/granite-4.0-tiny-preview \
8+
vllm serve $LOCAL_TEST_MODEL \
339
--dtype bfloat16 \
3410
> $(readlink -ef $(dirname $0))/vllm.log \
3511
2> $(readlink -ef $(dirname $0))/vllm.err

test/backends/test_think_budget_forcing/test_think_budget_forcing.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
from mellea import MelleaSession
2+
from mellea.backends.model_ids import OPENAI_GPT_OSS_20B, META_LLAMA_3_2_1B, IBM_GRANITE_4_TINY_PREVIEW_7B
23
from mellea.stdlib.base import CBlock, SimpleContext
34
from mellea.backends.openai import OpenAIBackend
5+
from mellea.backends.formatter import TemplateFormatter
46
from transformers import AutoTokenizer
57
import pytest
68
import os
79

10+
811
class TestOpenAIBackend:
12+
MODEL_ID = os.environ.get("LOCAL_TEST_MODEL", META_LLAMA_3_2_1B)
13+
# Local testing mode
14+
if MODEL_ID == "ibm-granite/granite-4.0-tiny-preview":
15+
MODEL_ID = IBM_GRANITE_4_TINY_PREVIEW_7B
16+
17+
elif MODEL_ID == "unsloth/Llama-3.2-1B":
18+
MODEL_ID = META_LLAMA_3_2_1B
19+
20+
else:
21+
raise RuntimeError(f"Unsupported model-id:{MODEL_ID}")
22+
923
model_id = "ibm-granite/granite-4.0-tiny-preview"
1024
backend = OpenAIBackend(
11-
model_id=model_id,
12-
base_url="http://0.0.0.0:8000/v1",
13-
api_key="EMPTY",
25+
model_id=MODEL_ID,
26+
formatter=TemplateFormatter(model_id=MODEL_ID),
27+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:8000')}/v1",
28+
api_key="ollama",
1429
)
30+
1531
m = MelleaSession(backend, ctx=SimpleContext())
16-
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
32+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID.hf_model_name, trust_remote_code=True)
1733

1834
def prepare_prmpt_for_math(self, query):
1935
# Preparing prompt for math reasoning tasks
@@ -28,12 +44,16 @@ def prepare_prmpt_for_math(self, query):
2844
msg.append({"role": "system", "content": system_prompt})
2945

3046
msg.append({"role": "user", "content": query})
31-
prompt = self.tokenizer.apply_chat_template(
32-
msg,
33-
tokenize=False,
34-
thinking=True,
35-
add_generation_prompt=True,
36-
)
47+
if self.tokenizer.chat_template is None:
48+
raise RuntimeError(f"No explicit chat template is defined for model-id: ")
49+
50+
else:
51+
prompt = self.tokenizer.apply_chat_template(
52+
msg,
53+
tokenize=False,
54+
thinking=True,
55+
add_generation_prompt=True,
56+
)
3757

3858
return prompt
3959

0 commit comments

Comments
 (0)