Skip to content

Commit 560a5d9

Browse files
committed
refactor: improve handling of endpoints metadata to lazy load them
1 parent 9435c64 commit 560a5d9

File tree

13 files changed

+120
-113
lines changed

13 files changed

+120
-113
lines changed

compose.override.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ services:
88
api:
99
ports:
1010
- 8000:8000
11-
# environment:
12-
# - DEFAULT_LLM_MODEL=openrouter/openai/gpt-5.1
11+
environment:
12+
- DEFAULT_LLM_MODEL=openrouter/openai/gpt-5.1
1313
# - USE_TOOLS=true
1414
# - FORCE_REINDEX=true
1515
# - DEFAULT_LLM_MODEL=openrouter/openai/gpt-5.2

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ dependencies = [
3838
"qdrant-client >=1.16.2",
3939
"fastembed >=0.7.4",
4040
"langchain-core >=1.2.6",
41-
"langgraph >=1.0.5",
4241
"markdownify >=1.1.0",
4342
"pandas >=2.2.3",
4443
]
4544

4645
[project.optional-dependencies]
4746
agent = [
4847
# LangGraph dependencies
48+
"langgraph >=1.0.5",
4949
"langchain >=1.2.0",
5050
"langchain-openai >=1.1.6",
5151
# "langchain-azure-ai >=0.1.0",

src/sparql_llm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
__version__ = "0.1.3"
44

5-
from .utils import SparqlEndpointLinks, query_sparql
5+
from .config import SparqlEndpointLinks
6+
from .utils import query_sparql
67
from .validate_sparql import validate_sparql, validate_sparql_in_msg, validate_sparql_with_void
78
from .loaders.sparql_examples_loader import SparqlExamplesLoader
89
from .loaders.sparql_void_shapes_loader import SparqlVoidShapesLoader, get_shex_dict_from_void, get_shex_from_void

src/sparql_llm/agent/nodes/validation.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
from sparql_llm.agent.prompts import FIX_QUERY_PROMPT
1111
from sparql_llm.agent.state import State, StepOutput
1212
from sparql_llm.config import Configuration, settings
13-
from sparql_llm.utils import get_prefixes_and_schema_for_endpoints, query_sparql
13+
from sparql_llm.utils import endpoints_metadata, query_sparql
1414
from sparql_llm.validate_sparql import validate_sparql_in_msg
1515

16-
prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(settings.endpoints)
17-
1816

1917
async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any]:
2018
"""LangGraph node to validate the output of a LLM call, e.g. SPARQL queries generated.
@@ -34,7 +32,7 @@ async def validate_output(state: State, config: RunnableConfig) -> dict[str, Any
3432
validation_steps: list[StepOutput] = []
3533
recall_messages: list[HumanMessage] = []
3634

37-
validation_outputs = validate_sparql_in_msg(last_msg, prefixes_map, endpoints_void_dict)
35+
validation_outputs = validate_sparql_in_msg(last_msg, endpoints_metadata.prefixes_map, endpoints_metadata.void_dict)
3836
for validation_output in validation_outputs:
3937
if validation_output["fixed_query"]:
4038
# Pass the fixed msg to the client

src/sparql_llm/config.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,27 @@
66
import os
77
from dataclasses import dataclass, field, fields
88
from pathlib import Path
9-
from typing import Annotated, Any, TypeVar
9+
from typing import Annotated, Any, Required, TypedDict, TypeVar
1010

1111
from fastembed import TextEmbedding
1212
from langchain_core.runnables import RunnableConfig, ensure_config
1313
from pydantic_settings import BaseSettings, SettingsConfigDict
1414
from qdrant_client import QdrantClient
1515

1616
from sparql_llm.agent import prompts
17-
from sparql_llm.utils import SparqlEndpointLinks
17+
18+
19+
# Total=False to make all fields optional except those marked as Required
20+
class SparqlEndpointLinks(TypedDict, total=False):
21+
"""A dictionary to store links and filepaths about a SPARQL endpoint."""
22+
23+
endpoint_url: Required[str]
24+
void_file: str | None
25+
examples_file: str | None
26+
homepage_url: str | None
27+
label: str | None
28+
description: str | None
29+
# ontology_url: Optional[str]
1830

1931

2032
class Settings(BaseSettings):

src/sparql_llm/indexing/index_resources.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from rdflib import RDF, Dataset, Namespace
1111

1212
from sparql_llm import SparqlExamplesLoader, SparqlInfoLoader, SparqlVoidShapesLoader
13-
from sparql_llm.config import embedding_model, qdrant_client, settings
13+
from sparql_llm.config import SparqlEndpointLinks, embedding_model, qdrant_client, settings
1414
from sparql_llm.loaders.sparql_info_loader import GENERAL_INFO_DOC_TYPE
15-
from sparql_llm.utils import SparqlEndpointLinks, get_prefixes_and_schema_for_endpoints
15+
from sparql_llm.utils import endpoints_metadata
1616

1717
SCHEMA = Namespace("http://schema.org/")
1818

@@ -160,9 +160,9 @@ def load_expasy_resources_infos(file: str = "expasy_resources_metadata.csv") ->
160160

161161

162162
def init_vectordb() -> None:
163-
"""Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints"""
163+
"""Initialize the vectordb with example queries and ontology descriptions from the SPARQL endpoints."""
164164
docs: list[Document] = []
165-
prefix_map, _void_schema = get_prefixes_and_schema_for_endpoints(settings.endpoints)
165+
endpoints_metadata._ensure_loaded()
166166

167167
# Gets documents from the SPARQL endpoints
168168
for endpoint in settings.endpoints:
@@ -174,7 +174,7 @@ def init_vectordb() -> None:
174174

175175
docs += SparqlVoidShapesLoader(
176176
endpoint["endpoint_url"],
177-
prefix_map=prefix_map,
177+
prefix_map=endpoints_metadata.prefixes_map,
178178
void_file=endpoint.get("void_file"),
179179
examples_file=endpoint.get("examples_file"),
180180
).load()

src/sparql_llm/loaders/sparql_info_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from langchain_core.document_loaders.base import BaseLoader
22
from langchain_core.documents import Document
33

4-
from sparql_llm.utils import SparqlEndpointLinks, logger
4+
from sparql_llm.config import SparqlEndpointLinks
5+
from sparql_llm.utils import logger
56

67
GENERAL_INFO_DOC_TYPE = "General information"
78

src/sparql_llm/loaders/sparql_void_shapes_loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def get_shex_dict_from_void(
3939
shex_dict = {}
4040

4141
for subject_cls, predicates in void_dict.items():
42-
if ignore_namespaces(namespaces_to_ignore, subject_cls):
42+
if ignore_namespaces(namespaces_to_ignore, subject_cls) and subject_cls not in [
43+
"http://www.w3.org/2002/07/owl#Class",
44+
"http://www.w3.org/2000/01/rdf-schema#Class",
45+
]:
4346
continue
4447
try:
4548
subj = prefix_converter.compress(subject_cls, passthrough=True)

src/sparql_llm/mcp_server.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from mcp.server.fastmcp import FastMCP
55
from qdrant_client.models import FieldCondition, Filter, MatchValue, ScoredPoint
66

7-
from sparql_llm.agent.nodes.validation import endpoints_void_dict, prefixes_map
87
from sparql_llm.config import embedding_model, qdrant_client, settings
98
from sparql_llm.indexing.index_resources import init_vectordb
10-
from sparql_llm.utils import logger, query_sparql
9+
from sparql_llm.utils import endpoints_metadata, logger, query_sparql
1110
from sparql_llm.validate_sparql import validate_sparql
1211

1312
# What are the rat orthologs of the human TP53?
@@ -222,7 +221,9 @@ def execute_sparql_query(sparql_query: str, endpoint_url: str) -> str:
222221
"""
223222
resp_msg = ""
224223
# First check if query valid based on classes schema and known prefixes
225-
validation_output = validate_sparql(sparql_query, endpoint_url, prefixes_map, endpoints_void_dict)
224+
validation_output = validate_sparql(
225+
sparql_query, endpoint_url, endpoints_metadata.prefixes_map, endpoints_metadata.void_dict
226+
)
226227
if validation_output["fixed_query"]:
227228
# Pass the fixed query to the client
228229
resp_msg += f"Fixed the prefixes of the generated SPARQL query automatically:\n```sparql\n{validation_output['fixed_query']}\n```\n"
@@ -256,8 +257,6 @@ def execute_sparql_query(sparql_query: str, endpoint_url: str) -> str:
256257
return resp_msg
257258

258259

259-
# prefixes_map, endpoints_void_dict = get_prefixes_and_schema_for_endpoints(settings.endpoints)
260-
261260
FIX_QUERY_PROMPT = """Please fix the query, and try again.
262261
We suggest you to make the query less restricted, e.g. use a broader regex for string matching instead of exact match,
263262
ignore case, make sure you are not overriding an existing variable with BIND, or break down your query in smaller parts

src/sparql_llm/utils.py

Lines changed: 71 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import json
22
import logging
33
from pathlib import Path
4-
from typing import Any, Required, TypedDict
4+
from typing import Any
55

66
import curies
77
import httpx
88
import rdflib
99

10+
from sparql_llm.config import SparqlEndpointLinks, settings
11+
1012
# Disable logger in your code with logging.getLogger("sparql_llm").setLevel(logging.WARNING)
1113
logger = logging.getLogger("sparql_llm")
1214
logger.setLevel(logging.INFO)
@@ -19,19 +21,6 @@
1921
logging.getLogger("httpx").setLevel(logging.WARNING)
2022

2123

22-
# Total=False to make all fields optional except those marked as Required
23-
class SparqlEndpointLinks(TypedDict, total=False):
24-
"""A dictionary to store links and filepaths about a SPARQL endpoint."""
25-
26-
endpoint_url: Required[str]
27-
void_file: str | None
28-
examples_file: str | None
29-
homepage_url: str | None
30-
label: str | None
31-
description: str | None
32-
# ontology_url: Optional[str]
33-
34-
3524
# Prefixes utilities
3625

3726
GET_PREFIXES_QUERY = """PREFIX sh: <http://www.w3.org/ns/shacl#>
@@ -45,40 +34,6 @@ class SparqlEndpointLinks(TypedDict, total=False):
4534
ENDPOINTS_METADATA_FILE = Path("data") / "endpoints_metadata.json"
4635

4736

48-
def load_endpoints_metadata_file() -> tuple[dict[str, str], "EndpointsSchemaDict"]:
49-
"""Load prefixes and schema from the cached metadata file."""
50-
try:
51-
with open(ENDPOINTS_METADATA_FILE) as f:
52-
data = json.load(f)
53-
logger.info(
54-
f"💾 Loaded endpoints metadata from {ENDPOINTS_METADATA_FILE.resolve()} for {len(data.get('classes_schema', {}))} endpoints"
55-
)
56-
return data.get("prefixes_map", {}), data.get("classes_schema", {})
57-
except Exception as e:
58-
logger.warning(f"Could not load metadata from {ENDPOINTS_METADATA_FILE}: {e}")
59-
return {}, {}
60-
61-
62-
def get_prefixes_and_schema_for_endpoints(
63-
endpoints: list[SparqlEndpointLinks],
64-
) -> tuple[dict[str, str], "EndpointsSchemaDict"]:
65-
"""Return a dictionary of prefixes and a dictionary of VoID classes schema for the given endpoints."""
66-
prefixes_map, endpoints_void_dict = load_endpoints_metadata_file()
67-
if prefixes_map and endpoints_void_dict:
68-
return prefixes_map, endpoints_void_dict
69-
logger.info(f"Fetching metadata for {len(endpoints)} endpoints...")
70-
for endpoint in endpoints:
71-
endpoints_void_dict[endpoint["endpoint_url"]] = get_schema_for_endpoint(
72-
endpoint["endpoint_url"], endpoint.get("void_file")
73-
)
74-
logger.info(f"Fetching {endpoint['endpoint_url']} metadata...")
75-
prefixes_map = get_prefixes_for_endpoint(endpoint["endpoint_url"], endpoint.get("examples_file"), prefixes_map)
76-
# Cache the metadata in a JSON file
77-
with open(ENDPOINTS_METADATA_FILE, "w") as f:
78-
json.dump({"prefixes_map": prefixes_map, "classes_schema": endpoints_void_dict}, f, indent=2)
79-
return prefixes_map, endpoints_void_dict
80-
81-
8237
def get_prefixes_for_endpoint(
8338
endpoint_url: str, examples_file: str | None = None, prefixes_map: dict[str, str] | None = None
8439
) -> dict[str, str]:
@@ -143,33 +98,6 @@ def get_schema_for_endpoint(endpoint_url: str, void_file: str | None = None) ->
14398
Formatted as: dict[subject_cls][predicate] = list[object_cls/datatype]"""
14499
void_dict: SchemaDict = {}
145100
try:
146-
# if void_file:
147-
# g = rdflib.Graph()
148-
# if void_file.startswith(("http://", "https://")):
149-
# # Handle URL case
150-
# with httpx.Client() as client:
151-
# for attempt in range(10):
152-
# # Retry a few times in case of HTTP errors, e.g. https://sparql.uniprot.org/.well-known/void/
153-
# try:
154-
# resp = client.get(void_file, headers={"Accept": "text/turtle"}, follow_redirects=True)
155-
# resp.raise_for_status()
156-
# if resp.text.strip() == "":
157-
# raise ValueError(f"Empty response for VoID description from {void_file}")
158-
# g.parse(data=resp.text, format="turtle")
159-
# break
160-
# except Exception as e:
161-
# if attempt == 3:
162-
# raise e
163-
# time.sleep(1)
164-
# continue
165-
# else:
166-
# # Handle local file case
167-
# g.parse(void_file, format="turtle")
168-
# results = g.query(GET_VOID_DESC)
169-
# bindings = [{str(k): {"value": str(v)} for k, v in row.asdict().items()} for row in results]
170-
# else:
171-
# bindings = query_sparql(GET_VOID_DESC, endpoint_url)["results"]["bindings"]
172-
173101
for void_triple in query_sparql(GET_VOID_DESC, endpoint_url, use_file=void_file, check_service_desc=True)[
174102
"results"
175103
]["bindings"]:
@@ -192,12 +120,7 @@ def get_schema_for_endpoint(endpoint_url: str, void_file: str | None = None) ->
192120
return void_dict
193121

194122

195-
# TODO: use SPARQLWrapper
196-
# sparqlw = SPARQLWrapper(endpoint)
197-
# sparqlw.setReturnFormat(JSON)
198-
# sparqlw.setOnlyConneg(True)
199-
# sparqlw.setQuery(query)
200-
# res = sparqlw.query().convert()
123+
# Use https://github.com/lu-pl/sparqlx ?
201124
def query_sparql(
202125
query: str,
203126
endpoint_url: str,
@@ -267,3 +190,70 @@ def query_sparql(
267190
if should_close:
268191
client.close()
269192
return query_resp
193+
194+
195+
class EndpointsMetadataManager:
196+
"""Lazy-loading manager for endpoints metadata."""
197+
198+
def __init__(self, endpoints: list[SparqlEndpointLinks], auto_init: bool = True) -> None:
199+
self._endpoints = endpoints
200+
self._prefixes_map: dict[str, str] = {}
201+
self._void_dict: EndpointsSchemaDict = {}
202+
self._initialized = False
203+
if auto_init:
204+
self._ensure_loaded()
205+
206+
def _ensure_loaded(self) -> None:
207+
"""Load metadata if not already loaded."""
208+
if self._initialized:
209+
return
210+
# Try loading from file first
211+
try:
212+
with open(ENDPOINTS_METADATA_FILE) as f:
213+
data = json.load(f)
214+
self._prefixes_map = data.get("prefixes_map", {})
215+
self._void_dict = data.get("classes_schema", {})
216+
if self._prefixes_map and self._void_dict:
217+
logger.info(
218+
f"💾 Loaded endpoints metadata from {ENDPOINTS_METADATA_FILE.resolve()} "
219+
f"for {len(self._void_dict)} endpoints"
220+
)
221+
return
222+
except Exception as e:
223+
logger.debug(f"Could not load metadata from {ENDPOINTS_METADATA_FILE}: {e}")
224+
225+
logger.info(f"Fetching metadata for {len(self._endpoints)} endpoints...")
226+
for endpoint in self._endpoints:
227+
self._void_dict[endpoint["endpoint_url"]] = get_schema_for_endpoint(
228+
endpoint["endpoint_url"], endpoint.get("void_file")
229+
)
230+
logger.info(f"Fetching {endpoint['endpoint_url']} metadata...")
231+
self._prefixes_map = get_prefixes_for_endpoint(
232+
endpoint["endpoint_url"], endpoint.get("examples_file"), self._prefixes_map
233+
)
234+
# Cache to JSON file
235+
with open(ENDPOINTS_METADATA_FILE, "w") as f:
236+
json.dump({"prefixes_map": self._prefixes_map, "classes_schema": self._void_dict}, f, indent=2)
237+
self._initialized = True
238+
logger.info(f"💾 Cached endpoints metadata to {ENDPOINTS_METADATA_FILE.resolve()}")
239+
240+
@property
241+
def prefixes_map(self) -> dict[str, str]:
242+
"""Get prefixes map, loading lazily if needed."""
243+
self._ensure_loaded()
244+
return self._prefixes_map or {}
245+
246+
@property
247+
def void_dict(self) -> "EndpointsSchemaDict":
248+
"""Get endpoints VoID schema dict, loading lazily if needed."""
249+
self._ensure_loaded()
250+
return self._void_dict or {}
251+
252+
# def reset(self) -> None:
253+
# """Reset cached metadata (useful for re-initialization after init_vectordb)."""
254+
# self._prefixes_map = {}
255+
# self._void_dict = {}
256+
257+
258+
# Global instance, metadata loads lazily on first property access
259+
endpoints_metadata = EndpointsMetadataManager(settings.endpoints, settings.auto_init)

0 commit comments

Comments
 (0)