@@ -482,6 +482,7 @@ def generate_with_budget_forcing(
482482 answer_suffix : str = "The final answer is:" ,
483483 answer_regex : str = "boxed" ,
484484 model_options : dict | None = None ,
485+ generate_logs : list [GenerateLog ] | None = None ,
485486 ) -> tuple [str , int ]:
486487 """Generate with budget forcing using the completions APIs. This relies on raw autocompletion and assumes the model's output is structured in the following form: '<think> ... </think> summary answer'
487488 The budget forcing method is proposed in the paper: https://arxiv.org/abs/2501.19393
@@ -537,23 +538,13 @@ def generate_with_budget_forcing(
537538 break
538539
539540 backend_opts ["max_tokens" ] = rem_toks
540- try :
541- completion_response = self ._client .completions .create (
542- model = self ._hf_model_id , prompt = curr_prompt , ** backend_opts
543- ) # type: ignore
544- except openai .BadRequestError as e :
545- if openai_ollama_batching_error in e .message :
546- FancyLogger .get_logger ().error (
547- "If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
548- "your requests will fail since ollama doesn't support batching requests."
549- )
550- raise e
551-
552- # Necessary for type checker.
553- assert isinstance (completion_response .usage , CompletionUsage )
554- gen_tok_count += completion_response .usage .completion_tokens
541+ # TODO workaround to obtain generated token counts
542+ # The token count should be relayed by openai's CompletionUsage
543+ backend_opts ["logprobs" ] = 1 # To get number of generated tokens
544+ result = self ._generate_from_raw ([prompt ], model_options = backend_opts , generate_logs = generate_logs )
545+ gen_tok_count += len (result [0 ]._meta ['oai_completion_response' ]['logprobs' ]['token_logprobs' ])
555546 rem_toks = think_max_tokens - gen_tok_count
556- response = completion_response . choices [0 ].text
547+ response = result [0 ].value
557548
558549 if think_wait_suffix == "" :
559550 # non-strict budget form
@@ -611,22 +602,10 @@ def generate_with_budget_forcing(
611602 else :
612603 backend_opts .pop ("max_tokens" , None ) # generate unconditionally
613604
614- try :
615- completion_response = self ._client .completions .create (
616- model = self ._hf_model_id , prompt = prompt , ** backend_opts
617- ) # type: ignore
618- except openai .BadRequestError as e :
619- if openai_ollama_batching_error in e .message :
620- FancyLogger .get_logger ().error (
621- "If you are trying to call `OpenAIBackend.generate_with_budget_forcing while targeting an ollama server, "
622- "your requests will fail since ollama doesn't support batching requests."
623- )
624- raise e
625-
626- # Necessary for type checker.
627- assert isinstance (completion_response .usage , CompletionUsage )
628- response += completion_response .choices [0 ].text
629- gen_tok_count += completion_response .usage .completion_tokens
605+ backend_opts ["logprobs" ] = 1 # To get number of generated tokens
606+ result = self ._generate_from_raw ([prompt ], model_options = backend_opts , generate_logs = generate_logs )
607+ response += result [0 ].value
608+ gen_tok_count += len (result [0 ]._meta ['oai_completion_response' ]['logprobs' ]['token_logprobs' ])
630609 return response , gen_tok_count
631610
632611 def _generate_from_raw (
0 commit comments