Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
.pants.d/
dist/
migration_scripts/
**data_wiki

# IDEs
.idea
Expand Down
440 changes: 440 additions & 0 deletions docs/docs/examples/node_postprocessor/REBELRerank.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions llama-index-core/llama_index/core/postprocessor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from llama_index.core.postprocessor.llm_rerank import LLMRerank
from llama_index.core.postprocessor.rebel_rerank import REBELRerank
from llama_index.core.postprocessor.structured_llm_rerank import (
StructuredLLMRerank,
DocumentWithRelevance,
Expand Down Expand Up @@ -39,6 +40,7 @@
"PIINodePostprocessor",
"NERPIINodePostprocessor",
"LLMRerank",
"REBELRerank",
"StructuredLLMRerank",
"DocumentWithRelevance",
"SentenceEmbeddingOptimizer",
Expand Down
160 changes: 160 additions & 0 deletions llama-index-core/llama_index/core/postprocessor/rebel_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
from typing import List, Optional, Callable

from llama_index.core.bridge.pydantic import Field, PrivateAttr, SerializeAsAny
from llama_index.core.llms import LLM
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.prompts import BasePromptTemplate
from llama_index.core.prompts.base import PromptTemplate, PromptType
from llama_index.core.prompts.default_prompts import (
DEFAULT_REBEL_META_PROMPT,
DEFAULT_REBEL_CHOICE_SELECT_PROMPT,
)
from llama_index.core.prompts.mixin import PromptDictType
from llama_index.core.schema import NodeWithScore, QueryBundle
from llama_index.core.settings import Settings
from llama_index.core.indices.utils import (
default_format_node_batch_fn,
default_parse_choice_select_answer_fn,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


def get_default_llm() -> LLM:
from llama_index.llms.openai import OpenAI # pants: no-infer-dep

return OpenAI(model="gpt-3.5-turbo-16k")


class REBELRerank(BaseNodePostprocessor):
"""REBEL (Rerank Beyond Relevance) reranker."""

top_n: int = Field(description="Top N nodes to return.")
one_turn: bool = Field(description="Whether to use a one_turn reranking prompt")
meta_prompt: SerializeAsAny[BasePromptTemplate] = Field(
description="REBEL prompt that generates the choice selection prompt."
)
choice_batch_size: int = Field(description="Batch size for choice select.")
llm: LLM = Field(
default_factory=get_default_llm, description="The LLM to rerank with."
)
verbose: bool = Field(
default=False, description="Whether to print intermediate steps."
)
choice_select_prompt: Optional[SerializeAsAny[BasePromptTemplate]] = Field(
default=None, description="Generated prompt for choice selection."
)

_format_node_batch_fn: Callable = PrivateAttr()
_parse_choice_select_answer_fn: Callable = PrivateAttr()

def __init__(
self,
llm: Optional[LLM] = None,
meta_prompt: Optional[BasePromptTemplate] = None,
choice_select_prompt: Optional[BasePromptTemplate] = None,
choice_batch_size: int = 10,
format_node_batch_fn: Optional[Callable] = None,
parse_choice_select_answer_fn: Optional[Callable] = None,
top_n: int = 10,
one_turn: bool = False,
) -> None:
"""Initialize params."""
meta_prompt = meta_prompt or DEFAULT_REBEL_META_PROMPT
choice_select_prompt = (
choice_select_prompt or DEFAULT_REBEL_CHOICE_SELECT_PROMPT
)
llm = llm or Settings.llm

super().__init__(
llm=llm,
meta_prompt=meta_prompt,
choice_select_prompt=choice_select_prompt,
choice_batch_size=choice_batch_size,
top_n=top_n,
one_turn=one_turn,
)

self._format_node_batch_fn = (
format_node_batch_fn or default_format_node_batch_fn
)
self._parse_choice_select_answer_fn = (
parse_choice_select_answer_fn or default_parse_choice_select_answer_fn
)

@classmethod
def class_name(cls) -> str:
return "REBELRerank"

def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
"meta_prompt": self.meta_prompt,
"static_prompt": self.static_prompt,
"choice_select_prompt": self.choice_select_prompt,
}

def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "meta_prompt" in prompts:
self.meta_prompt = prompts["meta_prompt"]
if "choice_select_prompt" in prompts:
self.choice_select_prompt = prompts["choice_select_prompt"]

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
"""Postprocess nodes."""
if query_bundle is None:
raise ValueError("Query bundle must be provided.")
if len(nodes) == 0:
return []

query_str = query_bundle.query_str

# In two_turn REBEL, the choice_select_prompt is created per-query from the meta prompt
if not self.one_turn:
self.choice_select_prompt = PromptTemplate(
self.llm.predict(
self.meta_prompt,
user_query=query_str,
),
prompt_type=PromptType.CHOICE_SELECT,
)

initial_results: List[NodeWithScore] = []
for idx in range(0, len(nodes), self.choice_batch_size):
nodes_batch = [
node.node for node in nodes[idx : idx + self.choice_batch_size]
]
fmt_batch_str = self._format_node_batch_fn(nodes_batch)

response = self.llm.predict(
self.choice_select_prompt,
context_str=fmt_batch_str,
query_str=query_str,
)

raw_choices, relevances = self._parse_choice_select_answer_fn(
response, len(nodes_batch)
)

choice_idxs = [int(choice) - 1 for choice in raw_choices]
choice_nodes = [nodes_batch[i] for i in choice_idxs]

initial_results.extend(
[
NodeWithScore(node=node, score=relevance)
for node, relevance in zip(choice_nodes, relevances)
]
)

final_results = sorted(
initial_results, key=lambda x: x.score or 0.0, reverse=True
)[: self.top_n]

return final_results
Loading