Skip to content

Commit e157ebb

Browse files
Retain return type from @dispatcher.span (#17817)
1 parent 85a0046 commit e157ebb

File tree

18 files changed

+145
-102
lines changed

18 files changed

+145
-102
lines changed

llama-index-core/llama_index/core/chat_engine/condense_plus_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def stream_chat(
347347
)
348348

349349
response = synthesizer.synthesize(message, context_nodes)
350+
assert isinstance(response, StreamingResponse)
350351

351352
def wrapped_gen(response: StreamingResponse) -> ChatResponseGen:
352353
full_response = ""
@@ -405,6 +406,7 @@ async def astream_chat(
405406
)
406407

407408
response = await synthesizer.asynthesize(message, context_nodes)
409+
assert isinstance(response, AsyncStreamingResponse)
408410

409411
async def wrapped_gen(response: AsyncStreamingResponse) -> ChatResponseAsyncGen:
410412
full_response = ""

llama-index-core/llama_index/core/evaluation/retrieval/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def _aget_retrieved_ids_and_texts(
4545

4646
return (
4747
[node.node.node_id for node in retrieved_nodes],
48-
[node.node.text for node in retrieved_nodes],
48+
[node.text for node in retrieved_nodes],
4949
)
5050

5151

@@ -84,7 +84,7 @@ async def _aget_retrieved_ids_and_texts(
8484
node = scored_node.node
8585
if isinstance(node, ImageNode):
8686
image_nodes.append(node)
87-
if node.text:
87+
if isinstance(node, TextNode):
8888
text_nodes.append(node)
8989

9090
if mode == "text":

llama-index-core/llama_index/core/extractors/metadata_extractors.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
(similar with contrastive learning)
2121
"""
2222

23-
from typing import Any, Callable, Dict, List, Optional, Sequence, cast
23+
from typing import Any, Callable, Dict, Generic, List, Optional, Sequence, cast
2424

2525
from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
2626
from llama_index.core.bridge.pydantic import (
@@ -33,7 +33,7 @@
3333
from llama_index.core.prompts import PromptTemplate
3434
from llama_index.core.schema import BaseNode, TextNode
3535
from llama_index.core.settings import Settings
36-
from llama_index.core.types import BasePydanticProgram
36+
from llama_index.core.types import BasePydanticProgram, Model
3737

3838
DEFAULT_TITLE_NODE_TEMPLATE = """\
3939
Context: {context_str}. Give a title that summarizes all of \
@@ -462,15 +462,15 @@ async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
462462
"""
463463

464464

465-
class PydanticProgramExtractor(BaseExtractor):
465+
class PydanticProgramExtractor(BaseExtractor, Generic[Model]):
466466
"""Pydantic program extractor.
467467
468468
Uses an LLM to extract out a Pydantic object. Return attributes of that object
469469
in a dictionary.
470470
471471
"""
472472

473-
program: SerializeAsAny[BasePydanticProgram] = Field(
473+
program: SerializeAsAny[BasePydanticProgram[Model]] = Field(
474474
..., description="Pydantic program to extract."
475475
)
476476
input_key: str = Field(
@@ -500,7 +500,9 @@ async def _acall_program(self, node: BaseNode) -> Dict[str, Any]:
500500
)
501501

502502
ret_object = await self.program.acall(**{self.input_key: extract_str})
503-
return ret_object.dict()
503+
assert not isinstance(ret_object, list)
504+
505+
return ret_object.model_dump()
504506

505507
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
506508
"""Extract pydantic program."""

llama-index-core/llama_index/core/indices/common_tree/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ async def abuild_index_from_nodes(
219219
self._llm.apredict(self.summary_prompt, context_str=text_chunk)
220220
for text_chunk in text_chunks_progress
221221
]
222-
outputs: List[Tuple[str, str]] = await asyncio.gather(*tasks)
223-
summaries = [output[0] for output in outputs]
222+
summaries = await asyncio.gather(*tasks)
224223

225224
event.on_end(payload={"summaries": summaries, "level": level})
226225

llama-index-core/llama_index/core/instrumentation/dispatcher.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import partial
33
from contextlib import contextmanager
44
from contextvars import Context, ContextVar, Token, copy_context
5-
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol
5+
from typing import Any, Callable, Generator, List, Optional, Dict, Protocol, TypeVar
66
import inspect
77
import logging
88
import uuid
@@ -26,6 +26,7 @@
2626
active_instrument_tags: ContextVar[Dict[str, Any]] = ContextVar(
2727
"instrument_tags", default={}
2828
)
29+
_R = TypeVar("_R")
2930

3031

3132
@contextmanager
@@ -239,7 +240,7 @@ def span_exit(
239240
else:
240241
c = c.parent
241242

242-
def span(self, func: Callable) -> Any:
243+
def span(self, func: Callable[..., _R]) -> Callable[..., _R]:
243244
# The `span` decorator should be idempotent.
244245
try:
245246
if hasattr(func, DISPATCHER_SPAN_DECORATED_ATTR):

llama-index-core/llama_index/core/llms/llm.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373

7474
if TYPE_CHECKING:
7575
from llama_index.core.chat_engine.types import AgentChatResponse
76+
from llama_index.core.program.utils import FlexibleModel
7677
from llama_index.core.tools.types import BaseTool
7778
from llama_index.core.llms.structured_llm import StructuredLLM
7879

@@ -322,11 +323,11 @@ def _as_query_component(self, **kwargs: Any) -> QueryComponent:
322323
@dispatcher.span
323324
def structured_predict(
324325
self,
325-
output_cls: Type[BaseModel],
326+
output_cls: Type[Model],
326327
prompt: PromptTemplate,
327328
llm_kwargs: Optional[Dict[str, Any]] = None,
328329
**prompt_args: Any,
329-
) -> BaseModel:
330+
) -> Model:
330331
r"""Structured predict.
331332
332333
Args:
@@ -372,17 +373,19 @@ class Test(BaseModel):
372373
)
373374

374375
result = program(llm_kwargs=llm_kwargs, **prompt_args)
376+
assert not isinstance(result, list)
377+
375378
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
376379
return result
377380

378381
@dispatcher.span
379382
async def astructured_predict(
380383
self,
381-
output_cls: Type[BaseModel],
384+
output_cls: Type[Model],
382385
prompt: PromptTemplate,
383386
llm_kwargs: Optional[Dict[str, Any]] = None,
384387
**prompt_args: Any,
385-
) -> BaseModel:
388+
) -> Model:
386389
r"""Async Structured predict.
387390
388391
Args:
@@ -429,17 +432,19 @@ class Test(BaseModel):
429432
)
430433

431434
result = await program.acall(llm_kwargs=llm_kwargs, **prompt_args)
435+
assert not isinstance(result, list)
436+
432437
dispatcher.event(LLMStructuredPredictEndEvent(output=result))
433438
return result
434439

435440
@dispatcher.span
436441
def stream_structured_predict(
437442
self,
438-
output_cls: Type[BaseModel],
443+
output_cls: Type[Model],
439444
prompt: PromptTemplate,
440445
llm_kwargs: Optional[Dict[str, Any]] = None,
441446
**prompt_args: Any,
442-
) -> Generator[Union[Model, List[Model]], None, None]:
447+
) -> Generator[Union[Model, "FlexibleModel"], None, None]:
443448
r"""Stream Structured predict.
444449
445450
Args:
@@ -489,18 +494,19 @@ class Test(BaseModel):
489494
result = program.stream_call(llm_kwargs=llm_kwargs, **prompt_args)
490495
for r in result:
491496
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
497+
assert not isinstance(r, list)
492498
yield r
493499

494500
dispatcher.event(LLMStructuredPredictEndEvent(output=r))
495501

496502
@dispatcher.span
497503
async def astream_structured_predict(
498504
self,
499-
output_cls: Type[BaseModel],
505+
output_cls: Type[Model],
500506
prompt: PromptTemplate,
501507
llm_kwargs: Optional[Dict[str, Any]] = None,
502508
**prompt_args: Any,
503-
) -> AsyncGenerator[Union[Model, List[Model]], None]:
509+
) -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
504510
r"""Async Stream Structured predict.
505511
506512
Args:
@@ -534,8 +540,10 @@ class Test(BaseModel):
534540
```
535541
"""
536542

537-
async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
538-
from llama_index.core.program.utils import get_program_for_llm
543+
async def gen() -> AsyncGenerator[Union[Model, "FlexibleModel"], None]:
544+
from llama_index.core.program.utils import (
545+
get_program_for_llm,
546+
)
539547

540548
dispatcher.event(
541549
LLMStructuredPredictStartEvent(
@@ -552,6 +560,7 @@ async def gen() -> AsyncGenerator[Union[Model, List[Model]], None]:
552560
result = await program.astream_call(llm_kwargs=llm_kwargs, **prompt_args)
553561
async for r in result:
554562
dispatcher.event(LLMStructuredPredictInProgressEvent(output=r))
563+
assert not isinstance(r, list)
555564
yield r
556565

557566
dispatcher.event(LLMStructuredPredictEndEvent(output=r))

llama-index-core/llama_index/core/output_parsers/pydantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Pydantic output parser."""
22

33
import json
4-
from typing import Any, List, Optional, Type
4+
from typing import Any, Generic, List, Optional, Type
55

66
from llama_index.core.output_parsers.base import ChainableOutputParser
77
from llama_index.core.output_parsers.utils import extract_json_str
@@ -15,7 +15,7 @@
1515
"""
1616

1717

18-
class PydanticOutputParser(ChainableOutputParser):
18+
class PydanticOutputParser(ChainableOutputParser, Generic[Model]):
1919
"""Pydantic Output Parser.
2020
2121
Args:
@@ -36,7 +36,7 @@ def __init__(
3636

3737
@property
3838
def output_cls(self) -> Type[Model]:
39-
return self._output_cls # type: ignore
39+
return self._output_cls
4040

4141
@property
4242
def format_string(self) -> str:

0 commit comments

Comments
 (0)