3030from transformers .generation .utils import GenerateDecoderOnlyOutput
3131
3232from mellea .backends import BaseModelSubclass
33+ from mellea .backends ._utils import to_chat , to_tool_calls , use_alora
3334from mellea .backends .aloras import Alora , AloraBackendMixin
3435from mellea .backends .cache import Cache , SimpleLRUCache
3536from mellea .backends .formatter import Formatter , FormatterBackend , TemplateFormatter
3940 add_tools_from_context_actions ,
4041 add_tools_from_model_options ,
4142 convert_tools_to_json ,
42- parse_tools ,
4343)
4444from mellea .backends .types import ModelOption
4545from mellea .helpers .async_helpers import send_to_queue
@@ -198,26 +198,24 @@ def generate_from_context(
198198 # Upsert model options.
199199 model_opts = self ._simplify_and_merge (model_options )
200200
201- # See `docs/dev/requirement_aLoRA_rerouting.md` for an explanation of the following code block.
202- if issubclass (type (action ), Requirement ):
203- # The general rule is that we reroute to the alora if it exists.
204- reroute_to_alora = self .get_alora ("constraint" ) is not None
205- # However, there are some exceptions:
206- if not self .default_to_constraint_checking_alora :
207- reroute_to_alora = False
208- if issubclass (type (action ), LLMaJRequirement ):
209- reroute_to_alora = False
210- if issubclass (type (action ), ALoraRequirement ):
211- reroute_to_alora = True
212- if reroute_to_alora :
213- mot = self ._generate_from_context_alora (
214- action , ctx , _format = format , model_options = model_opts
215- )
216- return mot , ctx .add (mot )
217- mot = self ._generate_from_context_standard (
218- action , ctx , _format = format , model_options = model_opts , tool_calls = tool_calls
219- )
220- return mot , ctx .add (action ).add (mot )
201+ if use_alora (
202+ action ,
203+ self .get_alora ("constraint" ),
204+ self .default_to_constraint_checking_alora ,
205+ ):
206+ mot = self ._generate_from_context_alora (
207+ action , ctx , _format = format , model_options = model_opts
208+ )
209+ return mot , ctx .add (mot )
210+ else :
211+ mot = self ._generate_from_context_standard (
212+ action ,
213+ ctx ,
214+ _format = format ,
215+ model_options = model_opts ,
216+ tool_calls = tool_calls ,
217+ )
218+ return mot , ctx .add (action ).add (mot )
221219
222220 def _generate_from_context_alora (
223221 self ,
@@ -279,35 +277,8 @@ def _generate_from_context_standard(
279277 # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
280278 # Otherwise, we will linearize the context and treat it as a raw input.
281279 if ctx .is_chat_context :
282- linearized_ctx = ctx .view_for_generation ()
283- assert linearized_ctx is not None , (
284- "If ctx.is_chat_context, then the context should be linearizable."
285- )
286- ctx_as_message_list : list [Message ] = self .formatter .to_chat_messages (
287- linearized_ctx
288- )
289- # add action
290- ctx_as_message_list .extend (self .formatter .to_chat_messages ([action ]))
291- ctx_as_conversation = [
292- {"role" : m .role , "content" : m .content } for m in ctx_as_message_list
293- ]
294-
295- # Check that we ddin't accidentally end up with CBlocks.
296- for msg in ctx_as_conversation :
297- for v in msg .values ():
298- if "CBlock" in v :
299- FancyLogger .get_logger ().error (
300- f"Found the string `CBlock` in what should've been a stringified context: { ctx_as_conversation } "
301- )
302-
303- # handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step.
304280 system_prompt = model_options .get (ModelOption .SYSTEM_PROMPT , None )
305- if system_prompt is not None :
306- system_msg : dict [str , str ] = {
307- "role" : "system" ,
308- "content" : system_prompt ,
309- }
310- ctx_as_conversation .insert (0 , system_msg )
281+ ctx_as_chat = to_chat (action , ctx , self .formatter , system_prompt )
311282
312283 # Append tool call information if applicable.
313284 tools : dict [str , Callable ] = dict ()
@@ -332,7 +303,7 @@ def _generate_from_context_standard(
332303 set_seed (seed )
333304
334305 input_ids = self ._tokenizer .apply_chat_template ( # type: ignore
335- ctx_as_conversation ,
306+ ctx_as_chat ,
336307 tools = convert_tools_to_json (tools ), # type: ignore
337308 add_generation_prompt = True , # If we change this, must modify huggingface granite guardian.
338309 return_tensors = "pt" ,
@@ -397,7 +368,7 @@ def _generate_from_context_standard(
397368 )
398369
399370 output = ModelOutputThunk (None )
400- output ._context = linearized_ctx
371+ output ._context = ctx . view_for_generation ()
401372 output ._action = action
402373 output ._model_options = model_options
403374
@@ -406,7 +377,7 @@ def _generate_from_context_standard(
406377 output ._process = functools .partial (self .processing , input_ids = input_ids )
407378 output ._post_process = functools .partial (
408379 self .post_processing ,
409- conversation = ctx_as_conversation ,
380+ conversation = ctx_as_chat ,
410381 input_ids = input_ids ,
411382 _format = _format ,
412383 tool_calls = tool_calls ,
@@ -497,7 +468,7 @@ async def post_processing(
497468
498469 # Only scan for tools if we are not doing structured output and tool calls were provided to the model.
499470 if _format is None and tool_calls :
500- mot .tool_calls = self . _extract_model_tool_requests (tools , mot .value )
471+ mot .tool_calls = to_tool_calls (tools , mot .value )
501472
502473 assert mot ._action is not None , (
503474 "ModelOutputThunks should have their action assigned during generation"
@@ -698,30 +669,6 @@ def _filter_chat_template_only_options(
698669 }
699670 return {k : v for k , v in model_options .items () if k not in chat_template_only }
700671
701- def _extract_model_tool_requests (
702- self , tools : dict [str , Callable ], decoded_result : str
703- ) -> dict [str , ModelToolCall ] | None :
704- model_tool_calls : dict [str , ModelToolCall ] = dict ()
705- for tool_name , tool_args in parse_tools (decoded_result ):
706- func = tools .get (tool_name )
707- if func is None :
708- FancyLogger .get_logger ().warning (
709- f"model attempted to call a non-existing function: { tool_name } "
710- )
711- continue
712-
713- # Clean up the function args slightly. Some models seem to
714- # hallucinate parameters when none are required.
715- sig = inspect .signature (func )
716- if len (sig .parameters ) == 0 :
717- tool_args = {}
718-
719- model_tool_calls [tool_name ] = ModelToolCall (tool_name , func , tool_args )
720-
721- if len (model_tool_calls ) > 0 :
722- return model_tool_calls
723- return None
724-
725672 # region ALora loading, unloading, and utility functions.
726673 def add_alora (self , alora : HFAlora ):
727674 """Loads an ALora for this backend.
0 commit comments