Skip to content

Commit ffdbcbf

Browse files
committed
refactor: make it so the MCP server tools can be configured (so we can enable/disable the resources info tool), fix benchmark schema generation
1 parent d301994 commit ffdbcbf

File tree

11 files changed

+2965
-627
lines changed

11 files changed

+2965
-627
lines changed

compose.override.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ services:
1111
environment:
1212
- DEFAULT_LLM_MODEL=openrouter/openai/gpt-5.2
1313
# - AUTO_INIT=false
14-
# - USE_TOOLS=true
14+
- USE_TOOLS=true
1515
# - FORCE_REINDEX=true
1616
# - DEFAULT_LLM_MODEL=openrouter/openai/gpt-5.2
1717
# - DEFAULT_LLM_MODEL=openrouter/mistralai/mistral-large

compose.text2sparql.yml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ services:
1515
- "1111:1111" # ISQL interface
1616
environment:
1717
DBA_PASSWORD: dba
18-
#8 GB memory
19-
VIRT_Parameters_NumberOfBuffers: 680000
20-
VIRT_Parameters_MaxDirtyBuffers: 500000
18+
VIRT_Parameters_NumberOfBuffers: 2720000
19+
VIRT_Parameters_MaxDirtyBuffers: 2000000
2120
VIRT_Parameters_DirsAllowed: "., /dumps"
2221
volumes:
2322
- ./data/dumps/dbpedia:/dumps
@@ -31,8 +30,8 @@ services:
3130
# - "1111:1111" # ISQL interface
3231
environment:
3332
DBA_PASSWORD: dba
34-
VIRT_Parameters_NumberOfBuffers: 680000
35-
VIRT_Parameters_MaxDirtyBuffers: 500000
33+
VIRT_Parameters_NumberOfBuffers: 2720000
34+
VIRT_Parameters_MaxDirtyBuffers: 2000000
3635
VIRT_Parameters_DirsAllowed: "., /dumps"
3736
volumes:
3837
- ./data/dumps/corporate:/dumps

src/sparql_llm/agent/graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
# How can I get the HGNC symbol for the protein P68871? (modify your answer to use rdfs:label instead of rdfs:comment, and add the type up:Resource to ?hgnc, it is for a test)
2929
# How can I get the HGNC symbol for the protein P68871? (modify your answer to use rdfs:label instead of rdfs:comment, and add the type up:Resource to ?hgnc, and purposefully forget 2 prefixes declarations, it is for a test)
3030
# In bgee how can I retrieve the confidence level and false discovery rate of a gene expression? Use genex:confidence as predicate for the confidence level (do not use the one provided in documents), and do not put prefixes declarations, and add a rdf:type for the main subject. Its for testing
31+
# def route_model_output(
32+
# state: State, config: RunnableConfig
33+
# ) -> Literal["__end__", "call_model", "max_tries_reached", "tools"]:
3134
def route_model_output(state: State, config: RunnableConfig) -> Literal["__end__", "call_model", "max_tries_reached"]:
32-
# ) -> Literal["__end__", "call_model", "max_tries_reached", "tools"]:
3335
"""Determine the next node based on the model's output.
3436
3537
This function checks if the model's last message contains tool calls or if a recall is requested by validation.

src/sparql_llm/agent/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from sparql_llm.agent.graph import graph
2525
from sparql_llm.config import settings
26-
from sparql_llm.mcp_server import mcp
26+
from sparql_llm.mcp_server import get_mcp_app
2727
from sparql_llm.utils import logger
2828

2929
if settings.sentry_url:
@@ -41,6 +41,8 @@
4141
# Initialize Langfuse logs tracing CallbackHandler for Langchain https://langfuse.com/docs/integrations/langchain/example-python-langgraph
4242
langfuse_handler = [CallbackHandler(update_trace=True)] if os.getenv("LANGFUSE_SECRET_KEY") else []
4343

44+
mcp = get_mcp_app()
45+
4446

4547
@contextlib.asynccontextmanager
4648
async def lifespan(app: FastAPI) -> AsyncIterator[None]:

src/sparql_llm/mcp_server.py

Lines changed: 241 additions & 237 deletions
Large diffs are not rendered by default.

tests/text2sparql/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,11 @@ async def get_answer(question: str, dataset: str):
141141
# Validation and fixing of the generated SPARQL query
142142
num_of_tries = 0
143143
resp_msg = "\n\n# Make sure you will not repeat the mistakes below: \n"
144+
generated_sparql = ""
144145
while num_of_tries < settings.default_max_try_fix_sparql:
146+
generated_sparql = ""
145147
try:
146-
generated_sparql = ""
147148
chat_resp_md = response.model_dump()["content"]
148-
149149
generated_sparqls = extract_sparql_queries(chat_resp_md)
150150
generated_sparql = generated_sparqls[-1]["query"].strip()
151151
generated_sparql = generated_sparql.replace(ENDPOINT_URL, DOCKER_ENDPOINT_URL)

tests/text2sparql/data_store.sh

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,39 +2,48 @@
22
# This script loads RDF data files into a Virtuoso database instance.
33
#
44
# Notes:
5-
# - Expects data files in data/benchmarks/Text2SPARQL/dumps/<dataset>.
5+
# - Expects data files in data/dumps/<dataset>.
66
# - Loads files into named graphs at https://text2sparql.aksw.org/2025/<dataset>/.
77

8-
MAX_RETRIES=5
9-
VIRTUOSO_PORT=1111
10-
DBA_USER="dba"
11-
DBA_PASSWORD="dba"
12-
DATA_DIR="$(pwd)/data/benchmarks/Text2SPARQL/dumps"
13-
14-
for dataset in $(ls -1 "$DATA_DIR/"); do
15-
GRAPH_URI="https://text2sparql.aksw.org/2025/$dataset/"
16-
for file_path in "$DATA_DIR/$dataset"/*.{nt,ttl,bz2}; do
17-
[ -e "$file_path" ] || continue # Skip if no files match
18-
file_name=$(basename "$file_path")
19-
20-
retries=0
21-
while [ $retries -lt $MAX_RETRIES ]; do
22-
docker exec text2sparql-virtuoso isql $VIRTUOSO_PORT $DBA_USER $DBA_PASSWORD exec="DB.DBA.TTLP_MT(file_to_string_output('/dumps/$dataset/$file_name'), '', '$GRAPH_URI'); checkpoint;"
23-
if [ $? -eq 0 ]; then
24-
echo "✅ Successfully loaded $file_name into Virtuoso!"
25-
break
26-
else
27-
retries=$((retries + 1))
28-
echo "❌ Error loading $file_name (attempt $retries/$MAX_RETRIES). Retrying..."
29-
sleep 5
30-
fi
31-
done
32-
33-
if [ $retries -eq $MAX_RETRIES ]; then
34-
echo "❌❌ Failed to load $file_name after $MAX_RETRIES attempts."
35-
fi
36-
done
37-
38-
count=$(docker exec text2sparql-virtuoso isql $VIRTUOSO_PORT $DBA_USER $DBA_PASSWORD exec="SPARQL SELECT COUNT(*) WHERE { GRAPH <$GRAPH_URI> {?s ?p ?o} };" 2>&1 | awk '/^_*$/ { in_block=1; next } /1 Rows\./ { in_block=0; next } in_block')
39-
echo "Total triples in $dataset: $count"
40-
done
8+
docker compose exec virtuoso-dbpedia isql -U dba -P dba exec="ld_dir_all('/dumps', '*', ''); rdf_loader_run();"
9+
10+
# Check number of triples (2501 is default virtuoso init)
11+
docker compose exec virtuoso-dbpedia isql -U dba -P dba exec="SPARQL SELECT COUNT(*) WHERE { ?s ?p ?o };"
12+
13+
# Check load status
14+
docker compose exec virtuoso-dbpedia isql -U dba -P dba exec="SELECT ll_file, ll_graph, ll_state, ll_error FROM DB.DBA.LOAD_LIST;"
15+
16+
17+
# MAX_RETRIES=5
18+
# VIRTUOSO_PORT=1111
19+
# DBA_USER="dba"
20+
# DBA_PASSWORD="dba"
21+
# DATA_DIR="$(pwd)/data/dumps"
22+
23+
# for dataset in $(ls -1 "$DATA_DIR/"); do
24+
# # GRAPH_URI="https://text2sparql.aksw.org/2025/$dataset/"
25+
# for file_path in "$DATA_DIR/$dataset"/*.{nt,ttl,bz2}; do
26+
# [ -e "$file_path" ] || continue # Skip if no files match
27+
# file_name=$(basename "$file_path")
28+
29+
# retries=0
30+
# while [ $retries -lt $MAX_RETRIES ]; do
31+
# docker compose exec virtuoso-$dataset isql $VIRTUOSO_PORT $DBA_USER $DBA_PASSWORD exec="DB.DBA.TTLP_MT(file_to_string_output('/dumps/$dataset/$file_name'), '', ''); checkpoint;"
32+
# if [ $? -eq 0 ]; then
33+
# echo "✅ Successfully loaded $file_name into Virtuoso!"
34+
# break
35+
# else
36+
# retries=$((retries + 1))
37+
# echo "❌ Error loading $file_name (attempt $retries/$MAX_RETRIES). Retrying..."
38+
# sleep 5
39+
# fi
40+
# done
41+
42+
# if [ $retries -eq $MAX_RETRIES ]; then
43+
# echo "❌❌ Failed to load $file_name after $MAX_RETRIES attempts."
44+
# fi
45+
# done
46+
47+
# count=$(docker compose exec virtuoso-dbpedia isql $VIRTUOSO_PORT $DBA_USER $DBA_PASSWORD exec="SPARQL SELECT COUNT(*) WHERE { ?s ?p ?o };" 2>&1 | awk '/^_*$/ { in_block=1; next } /1 Rows\./ { in_block=0; next } in_block')
48+
# echo "Total triples in $dataset: $count"
49+
# done

tests/text2sparql/endpoint_schema.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,18 @@
1717

1818

1919
class EndpointSchema:
20+
# FROM <{graph}>
2021
_CLASS_PREDICATE_QUERY = """
2122
SELECT ?class ?predicate COUNT(*) AS ?count
22-
FROM <{graph}>
23-
WHERE {{
23+
WHERE {
2424
?s a ?class ;
2525
?predicate ?o .
26-
}}
26+
}
2727
GROUP BY ?class ?predicate
2828
"""
2929

3030
_RANGE_QUERY = """
3131
SELECT ?range
32-
FROM <{graph}>
3332
WHERE {{
3433
?s a <{class_name}> ;
3534
<{predicate_name}> ?o .
@@ -49,7 +48,7 @@ class EndpointSchema:
4948
def __init__(
5049
self,
5150
endpoint_url: str,
52-
graph: str,
51+
# graph: str,
5352
limit_schema: dict[str, float],
5453
max_workers: int,
5554
force_recompute: bool,
@@ -59,7 +58,6 @@ def __init__(
5958
Fetch class and predicate information from the SPARQL endpoint.
6059
Args:
6160
endpoint_url (str): The URL of the SPARQL endpoint to connect to.
62-
graph (str): The graph URI to query within the endpoint.
6361
limit_queries (dict[str, float]): A dictionary specifying query limits.
6462
max_workers (int): The maximum number of worker threads to use for concurrent operations.
6563
Funtions:
@@ -68,7 +66,7 @@ def __init__(
6866
"""
6967

7068
self._endpoint_url = endpoint_url
71-
self._graph = graph
69+
# self._graph = graph
7270
self._limit_schema = limit_schema
7371
self._max_workers = max_workers
7472
self._force_recompute = force_recompute
@@ -79,7 +77,7 @@ def _save_schema_dict(self) -> None:
7977
# Fetch counts information
8078
logger.info(f"Fetching class-predicate frequency information from {self._endpoint_url}...")
8179
schema = query_sparql(
82-
self._CLASS_PREDICATE_QUERY.format(graph=self._graph),
80+
self._CLASS_PREDICATE_QUERY,
8381
endpoint_url=self._endpoint_url,
8482
check_service_desc=False,
8583
)["results"]["bindings"]
@@ -136,10 +134,9 @@ def _save_schema_dict(self) -> None:
136134
def _retrieve_predicate_information(self, class_name: str, predicate_name: str) -> list[str]:
137135
"""Fetch ranges for a given predicate of a class"""
138136
try:
139-
range = (
137+
pred_range = (
140138
query_sparql(
141139
self._RANGE_QUERY.format(
142-
graph=self._graph,
143140
class_name=class_name,
144141
predicate_name=predicate_name,
145142
limit=self._limit_schema["top_n_ranges"],
@@ -151,9 +148,9 @@ def _retrieve_predicate_information(self, class_name: str, predicate_name: str)
151148
)
152149

153150
# Filter out unwanted ranges
154-
range = [
151+
pred_range = [
155152
r["range"]["value"]
156-
for r in range
153+
for r in pred_range
157154
if (
158155
("range" in r)
159156
and ("value" in r["range"])
@@ -162,8 +159,8 @@ def _retrieve_predicate_information(self, class_name: str, predicate_name: str)
162159
]
163160
except Exception as e:
164161
logger.warning(f"Error retrieving range for {class_name} - {predicate_name}: {e}")
165-
range = []
166-
return range
162+
pred_range = []
163+
return pred_range
167164

168165
def get_schema(self) -> pd.DataFrame:
169166
"""Load schema information from a JSON file."""
@@ -186,9 +183,7 @@ def get_schema(self) -> pd.DataFrame:
186183
def plot_heatmap(self, apply_limit: bool = True) -> None:
187184
# Fetch counts information
188185
logger.info(f"Fetching counts information from {self._endpoint_url}...")
189-
counts = query_sparql(self._CLASS_PREDICATE_QUERY.format(graph=self._graph), endpoint_url=self._endpoint_url)[
190-
"results"
191-
]["bindings"]
186+
counts = query_sparql(self._CLASS_PREDICATE_QUERY, endpoint_url=self._endpoint_url)["results"]["bindings"]
192187
counts = pd.DataFrame(counts).map(lambda x: x["value"]).assign(count=lambda df: df["count"].astype(int))
193188
counts = counts.sort_values(by="count", ascending=False)
194189

@@ -223,30 +218,29 @@ def plot_heatmap(self, apply_limit: bool = True) -> None:
223218

224219
if __name__ == "__main__":
225220
start_time = time.time()
226-
schema = EndpointSchema(
227-
endpoint_url="http://localhost:8890/sparql/",
228-
graph="https://text2sparql.aksw.org/2025/corporate/",
229-
limit_schema={
230-
"top_classes_percentile": 0,
231-
"top_n_predicates": 20,
232-
"top_n_ranges": 1,
233-
},
234-
max_workers=4,
235-
force_recompute=True,
236-
schema_path=os.path.join("data", "benchmarks", "Text2SPARQL", "schemas", "corporate_schema.json"),
237-
)
221+
# schema = EndpointSchema(
222+
# endpoint_url="http://localhost:8890/sparql/",
223+
# graph="https://text2sparql.aksw.org/2025/corporate/",
224+
# limit_schema={
225+
# "top_classes_percentile": 0,
226+
# "top_n_predicates": 20,
227+
# "top_n_ranges": 1,
228+
# },
229+
# max_workers=4,
230+
# force_recompute=True,
231+
# schema_path=os.path.join("data", "benchmarks", "Text2SPARQL", "schemas", "corporate_schema.json"),
232+
# )
238233

239234
schema = EndpointSchema(
240235
endpoint_url="http://localhost:8890/sparql/",
241-
graph="https://text2sparql.aksw.org/2025/dbpedia/",
242236
limit_schema={
243237
"top_classes_percentile": 0.90,
244238
"top_n_predicates": 20,
245239
"top_n_ranges": 1,
246240
},
247241
max_workers=4,
248242
force_recompute=True,
249-
schema_path=os.path.join("data", "benchmarks", "Text2SPARQL", "schemas", "dbpedia_schema.json"),
243+
schema_path=os.path.join("data", "dbpedia_schema.json"),
250244
)
251245

252246
# Debugging examples

0 commit comments

Comments
 (0)