Skip to content

Commit 3b34bd3

Browse files
guicho271828tuliocoppola
authored andcommitted
fix: rename format argument in internal methods for better mypiability (generative-computing#172)
* refactor: renamed 'format' variable to '_format' in internal methods so that mypy detects it * fix: use format = None
1 parent 5fb014d commit 3b34bd3

File tree

5 files changed

+66
-53
lines changed

5 files changed

+66
-53
lines changed

mellea/backends/huggingface.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
"""
6868
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]
6969

70+
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
71+
7072

7173
@dataclasses.dataclass
7274
class HFAloraCacheInfo:
@@ -209,11 +211,11 @@ def generate_from_context(
209211
reroute_to_alora = True
210212
if reroute_to_alora:
211213
mot = self._generate_from_context_alora(
212-
action, ctx, format=format, model_options=model_opts
214+
action, ctx, _format=format, model_options=model_opts
213215
)
214216
return mot, ctx.add(mot)
215217
mot = self._generate_from_context_standard(
216-
action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls
218+
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
217219
)
218220
return mot, ctx.add(action).add(mot)
219221

@@ -222,7 +224,7 @@ def _generate_from_context_alora(
222224
action: Component | CBlock,
223225
ctx: Context,
224226
*,
225-
format: type[BaseModelSubclass] | None = None,
227+
_format: type[BaseModelSubclass] | None = None,
226228
model_options: dict[str, Any],
227229
) -> ModelOutputThunk:
228230
match action:
@@ -245,7 +247,7 @@ def _generate_from_context_alora(
245247
assert alora_for_this_request is not None
246248
assert type(user_message) is str
247249
assert type(assistant_message) is str
248-
assert format is None, "Structured outputs are not supported by ALoRAs."
250+
assert _format is None, "Structured outputs are not supported by ALoRAs."
249251

250252
alora_output = alora_for_this_request.generate_using_strings(
251253
input=user_message,
@@ -269,7 +271,7 @@ def _generate_from_context_standard(
269271
action: Component | CBlock,
270272
ctx: Context,
271273
*,
272-
format: type[BaseModelSubclass] | None = None,
274+
_format: type[BaseModelSubclass] | None = None,
273275
model_options: dict[str, Any],
274276
tool_calls: bool = False,
275277
) -> ModelOutputThunk:
@@ -310,7 +312,7 @@ def _generate_from_context_standard(
310312
# Append tool call information if applicable.
311313
tools: dict[str, Callable] = dict()
312314
if tool_calls:
313-
if format:
315+
if _format:
314316
FancyLogger.get_logger().warning(
315317
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
316318
)
@@ -338,10 +340,10 @@ def _generate_from_context_standard(
338340
).to(self._device) # type: ignore
339341

340342
format_kwargs = {}
341-
if format:
343+
if _format:
342344
# outlines.generate.json always parses the resulting json into a python dict.
343345
# We however want to keep it as a json string for later storing it in ModelOutputThunk
344-
schema: dict[str, Any] = format.model_json_schema()
346+
schema: dict[str, Any] = _format.model_json_schema()
345347
schema_json: str = json.dumps(schema)
346348
regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore
347349
schema_json
@@ -406,7 +408,7 @@ def _generate_from_context_standard(
406408
self.post_processing,
407409
conversation=ctx_as_conversation,
408410
input_ids=input_ids,
409-
format=format,
411+
_format=_format,
410412
tool_calls=tool_calls,
411413
tools=tools,
412414
seed=seed,
@@ -463,7 +465,7 @@ async def post_processing(
463465
self,
464466
mot: ModelOutputThunk,
465467
conversation: list[dict],
466-
format: type[BaseModelSubclass] | None,
468+
_format: type[BaseModelSubclass] | None,
467469
tool_calls: bool,
468470
tools: dict[str, Callable],
469471
seed,
@@ -494,7 +496,7 @@ async def post_processing(
494496
self.cache_put(mot.value, cache_info)
495497

496498
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
497-
if format is None and tool_calls:
499+
if _format is None and tool_calls:
498500
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)
499501

500502
assert mot._action is not None, (
@@ -514,7 +516,7 @@ async def post_processing(
514516
generate_log.date = datetime.datetime.now()
515517
generate_log.model_output = mot.value
516518
generate_log.extra = {
517-
"format": format,
519+
"format": _format,
518520
"tools_available": tools,
519521
"tools_called": mot.tool_calls,
520522
"seed": seed,

mellea/backends/litellm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from mellea.stdlib.chat import Message
4141
from mellea.stdlib.requirement import ALoraRequirement
4242

43+
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
44+
4345

4446
class LiteLLMBackend(FormatterBackend):
4547
"""A generic LiteLLM compatible backend."""
@@ -123,7 +125,7 @@ def generate_from_context(
123125
mot = self._generate_from_chat_context_standard(
124126
action,
125127
ctx,
126-
format=format,
128+
_format=format,
127129
model_options=model_options,
128130
tool_calls=tool_calls,
129131
)
@@ -215,7 +217,7 @@ def _generate_from_chat_context_standard(
215217
action: Component | CBlock,
216218
ctx: Context,
217219
*,
218-
format: type[BaseModelSubclass]
220+
_format: type[BaseModelSubclass]
219221
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
220222
model_options: dict | None = None,
221223
tool_calls: bool = False,
@@ -249,12 +251,12 @@ def _generate_from_chat_context_standard(
249251
[OpenAIBackend.message_to_openai_message(m) for m in messages]
250252
)
251253

252-
if format is not None:
254+
if _format is not None:
253255
response_format = {
254256
"type": "json_schema",
255257
"json_schema": {
256-
"name": format.__name__,
257-
"schema": format.model_json_schema(),
258+
"name": _format.__name__,
259+
"schema": _format.model_json_schema(),
258260
"strict": True,
259261
},
260262
}
@@ -267,7 +269,7 @@ def _generate_from_chat_context_standard(
267269
thinking = "medium"
268270

269271
# Append tool call information if applicable.
270-
tools = self._extract_tools(action, format, model_opts, tool_calls, ctx)
272+
tools = self._extract_tools(action, _format, model_opts, tool_calls, ctx)
271273
formatted_tools = convert_tools_to_json(tools) if len(tools) > 0 else None
272274

273275
model_specific_options = self._make_backend_specific_and_remove(model_opts)
@@ -302,7 +304,7 @@ def _generate_from_chat_context_standard(
302304
conversation=conversation,
303305
tools=tools,
304306
thinking=thinking,
305-
format=format,
307+
_format=_format,
306308
)
307309

308310
try:
@@ -380,7 +382,7 @@ async def post_processing(
380382
conversation: list[dict],
381383
tools: dict[str, Callable],
382384
thinking,
383-
format,
385+
_format,
384386
):
385387
"""Called when generation is done."""
386388
# Reconstruct the chat_response from chunks if streamed.
@@ -425,7 +427,7 @@ async def post_processing(
425427
generate_log.date = datetime.datetime.now()
426428
generate_log.model_output = mot._meta["litellm_chat_response"]
427429
generate_log.extra = {
428-
"format": format,
430+
"format": _format,
429431
"tools_available": tools,
430432
"tools_called": mot.tool_calls,
431433
"seed": thinking,
@@ -436,11 +438,11 @@ async def post_processing(
436438

437439
@staticmethod
438440
def _extract_tools(
439-
action, format, model_opts, tool_calls, ctx
441+
action, _format, model_opts, tool_calls, ctx
440442
) -> dict[str, Callable]:
441443
tools: dict[str, Callable] = dict()
442444
if tool_calls:
443-
if format:
445+
if _format:
444446
FancyLogger.get_logger().warning(
445447
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
446448
)

mellea/backends/ollama.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
from mellea.stdlib.chat import Message
3737
from mellea.stdlib.requirement import ALoraRequirement
3838

39+
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
40+
3941

4042
class OllamaModelBackend(FormatterBackend):
4143
"""A model that uses the Ollama Python SDK for local inference."""
@@ -265,7 +267,7 @@ def generate_from_context(
265267
mot = self.generate_from_chat_context(
266268
action,
267269
ctx,
268-
format=format,
270+
_format=format,
269271
model_options=model_options,
270272
tool_calls=tool_calls,
271273
)
@@ -277,7 +279,7 @@ def generate_from_chat_context(
277279
action: Component | CBlock,
278280
ctx: Context,
279281
*,
280-
format: type[BaseModelSubclass] | None = None,
282+
_format: type[BaseModelSubclass] | None = None,
281283
model_options: dict | None = None,
282284
tool_calls: bool = False,
283285
) -> ModelOutputThunk:
@@ -325,7 +327,7 @@ def generate_from_chat_context(
325327
# Append tool call information if applicable.
326328
tools: dict[str, Callable] = dict()
327329
if tool_calls:
328-
if format:
330+
if _format:
329331
FancyLogger.get_logger().warning(
330332
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
331333
)
@@ -348,7 +350,7 @@ def generate_from_chat_context(
348350
think=model_opts.get(ModelOption.THINKING, None),
349351
stream=model_opts.get(ModelOption.STREAM, False),
350352
options=self._make_backend_specific_and_remove(model_opts),
351-
format=format.model_json_schema() if format is not None else None,
353+
format=_format.model_json_schema() if _format is not None else None,
352354
) # type: ignore
353355

354356
output = ModelOutputThunk(None)
@@ -360,7 +362,10 @@ def generate_from_chat_context(
360362
# each processing step.
361363
output._process = functools.partial(self.processing, tools=tools)
362364
output._post_process = functools.partial(
363-
self.post_processing, conversation=conversation, tools=tools, format=format
365+
self.post_processing,
366+
conversation=conversation,
367+
tools=tools,
368+
_format=_format,
364369
)
365370

366371
try:
@@ -523,7 +528,7 @@ async def post_processing(
523528
mot: ModelOutputThunk,
524529
conversation: list[dict],
525530
tools: dict[str, Callable],
526-
format,
531+
_format,
527532
):
528533
"""Called when generation is done."""
529534
assert mot._action is not None, (
@@ -542,7 +547,7 @@ async def post_processing(
542547
generate_log.date = datetime.datetime.now()
543548
generate_log.model_output = mot._meta["chat_response"]
544549
generate_log.extra = {
545-
"format": format,
550+
"format": _format,
546551
"thinking": mot._model_options.get(ModelOption.THINKING, None),
547552
"tools_available": tools,
548553
"tools_called": mot.tool_calls,

mellea/backends/openai.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555

5656
openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"
5757

58+
format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
59+
5860

5961
class _ServerType(Enum):
6062
LOCALHOST = 1
@@ -303,7 +305,7 @@ def generate_from_context(
303305
mot = self.generate_from_chat_context(
304306
action,
305307
ctx,
306-
format=format,
308+
_format=format,
307309
model_options=model_options,
308310
tool_calls=tool_calls,
309311
)
@@ -314,7 +316,7 @@ def generate_from_chat_context(
314316
action: Component | CBlock,
315317
ctx: Context,
316318
*,
317-
format: type[BaseModelSubclass]
319+
_format: type[BaseModelSubclass]
318320
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
319321
model_options: dict | None = None,
320322
tool_calls: bool = False,
@@ -332,13 +334,13 @@ def generate_from_chat_context(
332334
reroute_to_alora = True
333335
if reroute_to_alora:
334336
return self._generate_from_chat_context_alora(
335-
action, ctx, format=format, model_options=model_options
337+
action, ctx, _format=_format, model_options=model_options
336338
)
337339

338340
return self._generate_from_chat_context_standard(
339341
action,
340342
ctx,
341-
format=format,
343+
_format=_format,
342344
model_options=model_options,
343345
tool_calls=tool_calls,
344346
)
@@ -348,7 +350,7 @@ def _generate_from_chat_context_alora(
348350
action: Component | CBlock,
349351
ctx: Context,
350352
*,
351-
format: type[BaseModelSubclass]
353+
_format: type[BaseModelSubclass]
352354
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
353355
model_options: dict | None = None,
354356
) -> ModelOutputThunk:
@@ -373,7 +375,7 @@ def _generate_from_chat_context_alora(
373375
assert alora_for_this_request is not None
374376
assert type(user_message) is str
375377
assert type(assistant_message) is str
376-
assert format is None, "Structured outputs are not supported by ALoRAs."
378+
assert _format is None, "Structured outputs are not supported by ALoRAs."
377379

378380
model_opts = self._simplify_and_merge(model_options, is_chat_context=True)
379381

@@ -434,7 +436,7 @@ def _generate_from_chat_context_standard(
434436
action: Component | CBlock,
435437
ctx: Context,
436438
*,
437-
format: type[BaseModelSubclass]
439+
_format: type[BaseModelSubclass]
438440
| None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel
439441
model_options: dict | None = None,
440442
tool_calls: bool = False,
@@ -463,12 +465,12 @@ def _generate_from_chat_context_standard(
463465
conversation.append({"role": "system", "content": system_prompt})
464466
conversation.extend([self.message_to_openai_message(m) for m in messages])
465467

466-
if format is not None:
468+
if _format is not None:
467469
response_format = {
468470
"type": "json_schema",
469471
"json_schema": {
470-
"name": format.__name__,
471-
"schema": format.model_json_schema(),
472+
"name": _format.__name__,
473+
"schema": _format.model_json_schema(),
472474
"strict": True,
473475
},
474476
}
@@ -478,7 +480,7 @@ def _generate_from_chat_context_standard(
478480
# Append tool call information if applicable.
479481
tools: dict[str, Callable] = dict()
480482
if tool_calls:
481-
if format:
483+
if _format:
482484
FancyLogger.get_logger().warning(
483485
f"Tool calling typically uses constrained generation, but you have specified a `format` in your generate call. NB: tool calling is superseded by format; we will NOT call tools for your request: {action}"
484486
)
@@ -527,7 +529,7 @@ def _generate_from_chat_context_standard(
527529
conversation=conversation,
528530
thinking=thinking,
529531
seed=model_opts.get(ModelOption.SEED, None),
530-
format=format,
532+
_format=_format,
531533
)
532534

533535
try:
@@ -596,7 +598,7 @@ async def post_processing(
596598
conversation: list[dict],
597599
thinking,
598600
seed,
599-
format,
601+
_format,
600602
):
601603
"""Called when generation is done."""
602604
# Reconstruct the chat_response from chunks if streamed.
@@ -634,7 +636,7 @@ async def post_processing(
634636
generate_log.date = datetime.datetime.now()
635637
generate_log.model_output = mot._meta["oai_chat_response"]
636638
generate_log.extra = {
637-
"format": format,
639+
"format": _format,
638640
"thinking": thinking,
639641
"tools_available": tools,
640642
"tools_called": mot.tool_calls,

0 commit comments

Comments
 (0)