Skip to content

Commit ad9df2f

Browse files
committed
update to bm25std and add escaper for safety
1 parent fae5e7e commit ad9df2f

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

python-recipes/vector-search/02_hybrid_search.ipynb

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@
367367
},
368368
{
369369
"cell_type": "code",
370-
"execution_count": 10,
370+
"execution_count": null,
371371
"metadata": {},
372372
"outputs": [
373373
{
@@ -382,6 +382,10 @@
382382
}
383383
],
384384
"source": [
385+
"from redisvl.utils.token_escaper import TokenEscaper\n",
386+
"\n",
387+
"escaper = TokenEscaper()\n",
388+
"\n",
385389
"# list of stopwords to filter out noise from query string\n",
386390
"stopwords = set([\n",
387391
" \"a\", \"is\", \"the\", \"an\", \"and\", \"are\", \"as\", \"at\", \"be\", \"but\", \"by\", \"for\",\n",
@@ -391,8 +395,8 @@
391395
"\n",
392396
"def tokenize_query(user_query: str) -> str:\n",
393397
" \"\"\"Convert a raw user query to a redis full text query joined by ORs\"\"\"\n",
394-
" tokens = [token.strip().strip(\",\").lower() for token in user_query.split()]\n",
395-
" return \" | \".join([token for token in tokens if token not in stopwords])\n",
398+
" tokens = [escaper.escape(token.strip().strip(\",\").lower()) for token in user_query.split()]\n",
399+
" return \" | \".join([token for token in tokens if token and token not in stopwords])\n",
396400
"\n",
397401
"# Example\n",
398402
"tokenize_query(user_query)"
@@ -407,7 +411,7 @@
407411
},
408412
{
409413
"cell_type": "code",
410-
"execution_count": 11,
414+
"execution_count": null,
411415
"metadata": {},
412416
"outputs": [],
413417
"source": [
@@ -438,8 +442,8 @@
438442
" filter_expression=f\"~({Text(text_field) % tokenize_query(user_query)})\",\n",
439443
" num_results=num_results,\n",
440444
" return_fields=[\"title\", \"description\"],\n",
441-
" dialect=4,\n",
442-
" ).scorer(\"BM25\").with_scores()"
445+
" dialect=2,\n",
446+
" ).scorer(\"BM25STD\").with_scores()"
443447
]
444448
},
445449
{
@@ -540,7 +544,7 @@
540544
},
541545
{
542546
"cell_type": "code",
543-
"execution_count": 14,
547+
"execution_count": null,
544548
"metadata": {},
545549
"outputs": [
546550
{
@@ -581,13 +585,13 @@
581585
"# Build the aggregation request\n",
582586
"req = (\n",
583587
" AggregateRequest(query.query_string())\n",
584-
" .scorer(\"BM25\")\n",
588+
" .scorer(\"BM25STD\")\n",
585589
" .add_scores()\n",
586590
" .apply(cosine_similarity=\"(2 - @vector_distance)/2\", bm25_score=\"@__score\")\n",
587591
" .apply(hybrid_score=f\"0.3*@bm25_score + 0.7*@cosine_similarity\")\n",
588592
" .load(\"title\", \"description\", \"cosine_similarity\", \"bm25_score\", \"hybrid_score\")\n",
589593
" .sort_by(Desc(\"@hybrid_score\"), max=3)\n",
590-
" .dialect(4)\n",
594+
" .dialect(2)\n",
591595
")\n",
592596
"\n",
593597
"# Run the query\n",
@@ -620,7 +624,7 @@
620624
},
621625
{
622626
"cell_type": "code",
623-
"execution_count": 15,
627+
"execution_count": null,
624628
"metadata": {},
625629
"outputs": [],
626630
"source": [
@@ -634,13 +638,13 @@
634638
" # Build aggregation\n",
635639
" req = (\n",
636640
" AggregateRequest(query.query_string())\n",
637-
" .scorer(\"BM25\")\n",
641+
" .scorer(\"BM25STD\")\n",
638642
" .add_scores()\n",
639643
" .apply(cosine_similarity=\"(2 - @vector_distance)/2\", bm25_score=\"@__score\")\n",
640644
" .apply(hybrid_score=f\"{1-alpha}*@bm25_score + {alpha}*@cosine_similarity\")\n",
641645
" .sort_by(Desc(\"@hybrid_score\"), max=num_results)\n",
642646
" .load(\"title\", \"description\", \"cosine_similarity\", \"bm25_score\", \"hybrid_score\")\n",
643-
" .dialect(4)\n",
647+
" .dialect(2)\n",
644648
" )\n",
645649
"\n",
646650
" # Run the query\n",

0 commit comments

Comments
 (0)