diff --git a/dataset/cache.csv b/dataset/cache.csv deleted file mode 100644 index ab54afd..0000000 --- a/dataset/cache.csv +++ /dev/null @@ -1,8 +0,0 @@ -id,text -1,"reset your password" -2,"our store hours" -3,"Halloween movies" -4,"Horror movies" -5,"Scary movies" -6,"mama ms rachel" -7,"ms rachel" diff --git a/dataset/chatgpt.csv b/dataset/chatgpt.csv deleted file mode 100644 index b16d46a..0000000 --- a/dataset/chatgpt.csv +++ /dev/null @@ -1,717 +0,0 @@ -sentence1,sentence2,label -"price of iphone 16 in the US (latest)","current US price for iPhone 16 today",0 -"price of iphone 16 in the US","iphone 16 cost in the United States",1 -"python sort list ascending","how to sort a list in python ascending",1 -"python sort list ascending","python reverse sort list",0 -"Capital of Canada","What is the capital city of Canada?",1 -"Capital of Canada","Population of Canada",0 -"download TensorFlow 2.12 wheels","download TensorFlow 2.13 wheels",0 -"GDP of Japan in 2020","Japan GDP for year 2020",1 -"GDP of Japan in 2020","Japan GDP latest",0 -"translate to Spanish: 'Good morning'","traducir a español: 'Good morning'",1 -"translate to Spanish: 'Good morning'","translate to French: 'Good morning'",0 -"site:arxiv.org Transformers review","Transformers review site:arxiv.org",1 -"site:arxiv.org Transformers review","Transformers review site:openreview.net",0 -"weather in Berlin today","today's weather in Berlin",1 -"weather in Berlin today","weather in Berlin next week",0 -"convert 10 USD to EUR","10 dollars in euros",1 -"convert 10 USD to EUR","convert 10 EUR to USD",0 -"top 5 coffee shops in Seattle","best five coffee shops in Seattle",1 -"top 5 coffee shops in Seattle","top 10 coffee shops in Seattle",0 -"Create JSON with keys title,author","Return JSON fields: title and author",1 -"Create JSON with keys title,author","Return YAML fields: title and author",0 -"apple stock price as of 2024-12-31","AAPL price on 2024-12-31",1 -"apple stock price as of 2024-12-31","AAPL price today",0 -"book a hotel in Paris for two adults","reserve a hotel room in Paris for 2 adults",1 -"book a hotel in Paris for two adults","find tourist attractions in Paris",0 -"nearest gas station","closest petrol station",1 -"nearest gas station","cheapest gas station",0 -"SQL: count rows in table orders","SQL: number of rows in orders table",1 -"SQL: count rows in table orders","SQL: list all rows in orders",0 -"NBA finals winners list","list of NBA champions",1 -"NBA finals winners list","list of NFL champions",0 -"open source license comparison MIT vs Apache","compare MIT with Apache license",1 -"open source license comparison MIT vs Apache","MIT license full text",0 -"translate this to German (informal)","translate this to German using du-form",1 -"translate this to German (informal)","translate this to German (formal)",0 -"elevation of Mount Everest","height of Everest",1 -"elevation of Mount Everest","height of K2",0 -"site:wikipedia.org pandas merge examples","pandas merge examples site:wikipedia.org",1 -"site:wikipedia.org pandas merge examples","pandas merge examples site:docs.python.org",0 -"download Ubuntu 22.04 ISO","Ubuntu 22.04 ISO download",1 -"download Ubuntu 22.04 ISO","download Ubuntu 24.04 ISO",0 -"restaurants open now near me","restaurants currently open nearby",1 -"restaurants open now near me","best restaurants near me",0 -"define photosynthesis in one sentence","give a one-sentence definition of photosynthesis",1 -"define photosynthesis in one sentence","explain photosynthesis in detail",0 -"convert 5km to miles","5 kilometers in miles",1 -"convert 5km to miles","5 miles in kilometers",0 -"top 3 results for 'machine learning basics'","return three results for machine learning basics",1 -"top 3 results for 'machine learning basics'","return ten results for machine learning basics",0 -"show me New York time now","current time in New York",1 -"show me New York time now","current time in Los Angeles",0 -"What is Docker? (beginner)","Explain Docker for a beginner",1 -"What is Docker? (beginner)","Explain Docker for an expert",0 -"sum of 2 and 3","compute 2+3",1 -"sum of 2 and 3","product of 2 and 3",0 -"Flights from London to Tokyo in June","June flights London→Tokyo",1 -"Flights from London to Tokyo in June","Flights from London to Tokyo in December",0 -"Sort array ascending in JavaScript","JS: sort an array ascending",1 -"Sort array ascending in JavaScript","JS: sort an array descending",0 -"company address of OpenAI","OpenAI company address",1 -"company address of OpenAI","OpenAI CEO name",0 -"Is quinoa gluten-free?","Does quinoa contain gluten?",1 -"Is quinoa gluten-free?","Calories in quinoa",0 -"Define 'entropy' (short)","Short definition of entropy",1 -"Define 'entropy' (short)","Detailed explanation of entropy with math",0 -"news about Tesla this week","latest Tesla news",1 -"news about Tesla this week","Tesla news in 2018",0 -"site:github.com 'vector search'","'vector search' on github.com",1 -"site:github.com 'vector search'","'vector search' on stackoverflow.com",0 -"download Node.js LTS","get Node LTS download",1 -"download Node.js LTS","download Node.js Current",0 -"restaurants in Paris under €30","Paris restaurants with price < 30 EUR",1 -"restaurants in Paris under €30","Paris restaurants under $30",0 -"what is HTTP status 404","explain HTTP 404",1 -"what is HTTP status 404","explain HTTP 500",0 -"nearest ATM","closest cash machine",1 -"nearest ATM","nearest bank branch",0 -"translate to Japanese: 'thank you'","'thank you' into Japanese",1 -"translate to Japanese: 'thank you'","'thank you' into Chinese",0 -"images of the Eiffel Tower","Eiffel Tower photos",1 -"images of the Eiffel Tower","Eiffel Tower ticket prices",0 -"population of Berlin 2020","Berlin population in 2020",1 -"population of Berlin 2020","Berlin population now",0 -"Open in CSV: top 10 rows of table sales","Return CSV with top 10 rows from sales",1 -"Open in CSV: top 10 rows of table sales","Return JSON with top 10 rows from sales",0 -"find PDF user manual for ThinkPad X1","download ThinkPad X1 user manual PDF",1 -"find PDF user manual for ThinkPad X1","find driver updates for ThinkPad X1",0 -"weather in Sydney tomorrow","Sydney forecast for tomorrow",1 -"weather in Sydney tomorrow","Sydney weather today",0 -"site:docs.python.org asyncio tutorial","asyncio tutorial site:docs.python.org",1 -"site:docs.python.org asyncio tutorial","asyncio tutorial site:realpython.com",0 -"top 5 sci-fi books of 2023","best five science fiction books 2023",1 -"top 5 sci-fi books of 2023","best five science fiction books 2024",0 -"open-source vector DBs list","list of open source vector databases",1 -"open-source vector DBs list","compare managed vector DB services",0 -"show me JPEG to PNG converter","JPEG→PNG converter",1 -"show me JPEG to PNG converter","JPEG→WebP converter",0 -"currency exchange rate USD to JPY now","current USD/JPY rate",1 -"currency exchange rate USD to JPY now","USD/JPY rate yesterday",0 -"create a table with two columns name,age","make a 2-column table: name and age",1 -"create a table with two columns name,age","make a table with columns name,age,city",0 -"SQL: list orders where status='shipped'","SQL: select orders with status shipped",1 -"SQL: list orders where status='shipped'","SQL: count orders where status='shipped'",0 -"define 'variance' in statistics","statistics: definition of variance",1 -"define 'variance' in statistics","compute sample variance of [1,2,3]",0 -"translate to English: '¿Dónde está la biblioteca?'","Translate '¿Dónde está la biblioteca?' to English",1 -"translate to English: '¿Dónde está la biblioteca?'","Translate '¿Dónde está la biblioteca?' to French",0 -"movie times for 'Oppenheimer' in Boston","Boston showtimes for Oppenheimer",1 -"movie times for 'Oppenheimer' in Boston","movie times for 'Barbie' in Boston",0 -"NYC subway map","New York City subway map",1 -"NYC subway map","NYC bus map",0 -"latest stable Kubernetes version","current stable Kubernetes release",1 -"latest stable Kubernetes version","Kubernetes version in 2021",0 -"Chocolate chip cookie recipe (metric)","cookie recipe in metric units",1 -"Chocolate chip cookie recipe (metric)","cookie recipe in US cups",0 -"find me 'site:gov.uk passport renewal'","passport renewal site:gov.uk",1 -"find me 'site:gov.uk passport renewal'","passport renewal site:usa.gov",0 -"Linux check disk space command","how to check disk usage on Linux",1 -"Linux check disk space command","how to check memory usage on Linux",0 -"Who is the president of France?","Name the French president",1 -"Who is the president of France?","Name the prime minister of France",0 -"Top universities in Germany 2024 ranking","best German universities ranking 2024",1 -"Top universities in Germany 2024 ranking","best German universities ranking 2022",0 -"Restaurants in Tokyo that are open late","late-night restaurants in Tokyo",1 -"Restaurants in Tokyo that are open late","breakfast restaurants in Tokyo",0 -"JSON: return fields id and url only","Return only id and url as JSON",1 -"JSON: return fields id and url only","Return id, url, and title as JSON",0 -"Compute median of list [1,3,5,7]","Find the median for [1,3,5,7]",1 -"Compute median of list [1,3,5,7]","Find the mean for [1,3,5,7]",0 -"Give me synonyms of 'happy'","List synonyms for happy",1 -"Give me synonyms of 'happy'","Give antonyms of happy",0 -"site:bbc.com climate change article","climate change article site:bbc.com",1 -"site:bbc.com climate change article","climate change article site:cnn.com",0 -"'machine learning' exact phrase search","search exact phrase 'machine learning'",1 -"'machine learning' exact phrase search","search machine learning without quotes",0 -"translate this to Portuguese (pt-BR)","translate to Brazilian Portuguese",1 -"translate this to Portuguese (pt-BR)","translate to European Portuguese (pt-PT)",0 -"time in Tokyo when it's noon in London","convert time London noon to Tokyo",1 -"time in Tokyo when it's noon in London","time in Tokyo now",0 -"find CSV download for unemployment data 2023","get 2023 unemployment data as CSV",1 -"find CSV download for unemployment data 2023","get 2023 unemployment data as JSON",0 -"Explain PageRank in 3 bullets","3 bullet explanation of PageRank",1 -"Explain PageRank in 3 bullets","Explain PageRank with math derivation",0 -"nearest pharmacy open 24 hours","24/7 pharmacy near me",1 -"nearest pharmacy open 24 hours","pharmacy near me with COVID vaccines",0 -"post office hours today in Chicago","Chicago post office hours today",1 -"post office hours today in Chicago","Chicago post office hours tomorrow",0 -"find 'Deep Learning' book PDF (legal)","locate a legal PDF of 'Deep Learning'",1 -"find 'Deep Learning' book PDF (legal)","find torrents for 'Deep Learning' book",0 -"temperature in Celsius for New York now","current temperature in New York (°C)",1 -"temperature in Celsius for New York now","current temperature in New York (°F)",0 -"KNeighborsClassifier sklearn example","example for sklearn KNeighborsClassifier",1 -"KNeighborsClassifier sklearn example","example for sklearn RandomForestClassifier",0 -"Top 3 news about AI today","three AI headlines today",1 -"Top 3 news about AI today","AI headlines this month",0 -"Recipe for pesto (nut-free)","nut-free pesto recipe",1 -"Recipe for pesto (nut-free)","pesto recipe with pine nuts",0 -"exchange rate GBP to EUR at market close","GBP/EUR at yesterday's close",0 -"exchange rate GBP to EUR at market close","GBP to EUR market close rate",1 -"How to reset a Windows 11 password","steps to reset Windows 11 password",1 -"How to reset a Windows 11 password","How to reset a Mac password",0 -"top 10 universities worldwide QS 2025","QS World University Rankings 2025 top 10",1 -"top 10 universities worldwide QS 2025","QS World University Rankings 2023 top 10",0 -"news about OpenAI CEO today","OpenAI CEO news today",1 -"news about OpenAI CEO today","OpenAI CEO biography",0 -"convert Markdown to HTML","turn Markdown into HTML",1 -"convert Markdown to HTML","convert HTML to Markdown",0 -"clinical trial NCT number lookup","find study by NCT ID",1 -"clinical trial NCT number lookup","find study by EudraCT number",0 -"return top 5 results sorted by date","give five newest results",1 -"return top 5 results sorted by date","give five most relevant results",0 -"define 'precision' vs 'recall' briefly","briefly define precision and recall",1 -"define 'precision' vs 'recall' briefly","compute precision/recall for given data",0 -"restaurant reservations for tonight at 7pm","book a table tonight at 7 pm",1 -"restaurant reservations for tonight at 7pm","find restaurants open tonight",0 -"British English spelling for 'color'","Write 'colour' in British English",1 -"British English spelling for 'color'","Write 'color' in American English",0 -"download latest NVIDIA driver for Windows","get newest NVIDIA Windows driver",1 -"download latest NVIDIA driver for Windows","get newest AMD Windows driver",0 -"find 'site:europa.eu GDPR guidelines pdf'","GDPR guidelines pdf site:europa.eu",1 -"find 'site:europa.eu GDPR guidelines pdf'","GDPR guidelines html site:europa.eu",0 -"SQL: SELECT id,name FROM users LIMIT 10","SQL: return id and name for first 10 users",1 -"SQL: SELECT id,name FROM users LIMIT 10","SQL: SELECT id,name,email FROM users LIMIT 10",0 -"calculate mortgage payment for $500k at 5%","monthly payment on $500k mortgage at 5 percent",1 -"calculate mortgage payment for $500k at 5%","monthly payment on $600k mortgage at 5 percent",0 -"show me code to read a CSV in Python","Python code to read CSV",1 -"show me code to read a CSV in Python","R code to read CSV",0 -"who won the FIFA World Cup 2018","FIFA 2018 champion",1 -"who won the FIFA World Cup 2018","FIFA 2022 champion",0 -"top JavaScript frameworks 2024","best JS frameworks of 2024",1 -"top JavaScript frameworks 2024","best JS frameworks of 2021",0 -"nearest emergency room","closest ER",1 -"nearest emergency room","closest urgent care",0 -"translate to Chinese (Simplified): 'library'","translate 'library' into zh-CN",1 -"translate to Chinese (Simplified): 'library'","translate 'library' into zh-TW",0 -"Return plaintext only, no links: explain TLS","Explain TLS in plaintext with no links",1 -"Return plaintext only, no links: explain TLS","Explain TLS and include links",0 -"site:reddit.com best budget laptops","best budget laptops site:reddit.com",1 -"site:reddit.com best budget laptops","best budget laptops site:theverge.com",0 -"current EUR inflation rate (y/y)","year-over-year inflation in the Eurozone now",1 -"current EUR inflation rate (y/y)","Eurozone inflation rate in 2019",0 -"OpenAPI spec example (YAML)","example OpenAPI YAML",1 -"OpenAPI spec example (YAML)","example OpenAPI JSON",0 -"show top 3 restaurants with rating ≥4.5","list three restaurants rated at least 4.5",1 -"show top 3 restaurants with rating ≥4.5","list five restaurants rated at least 4.5",0 -"What is the derivative of sin(x)?","derivative of sine x",1 -"What is the derivative of sin(x)?","integral of sin(x)",0 -"Find contact email for Acme Corp","Acme Corporation contact email",1 -"Find contact email for Acme Corp","Acme Corporation phone number",0 -"Give two pros and two cons of remote work","list 2 pros and 2 cons of remote work",1 -"Give two pros and two cons of remote work","list three pros and three cons of remote work",0 -"Return CSV with columns name,price,currency","CSV output: name, price, currency",1 -"Return CSV with columns name,price,currency","JSON output: name, price, currency",0 -"convert 100°F to Celsius","100 Fahrenheit in Celsius",1 -"convert 100°F to Celsius","100 Celsius in Fahrenheit",0 -"best hiking trails near Denver","top hiking trails near Denver",1 -"best hiking trails near Denver","best camping sites near Denver",0 -"Explain 'hash map' to a 10-year-old","ELI5 explanation of hash maps",1 -"Explain 'hash map' to a 10-year-old","Formal academic definition of hash maps",0 -"site:imperial.ac.uk machine learning syllabus","machine learning syllabus site:imperial.ac.uk",1 -"site:imperial.ac.uk machine learning syllabus","machine learning syllabus site:ox.ac.uk",0 -"latest NVIDIA earnings report","most recent NVIDIA earnings",1 -"latest NVIDIA earnings report","NVIDIA earnings in 2020",0 -"nearest EV charging station","closest EV charger",1 -"nearest EV charging station","cheapest EV charging station",0 -"Who wrote 'Pride and Prejudice'?","Author of 'Pride and Prejudice'",1 -"Who wrote 'Pride and Prejudice'?","Main characters of 'Pride and Prejudice'",0 -"give me three citations about GPT-4","three references about GPT-4",1 -"give me three citations about GPT-4","ten references about GPT-4",0 -"weather in Sofia this weekend","Sofia weekend weather forecast",1 -"weather in Sofia this weekend","Sofia weather next month",0 -"find PDF 'AWS Well-Architected Framework'","download 'AWS Well-Architected Framework' PDF",1 -"find PDF 'AWS Well-Architected Framework'","download 'Azure Well-Architected Framework' PDF",0 -"Return only titles of top 5 results","Only return titles for five results",1 -"Return only titles of top 5 results","Return titles and URLs for five results",0 -"Is 'color' spelled 'colour' in the UK?","British spelling of 'color' is 'colour'?",1 -"Is 'color' spelled 'colour' in the UK?","American spelling of 'colour' is 'color'?",0 -"translate to Italian and keep punctuation","Italian translation preserving punctuation",1 -"translate to Italian and keep punctuation","Italian translation without punctuation",0 -"Stack Overflow question on Python f-strings","Python f-strings question on Stack Overflow",1 -"Stack Overflow question on Python f-strings","Reddit discussion on Python f-strings",0 -"top 3 papers on retrieval-augmented generation","three best RAG papers",1 -"top 3 papers on retrieval-augmented generation","five best RAG papers",0 -"return headlines only, no summaries","give only the headlines",1 -"return headlines only, no summaries","give headlines with summaries",0 -"nearest public library","closest public library",1 -"nearest public library","public library opening hours",0 -"Explain RSA in two sentences","two-sentence explanation of RSA",1 -"Explain RSA in two sentences","detailed explanation of RSA with equations",0 -"site:who.int malaria factsheet pdf","malaria factsheet pdf site:who.int",1 -"python memory leak in requests","memory leak in python requests",1 -"error code 504 on upload","why do I get 504 when uploading",1 -"tag:postgres in:title index strategy","in:title index strategy tag:postgres",1 -"author:alice docker compose tips","docker compose tips by author:alice",1 -"best IDE for Rust beginners","good starter IDE for Rust",1 -"site:forum.example.com graphql timeout","graphql timeout site:forum.example.com",1 -"\"connection reset by peer\" on deploy","deploy fails with \"connection reset by peer\"",1 -"tag:kubernetes sort:recent ingress tls","ingress tls tag:kubernetes sort:recent",1 -"in:title dark mode css","css dark mode in:title",1 -"how to mute a thread","mute a discussion thread",1 -"compare SSD vs HDD longevity","SSD and HDD longevity comparison",1 -"after:2025-01-01 tag:release-notes","tag:release-notes after:2025-01-01",1 -"markdown table alignment help","how to align tables in markdown",1 -"android studio emulator slow on M1","emulator is slow on M1 in android studio",1 -"windows 11 clipboard history not working","clipboard history broken on windows 11",1 -"gpu passthrough on proxmox guide","proxmox gpu passthrough how-to",1 -"nx monorepo caching tips","tips for caching in nx monorepo",1 -"vim search and replace across files","find and replace across files in vim",1 -"latex figure not centered","center a figure in latex",1 -"ssh key permission denied troubleshooting","permission denied (publickey) fix",1 -"top 5 posts about log aggregation","show five best posts on log aggregation",1 -"is:solved selenium click intercepted","selenium element click intercepted is:solved",1 -"how to revoke jwt tokens","revoking JWT tokens best practices",1 -"rebase vs merge explanation","explain rebase versus merge",1 -"npm audit false positives","false positives in npm audit",1 -"how to backup mariaDB daily","daily backup for MariaDB",1 -"ffmpeg trim without re-encoding","trim video without reencoding ffmpeg",1 -"c# nullable reference types overview","overview of nullable reference types in c#",1 -"jira automation send slack message","send slack from jira automation",1 -"git ignore nested node_modules","ignore nested node_modules in git",1 -"python memory leak in requests","python memory leak in aiohttp",0 -"error code 504 on upload","error code 502 on upload",0 -"tag:postgres in:title index strategy","tag:mysql in:title index strategy",0 -"author:alice docker compose tips","author:bob docker compose tips",0 -"best IDE for Rust beginners","best IDE for Go beginners",0 -"site:forum.example.com graphql timeout","graphql timeout site:docs.example.com",0 -"\"connection reset by peer\" on deploy","\"broken pipe\" on deploy",0 -"tag:kubernetes sort:recent ingress tls","tag:kubernetes sort:top ingress tls",0 -"in:title dark mode css","in:body dark mode css",0 -"how to mute a thread","unmute a discussion thread",0 -"compare SSD vs HDD longevity","compare SSD vs NVMe longevity",0 -"after:2025-01-01 tag:release-notes","before:2025-01-01 tag:release-notes",0 -"markdown table alignment help","markdown table alignment help (Spanish)",0 -"android studio emulator slow on M1","android studio emulator slow on Windows",0 -"windows 11 clipboard history not working","windows 10 clipboard history not working",0 -"gpu passthrough on proxmox guide","gpu passthrough on vmware guide",0 -"nx monorepo caching tips","bazel monorepo caching tips",0 -"vim search and replace across files","emacs search and replace across files",0 -"latex figure not centered","word figure not centered",0 -"ssh key permission denied troubleshooting","ssh key permission denied on GitLab CI",0 -"top 5 posts about log aggregation","show top 20 posts on log aggregation",0 -"is:solved selenium click intercepted","selenium click intercepted is:unsolved",0 -"how to revoke jwt tokens","how to refresh jwt tokens",0 -"rebase vs merge explanation","squash vs merge explanation",0 -"npm audit false positives","yarn audit false positives",0 -"how to backup mariaDB daily","how to backup PostgreSQL daily",0 -"ffmpeg trim without re-encoding","ffmpeg resize without re-encoding",0 -"c# nullable reference types overview","f# nullable reference types overview",0 -"jira automation send slack message","jira automation send email",0 -"git ignore nested node_modules","git ignore .venv folders",0 -"find posts about home lab cooling","home lab cooling tips",1 -"setup wireguard on ubuntu 24.04","wireguard setup on ubuntu 24.04",1 -"tag:career how to negotiate salary","salary negotiation tag:career",1 -"delete local git branch safely","safely delete a local git branch",1 -"docker prune unused images command","remove unused docker images command",1 -"typescript narrow union by in operator","narrow unions with 'in' operator typescript",1 -"postgres vacuum settings explained","explain autovacuum settings in postgres",1 -"kafka consumer lag monitoring","monitoring kafka consumer lag",1 -"apple silicon brew path issues","brew path problems on apple silicon",1 -"nuxt vs next for content sites","next vs nuxt for content websites",1 -"find posts about home lab cooling","find posts about home **heating**",0 -"setup wireguard on ubuntu 24.04","setup wireguard on debian 12",0 -"tag:career how to negotiate salary","tag:career how to negotiate offers **in Germany**",0 -"delete local git branch safely","delete **remote** git branch safely",0 -"docker prune unused images command","docker **system** prune all command",0 -"typescript narrow union by in operator","typescript narrow union by **typeof**",0 -"postgres vacuum settings explained","postgres **wal** settings explained",0 -"kafka consumer lag monitoring","kafka **partition** rebalancing",0 -"apple silicon brew path issues","intel mac brew path issues",0 -"nuxt vs next for content sites","nuxt vs next for **ecommerce** sites",0 -"quote: \"strict mode\" causes duplicate renders react 18","react 18 duplicate renders with \"strict mode\"",1 -"site:forum.example.com tag:security password hashing","password hashing tag:security site:forum.example.com",1 -"after:2025-06-01 sort:recent tag:release","tag:release sort:recent after:2025-06-01",1 -"in:title migrate from circleci to github actions","migrate from circleci to github actions in:title",1 -"terraform state file locking s3","s3 terraform state locking",1 -"monitor linux disk io with iostat","iostat monitor disk io on linux",1 -"convert wav to mp3 with ffmpeg","ffmpeg convert wav to mp3",1 -"set up cron weekly on ubuntu","ubuntu cron weekly setup",1 -"python list comprehension if else example","if else in python list comprehension example",1 -"regex for email validation examples","examples of email validation regex",1 -"quote: \"strict mode\" causes duplicate renders react 18","react 18 strict mode performance",0 -"site:forum.example.com tag:security password hashing","site:forum.example.com tag:security **oauth**",0 -"after:2025-06-01 sort:recent tag:release","after:2024-06-01 sort:recent tag:release",0 -"in:title migrate from circleci to github actions","in:title migrate from **travis** to github actions",0 -"terraform state file locking s3","terraform state file locking **gcs**",0 -"monitor linux disk io with iostat","monitor linux disk io with **sar**",0 -"convert wav to mp3 with ffmpeg","convert wav to **flac** with ffmpeg",0 -"set up cron weekly on ubuntu","set up cron **daily** on ubuntu",0 -"python list comprehension if else example","python list comprehension **for if** example",0 -"regex for email validation examples","regex for **URL** validation examples",0 -"how to export from notion to markdown","export notion pages to markdown",1 -"install tailscale on synology nas","synology install tailscale",1 -"capture network packets with tcpdump basics","tcpdump basics to capture packets",1 -"explain CAP theorem like I'm five","ELI5 CAP theorem explanation",1 -"migrate from yarn to pnpm steps","steps to migrate from yarn to pnpm",1 -"zsh autocomplete not working fix","fix zsh autocomplete not working",1 -"sqlite foreign key constraints on by default","enable sqlite foreign keys by default",1 -"find python jobs board threads","threads about python job boards",1 -"eslint config for monorepo example","example eslint config for monorepo",1 -"k8s liveness vs readiness probe difference","difference between liveness and readiness probes k8s",1 -"how to export from notion to markdown","how to export from **confluence** to markdown",0 -"install tailscale on synology nas","install tailscale on **qnap** nas",0 -"capture network packets with tcpdump basics","capture network packets with **wireshark** basics",0 -"explain CAP theorem like I'm five","explain **PACELC** theorem like I'm five",0 -"migrate from yarn to pnpm steps","migrate from yarn to **npm** steps",0 -"zsh autocomplete not working fix","bash autocomplete not working fix",0 -"sqlite foreign key constraints on by default","postgres foreign key constraints on by default",0 -"find python jobs board threads","find **java** jobs board threads",0 -"eslint config for monorepo example","tsconfig for monorepo example",0 -"k8s liveness vs readiness probe difference","k8s **startup** vs readiness probe difference",0 -"show me posts about remote work burnout","remote work burnout discussions",1 -"macOS firewall block specific app","block specific app with macOS firewall",1 -"how to rotate api keys safely","rotate API keys safely",1 -"pricing discussion for cloud storage","cloud storage pricing discussion",1 -"best practices for code reviews","good code review practices",1 -"where to learn wasm basics","resources to learn wasm basics",1 -"python dataclass default factory examples","examples of dataclass default_factory",1 -"what is blue/green deployment","blue green deployment explained",1 -"how to paginate REST API responses","paginate responses in REST API",1 -"start a homelab cheaply","cheap homelab getting started",1 -"show me posts about remote work burnout","show me posts about remote **hiring**",0 -"macOS firewall block specific app","windows firewall block specific app",0 -"how to rotate api keys safely","how to **store** api keys safely",0 -"pricing discussion for cloud storage","pricing discussion for **object** storage with egress costs",0 -"best practices for code reviews","best practices for **pair programming**",0 -"where to learn wasm basics","where to learn **webgpu** basics",0 -"python dataclass default factory examples","python dataclass **slots** examples",0 -"what is blue/green deployment","what is **canary** deployment",0 -"how to paginate REST API responses","how to **filter** REST API responses",0 -"start a homelab cheaply","start a homelab **for AI** cheaply",0 -"in:title \"rate limit\" github api","github api rate limit in:title",1 -"author:carol rust lifetimes tutorial","rust lifetimes tutorial by author:carol",1 -"tag:design-system button accessibility","button accessibility tag:design-system",1 -"search exact phrase \"null reference\"","\"null reference\" exact phrase search",1 -"how to escape quotes in json","escaping quotes in json",1 -"set up private npm registry","private npm registry setup",1 -"choose a linux filesystem for ssd","best linux filesystem for ssd",1 -"websocket reconnect backoff strategy","backoff strategy for websocket reconnect",1 -"compile nginx with brotli module","nginx compile with brotli",1 -"graphql vs rest caching strategies","caching strategies for graphql vs rest",1 -"in:title \"rate limit\" github api","in:title \"rate limit\" **twitter** api",0 -"author:carol rust lifetimes tutorial","author:carol rust **traits** tutorial",0 -"tag:design-system button accessibility","tag:design-system **input** accessibility",0 -"search exact phrase \"null reference\"","search exact phrase \"null pointer\"",0 -"how to escape quotes in json","how to escape quotes in **yaml**",0 -"set up private npm registry","set up **docker** private registry",0 -"choose a linux filesystem for ssd","choose a linux filesystem for **nas**",0 -"websocket reconnect backoff strategy","http retry backoff strategy",0 -"compile nginx with brotli module","compile **apache** with brotli module",0 -"graphql vs rest caching strategies","graphql vs **grpc** caching strategies",0 -"find posts with code blocks about regex","posts about regex with code blocks has:code",1 -"search guides for migrating to ipv6","guides for migrating to ipv6",1 -"tips for reducing docker image size","reduce docker image size tips",1 -"how to debug 100% cpu in nodejs","debug high cpu in nodejs",1 -"ios push notifications not delivered","apns not delivering push notifications",1 -"split monolith into microservices pitfalls","pitfalls splitting monolith into microservices",1 -"kotlin coroutines vs threads explanation","explain kotlin coroutines vs threads",1 -"cicd for python with github actions","github actions cicd for python",1 -"raspberry pi headless setup wifi","headless raspberry pi wifi setup",1 -"sqlite full text search tutorial","tutorial for sqlite fts",1 -"find posts with code blocks about regex","posts about regex with **images** has:image",0 -"search guides for migrating to ipv6","search guides for migrating to **ipv4**",0 -"tips for reducing docker image size","tips for reducing **vm** image size",0 -"how to debug 100% cpu in nodejs","how to debug 100% **memory** in nodejs",0 -"ios push notifications not delivered","**firebase** push notifications not delivered",0 -"split monolith into microservices pitfalls","split microservices into **monolith** pitfalls",0 -"kotlin coroutines vs threads explanation","kotlin **flows** vs threads explanation",0 -"cicd for python with github actions","cicd for python with **gitlab**",0 -"raspberry pi headless setup wifi","raspberry pi headless setup **ethernet only**",0 -"sqlite full text search tutorial","postgres full text search tutorial",0 -"sort:recent tag:announcements outage report","outage report tag:announcements sort:recent",1 -"compare password managers thread","thread comparing password managers",1 -"windows terminal transparent background","make windows terminal background transparent",1 -"helm chart testing best practices","best practices for testing helm charts",1 -"vue 3 composition api intro","introduction to vue 3 composition api",1 -"python virtualenv vs conda which to use","which to use: python virtualenv or conda",1 -"learn git submodules the hard way","git submodules guide",1 -"redis pub/sub vs streams use cases","use cases for redis pub/sub vs streams",1 -"serverless cold start mitigation","mitigating serverless cold starts",1 -"find threads on burnout prevention","threads on preventing burnout",1 -"sort:recent tag:announcements outage report","sort:top tag:announcements outage report",0 -"compare password managers thread","compare **vpn providers** thread",0 -"windows terminal transparent background","windows terminal **blur** background",0 -"helm chart testing best practices","helm chart **signing** best practices",0 -"vue 3 composition api intro","react hooks intro",0 -"python virtualenv vs conda which to use","python virtualenv **inside docker** which to use",0 -"learn git submodules the hard way","learn git **subtrees** the hard way",0 -"redis pub/sub vs streams use cases","kafka pub/sub vs streams use cases",0 -"serverless cold start mitigation","serverless cold start **measurement**",0 -"find threads on burnout prevention","find threads on **on-call** burnout",0 -"how to write ADRs effectively","effective architecture decision records",1 -"linux ulimit soft vs hard","difference between soft and hard ulimit",1 -"python logging to json examples","examples of logging to json in python",1 -"monitoring with prometheus histograms","using prometheus histograms",1 -"convert csv to parquet with python","python convert csv to parquet",1 -"git lfs migrate existing repository","migrate existing repo to git lfs",1 -"compile openjdk from source tips","tips to compile openjdk from source",1 -"explain oauth2 device flow","oauth2 device code flow explained",1 -"ansible vault best practices","best practices for ansible vault",1 -"compare pgdump and pg_dumpall","pgdump vs pg_dumpall comparison",1 -"how to write ADRs effectively","how to write **RFCs** effectively",0 -"linux ulimit soft vs hard","linux **nice** values explained",0 -"python logging to json examples","python logging to **syslog** examples",0 -"monitoring with prometheus histograms","monitoring with prometheus **summaries**",0 -"convert csv to parquet with python","convert csv to parquet with **java**",0 -"git lfs migrate existing repository","git lfs **clean** existing repository",0 -"compile openjdk from source tips","compile **openj9** from source tips",0 -"explain oauth2 device flow","explain oauth2 **implicit** flow",0 -"ansible vault best practices","ansible **collections** best practices",0 -"compare pgdump and pg_dumpall","compare **mysqldump** and pg_dumpall",0 -"how to search only titles for docker networking","search titles only for docker networking",1 -"find posts with more than 50 upvotes about sso","posts about sso votes:>50",1 -"threads about migrating from svn to git","migrating from svn to git threads",1 -"how to mark a post as duplicate","mark thread as duplicate guide",1 -"find unanswered questions about grpc","replies:0 about grpc",1 -"where to report bugs to moderators","report bugs to moderators",1 -"download attachments from a thread","how to download thread attachments",1 -"how to restore deleted comments","restore deleted comments guide",1 -"explain event sourcing with examples","event sourcing explained with examples",1 -"compare json schema and protobuf","json schema vs protobuf compared",1 -"how to search only titles for docker networking","how to search only **bodies** for docker networking",0 -"find posts with more than 50 upvotes about sso","posts about sso votes:>10",0 -"threads about migrating from svn to git","threads about migrating from **git** to **svn**",0 -"how to mark a post as duplicate","how to mark a post as **solved**",0 -"find unanswered questions about grpc","replies:>0 about grpc",0 -"where to report bugs to moderators","where to report **spam** to moderators",0 -"download attachments from a thread","download **images only** from a thread",0 -"how to restore deleted comments","how to restore deleted **posts**",0 -"explain event sourcing with examples","explain **CQRS** with examples",0 -"compare json schema and protobuf","compare json schema and **avro**",0 -"posts about linux laptops under $1000","linux laptops under $1000 posts",1 -"recommend a quiet mechanical keyboard","quiet mechanical keyboard recommendations",1 -"how to calibrate a monitor for design","monitor calibration for design",1 -"3d printing beginner mistakes","beginner mistakes in 3d printing",1 -"bike commuting in rainy cities tips","tips for bike commuting in rain",1 -"budget microphones for podcasting","podcasting budget microphones",1 -"pet-friendly travel checklist","checklist for pet friendly travel",1 -"best books on distributed systems","top books on distributed systems",1 -"recipes for quick weeknight dinners","quick weeknight dinner recipes",1 -"board games for 2 players strategy","strategy board games for two players",1 -"posts about linux laptops under $1000","posts about linux laptops under **$1500**",0 -"recommend a quiet mechanical keyboard","recommend a **wireless** mechanical keyboard",0 -"how to calibrate a monitor for design","how to calibrate a monitor for **gaming**",0 -"3d printing beginner mistakes","3d printing **advanced** mistakes",0 -"bike commuting in rainy cities tips","bike commuting in **snowy** cities tips",0 -"budget microphones for podcasting","budget microphones for **streaming**",0 -"pet-friendly travel checklist","pet-friendly **air travel** checklist",0 -"best books on distributed systems","best books on **operating systems**",0 -"recipes for quick weeknight dinners","recipes for quick **vegetarian** dinners",0 -"board games for 2 players strategy","board games for **party** games",0 -"how to embed images in markdown","embed images in markdown",1 -"vpn keeps disconnecting on windows","windows vpn keeps disconnecting",1 -"ssh agent forwarding security risks","security risks of ssh agent forwarding",1 -"how to enable http3 in nginx","enable http3 in nginx",1 -"cuda out of memory tips","tips for cuda out of memory",1 -"mac shortcuts for window management","window management shortcuts mac",1 -"postgres query plan explanation basics","basics of explaining postgres query plan",1 -"golang context cancellation patterns","patterns for golang context cancellation",1 -"react suspense data fetching examples","examples of react suspense data fetching",1 -"docker buildkit secrets usage","using docker buildkit secrets",1 -"how to embed images in markdown","how to **resize** images in markdown",0 -"vpn keeps disconnecting on windows","vpn keeps disconnecting on **mac**",0 -"ssh agent forwarding security risks","ssh **X11** forwarding security risks",0 -"how to enable http3 in nginx","how to enable **http2** in nginx",0 -"cuda out of memory tips","cuda **illegal memory access** tips",0 -"mac shortcuts for window management","mac shortcuts for **clipboard**",0 -"postgres query plan explanation basics","postgres query **index** explanation basics",0 -"golang context cancellation patterns","golang context **timeouts** patterns",0 -"react suspense data fetching examples","react suspense **concurrent ui** examples",0 -"docker buildkit secrets usage","docker buildkit **cache mounts** usage",0 -"find threads by author:dan about ssh","threads about ssh by author:dan",1 -"list of pinned posts in devops tag","pinned posts list tag:devops",1 -"restore a locked account steps","steps to restore a locked account",1 -"rss feed for security tag","security tag rss feed",1 -"open source licenses comparison","comparison of open source licenses",1 -"disable inline code linting in vscode","turn off inline code linting vscode",1 -"run multiple versions of python side by side","run multiple python versions side by side",1 -"explain database sharding vs partitioning","sharding vs partitioning explained",1 -"regex backreference examples","examples of regex backreferences",1 -"rate limit strategies with redis","rate limiting strategies with redis",1 -"find threads by author:dan about ssh","find threads by author:**ann** about ssh",0 -"list of pinned posts in devops tag","list of pinned posts in **security** tag",0 -"restore a locked account steps","restore a **banned** account steps",0 -"rss feed for security tag","rss feed for **privacy** tag",0 -"open source licenses comparison","open source **copyleft** licenses comparison",0 -"disable inline code linting in vscode","disable inline code linting in **jetbrains**",0 -"run multiple versions of python side by side","run multiple versions of **node** side by side",0 -"explain database sharding vs partitioning","explain database **replication** vs partitioning",0 -"regex backreference examples","regex **lookbehind** examples",0 -"rate limit strategies with redis","rate limit strategies with **nginx**",0 -"how to filter search by solved threads only","filter to solved threads only",1 -"search within this thread only","search only in this thread",1 -"compare two-factor and passkeys","compare passkeys with two-factor",1 -"docker compose v2 migration notes","notes on migrating to docker compose v2",1 -"find tutorials for pandas merge","pandas merge tutorials",1 -"secure cookies across subdomains","secure cookies across subdomains setup",1 -"setting up oidc with keycloak","keycloak oidc setup",1 -"optimize images for webp","optimize images to webp",1 -"monitor node memory leaks tools","tools for monitoring node memory leaks",1 -"apache vs nginx for static files","nginx vs apache for static files",1 -"how to filter search by solved threads only","how to filter search by **unsolved** threads only",0 -"search within this thread only","search within this **category** only",0 -"compare two-factor and passkeys","compare **sms 2fa** and passkeys",0 -"docker compose v2 migration notes","docker **swarm** migration notes",0 -"find tutorials for pandas merge","find tutorials for pandas **join**",0 -"secure cookies across subdomains","secure cookies across **multiple top-level domains**",0 -"setting up oidc with keycloak","setting up **saml** with keycloak",0 -"optimize images for webp","optimize images for **avif**",0 -"monitor node memory leaks tools","monitor node **event loop** lag tools",0 -"apache vs nginx for static files","apache vs nginx for **php-fpm**",0 -"find moderator announcements in July 2025","moderator announcements July 2025",1 -"threads discussing GDPR updates 2025","GDPR updates 2025 discussions",1 -"best laptops threads 2025 edition","2025 edition best laptops threads",1 -"tag:beta-features feedback requested","feedback requested tag:beta-features",1 -"changelog for forum v3.2","forum v3.2 changelog",1 -"security incidents after:2025-05-01","after:2025-05-01 security incidents",1 -"dark theme accessibility audit results","accessibility audit results dark theme",1 -"posts with polls about remote work","posts with polls about remote work has:poll",1 -"threads with accepted answers on grpc","is:solved grpc threads",1 -"where to find site rules and etiquette","site rules and etiquette",1 -"find moderator announcements in July 2025","find moderator announcements in **June** 2025",0 -"threads discussing GDPR updates 2025","threads discussing **CCPA** updates 2025",0 -"best laptops threads 2025 edition","best **phones** threads 2025 edition",0 -"tag:beta-features feedback requested","tag:beta-features feedback requested **in Spanish**",0 -"changelog for forum v3.2","changelog for forum **v3.3**",0 -"security incidents after:2025-05-01","security incidents **before:2025-05-01**",0 -"dark theme accessibility audit results","**light** theme accessibility audit results",0 -"posts with polls about remote work","posts with **surveys** about remote work",0 -"threads with accepted answers on grpc","threads with accepted answers on **graphql**",0 -"where to find site rules and etiquette","where to find **api** rules and etiquette",0 -"search exact phrase \"cannot allocate memory\"","\"cannot allocate memory\" exact phrase",1 -"tag:ai policy discussion","policy discussion tag:ai",1 -"compare markdown editors used by the community","community-used markdown editors comparison",1 -"recommend beginner kubernetes courses","beginner kubernetes course recommendations",1 -"find guides for migrating email providers","guides for migrating email providers",1 -"threads about passwordless auth ux","passwordless auth UX threads",1 -"list weekend projects under 4 hours","weekend projects under 4 hours list",1 -"tips for silent pc builds","silent pc build tips",1 -"monitoring uptime with blackbox exporter","blackbox exporter for uptime monitoring",1 -"npm workspaces with turborepo gotchas","gotchas using npm workspaces with turborepo",1 -"search exact phrase \"cannot allocate memory\"","exact phrase \"out of memory\"",0 -"tag:ai policy discussion","tag:ai **ethics** discussion",0 -"compare markdown editors used by the community","compare markdown editors used by the **design** community",0 -"recommend beginner kubernetes courses","recommend **advanced** kubernetes courses",0 -"find guides for migrating email providers","find guides for migrating **cloud** providers",0 -"threads about passwordless auth ux","threads about **2fa** auth ux",0 -"list weekend projects under 4 hours","list weekend projects under **8** hours",0 -"tips for silent pc builds","tips for **rgb** pc builds",0 -"monitoring uptime with blackbox exporter","monitoring uptime with **ping** only",0 -"npm workspaces with turborepo gotchas","pnpm workspaces with turborepo gotchas",0 -"find threads linking to arXiv papers","threads linking to arXiv has:link",1 -"how to archive a thread","archive a discussion thread",1 -"view profile change history","profile change history view",1 -"enable two-factor for forum account","enable two factor authentication for forum account",1 -"how to create a community wiki post","create a community wiki post",1 -"pin a helpful answer to top","pin a helpful answer",1 -"export my post history as csv","export my post history to csv",1 -"request username change policy","policy for username changes",1 -"appeal a moderation decision","appeal moderation decision",1 -"find threads with image attachments only","threads with image attachments only has:image",1 -"find threads linking to arXiv papers","find threads linking to **SSRN** papers",0 -"how to archive a thread","how to **lock** a thread",0 -"view profile change history","view profile change history **api**",0 -"enable two-factor for forum account","disable two factor for forum account",0 -"how to create a community wiki post","how to create a **private** post",0 -"pin a helpful answer to top","pin a helpful answer **for 7 days**",0 -"export my post history as csv","export my post history as **json**",0 -"request username change policy","request **account deletion** policy",0 -"appeal a moderation decision","appeal a moderation decision **template**",0 -"find threads with image attachments only","find threads with **video** attachments only",0 -"Best horror movies from the 1990s","Top 1990s horror films list",1 -"Best horror movies from the 1990s","Best horror movies from the 2000s",0 -"Advice on handling exam stress","Tips for managing stress during exams",1 -"Advice on handling exam stress","Advice for dealing with stress at work",0 -"Where to watch Studio Ghibli films online","Online streaming options for Studio Ghibli",1 -"Where to watch Studio Ghibli films online","Where to watch Pixar movies online",0 -"Healthy dinner recipes under 500 calories","Low-calorie dinner recipes under 500 calories",1 -"Healthy dinner recipes under 500 calories","Healthy dinner recipes under 800 calories",0 -"How to make long-distance relationships work","Advice on maintaining long-distance relationships",1 -"How to make long-distance relationships work","How to end a long-distance relationship peacefully",0 -"Best novels for teenagers","Top books recommended for teens",1 -"Best novels for teenagers","Best novels for senior citizens",0 -"Top fantasy TV shows to binge","Good fantasy series to watch",1 -"Top fantasy TV shows to binge","Top science fiction TV shows to binge",0 -"Scholarships available for international students","Funding opportunities for international students",1 -"Scholarships available for international students","Internship opportunities for international students",0 -"Side effects of drinking too much coffee","Health risks of excessive coffee consumption",1 -"Side effects of drinking too much coffee","Health benefits of drinking coffee daily",0 -"Budget travel tips for Europe","Cheap travel advice for Europe",1 -"Budget travel tips for Europe","Luxury travel tips for Europe",0 -"How to build confidence in public speaking","Tips for speaking confidently in public",1 -"How to build confidence in public speaking","How to overcome shyness at parties",0 -"Top romantic comedies on Netflix","Best rom-coms available on Netflix",1 -"Top romantic comedies on Netflix","Top action movies on Netflix",0 -"How to manage time better in college","College tips for effective time management",1 -"How to manage time better in college","How to choose the right college major",0 -"Famous paintings from the Renaissance","Renaissance art masterpieces",1 -"Famous paintings from the Renaissance","Modern digital art techniques",0 -"How to prepare for a first job interview","First job interview preparation tips",1 -"How to prepare for a first job interview","Best outfits for weddings and parties",0 -"Best hiking trails in California","Top California hikes",1 -"Best hiking trails in California","Best skiing spots in California",0 -"How to save money as a student","Budgeting advice for college students",1 -"How to save money as a student","Best part-time jobs for students",0 -"Top mysteries with unexpected endings","Best plot-twist mystery novels",1 -"Top mysteries with unexpected endings","Best romance novels of 2024",0 -"How to deal with roommate conflicts","Advice for handling conflicts with roommates",1 -"How to deal with roommate conflicts","How to find a new roommate quickly",0 -"Top romantic destinations in Europe","Most romantic places to visit in Europe",1 -"Top romantic destinations in Europe","Best budget destinations in Europe",0 -"Simple meditation techniques for beginners","Easy meditation methods for newbies",1 -"Simple meditation techniques for beginners","Advanced meditation retreats in India",0 -"Best true crime documentaries","Top documentaries about real crimes",1 -"Best true crime documentaries","Best comedy specials on Netflix",0 -"Tips for staying motivated while studying","Ways to keep motivated during study sessions",1 -"Tips for staying motivated while studying","How to get better grades without studying",0 -"Popular pet breeds for families","Best pets for family households",1 -"Popular pet breeds for families","Exotic reptiles as pets",0 -"How to write a strong college essay","Tips for writing a great college application essay",1 -"How to write a strong college essay","How to write a professional business report",0 -"Best pizza toppings combinations","Most popular pizza topping choices",1 -"Best pizza toppings combinations","Best pasta sauce recipes",0 -"Movies that make you cry","Films guaranteed to make you emotional",1 -"Movies that make you cry","Movies that will make you laugh out loud",0 -"Tips for adjusting to life in a new city","Advice on adapting after moving to a new city",1 -"Tips for adjusting to life in a new city","How to choose which city to move to",0 -"Best motivational TED talks","Top inspirational TED presentations",1 -"Best motivational TED talks","Best horror movie trailers",0 -"How to plan a road trip with friends","Planning tips for a group road trip",1 -"How to plan a road trip with friends","How to plan a solo backpacking trip",0 -"Books similar to Harry Potter","Novels like the Harry Potter series",1 -"Books similar to Harry Potter","Books about space exploration",0 -"Signs of burnout at work","Symptoms of workplace burnout",1 -"Signs of burnout at work","Benefits of a productive workplace",0 -"Top musicals of all time","Best Broadway musicals ever made",1 -"Top musicals of all time","Top stand-up comedy shows of all time",0 -"How to overcome fear of failure","Tips for dealing with failure anxiety",1 -"How to overcome fear of failure","How to celebrate success effectively",0 -"Best documentaries about nature","Top nature documentaries to watch",1 -"Best documentaries about nature","Best romantic comedies on Hulu",0 -"How to learn a new language quickly","Fastest ways to pick up a new language",1 -"How to learn a new language quickly","How to improve your English writing style",0 -"Movies with the best soundtracks","Films famous for great music scores",1 -"Movies with the best soundtracks","Movies with great fight scenes",0 -"Tips for dealing with procrastination","How to stop procrastinating effectively",1 -"Tips for dealing with procrastination","How to develop creativity skills",0 -"Best podcasts for true crime fans","Top true crime podcast recommendations",1 -"Best podcasts for true crime fans","Best travel vlogs to watch",0 -"How to build healthy morning routines","Creating productive morning habits",1 -"How to build healthy morning routines","How to build healthy evening routines",0 -"Best comedies of the 1980s","Top 1980s comedy films",1 -"Best comedies of the 1980s","Best action films of the 1980s",0 -"How to manage social anxiety","Tips for overcoming social anxiety",1 -"How to manage social anxiety","How to improve public speaking humor",0 -"Best board games for families","Top family board game recommendations",1 -"Best board games for families","Best video games for teenagers",0 diff --git a/dataset/queries.csv b/dataset/queries.csv deleted file mode 100644 index 1ddfb1f..0000000 --- a/dataset/queries.csv +++ /dev/null @@ -1,8 +0,0 @@ -id,text -1,"how do I reset my password?" -2,"store hours on sunday" -3,"Halloween movies" -4,"Horror movies" -5,"Scary movies" -6,"mama ms rachel" -7,"ms rachel" diff --git a/evaluation.py b/evaluation.py index 58acec6..e3d124c 100644 --- a/evaluation.py +++ b/evaluation.py @@ -253,6 +253,18 @@ def main(args): parser.add_argument( "--redis_batch_size", type=int, default=256, help="Batch size for Redis vector operations (default: 256)" ) + parser.add_argument( + "--cross_encoder_model", + type=str, + default=None, + help="Name of the cross-encoder model to use for reranking (default: None)", + ) + parser.add_argument( + "--rerank_k", + type=int, + default=10, + help="Number of candidates to rerank (default: 10)", + ) args = parser.parse_args() main(args) diff --git a/pyproject.toml b/pyproject.toml index 086a7f3..4bc1f7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,10 +7,13 @@ requires-python = ">=3.13" dependencies = [ "accelerate>=1.11.0", "boto3>=1.40.18", + "google-genai>=1.0.0", "llm-sim-eval==0.3.0", "matplotlib>=3.10.6", "numpy>=2.0.0", + "openai>=1.0.0", "pandas>=2.3.2", + "peft>=0.18.1", "pytest>=8.4.2", "redis>=6.4.0", "redisvl>=0.10.0", @@ -30,13 +33,14 @@ only-include = [ ] [[tool.uv.index]] -url = "https://artifactory.dev.redislabs.com/artifactory/api/pypi/cloud-pypi-local/simple" +url = "" [dependency-groups] dev = [ "black>=25.1.0", "isort>=6.0.1", "pylint>=3.3.7", + "radon>=6.0.1", ] [tool.black] diff --git a/run_benchmark.py b/run_benchmark.py index e693231..577d61f 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -56,6 +56,14 @@ def main(): parser.add_argument("--redis_index_name", type=str, default="idx_cache_match") parser.add_argument("--redis_doc_prefix", type=str, default="cache:") parser.add_argument("--redis_batch_size", type=int, default=256) + parser.add_argument( + "--cross_encoder_models", + type=str, + nargs="*", + default=None, + help="List of cross-encoder models (optional). If not provided, only bi-encoder is used.", + ) + parser.add_argument("--rerank_k", type=int, default=10, help="Number of candidates to rerank.") args = parser.parse_args() @@ -101,76 +109,96 @@ def main(): for model_name in args.models: print(f"\n Model: {model_name}") - # Sanitize model name for directory structure - safe_model_name = model_name.replace("/", "_") + # Prepare list of cross-encoders to iterate over (None = no reranking) + ce_models = args.cross_encoder_models if args.cross_encoder_models else [None] + + for ce_model_name in ce_models: + if ce_model_name: + print(f" Cross-Encoder: {ce_model_name}") + else: + print(f" Cross-Encoder: None (Bi-Encoder only)") + + # Sanitize model name for directory structure + safe_model_name = model_name.replace("/", "_") + + for run_i in range(1, args.n_runs + 1): + print(f" Run {run_i}/{args.n_runs}...") + + # 1. Bootstrapping Logic + # Sample 80% of the universe + run_universe = full_df.sample( + frac=args.sample_ratio, random_state=run_i + ) # Use run_i as seed for reproducibility per run - for run_i in range(1, args.n_runs + 1): - print(f" Run {run_i}/{args.n_runs}...") + # Split into Queries (n_samples) and Cache (remainder) + if len(run_universe) <= args.n_samples: + print( + f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping." + ) + continue - # 1. Bootstrapping Logic - # Sample 80% of the universe - run_universe = full_df.sample( - frac=args.sample_ratio, random_state=run_i - ) # Use run_i as seed for reproducibility per run + queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000) + cache = run_universe.drop(queries.index) - # Split into Queries (n_samples) and Cache (remainder) - if len(run_universe) <= args.n_samples: - print( - f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping." + # Shuffle cache + cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True) + queries = queries.reset_index(drop=True) + + # 2. Construct Output Path + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + + # Include cross-encoder in output path if used + model_dir_name = safe_model_name + if ce_model_name: + safe_cross_encoder_name = ce_model_name.replace("/", "_") + model_dir_name = f"{safe_model_name}_rerank_{safe_cross_encoder_name}" + + run_output_dir = os.path.join( + args.output_dir, dataset_name, model_dir_name, f"run_{run_i}", timestamp ) - continue - - queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000) - cache = run_universe.drop(queries.index) - - # Shuffle cache - cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True) - queries = queries.reset_index(drop=True) - - # 2. Construct Output Path - timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") - run_output_dir = os.path.join(args.output_dir, dataset_name, safe_model_name, f"run_{run_i}", timestamp) - os.makedirs(run_output_dir, exist_ok=True) - - # 3. Prepare Args for Evaluation - eval_args = BenchmarkArgs( - query_log_path=dataset_path, # Not strictly used by logic below but good for reference - sentence_column=args.sentence_column, - output_dir=run_output_dir, - n_samples=args.n_samples, - model_name=model_name, - cache_path=None, - full=args.full, - llm_name=args.llm_name, - llm_model=llm_classifier, - sweep_steps=200, # Default - use_redis=args.use_redis, - redis_url=args.redis_url, - redis_index_name=args.redis_index_name, - redis_doc_prefix=args.redis_doc_prefix, - redis_batch_size=args.redis_batch_size, - # device defaults to code logic - ) - - # 4. Run Evaluation - try: - print(" Matching...") - if args.use_redis: - queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args) - else: - queries_matched = run_matching(queries.copy(), cache.copy(), eval_args) - - print(" Evaluating...") - if args.full: - run_full_evaluation(queries_matched, eval_args) - else: - run_chr_analysis(queries_matched, eval_args) - - except Exception as e: - print(f" Error in run {run_i}: {e}") - import traceback - - traceback.print_exc() + os.makedirs(run_output_dir, exist_ok=True) + + # 3. Prepare Args for Evaluation + eval_args = BenchmarkArgs( + query_log_path=dataset_path, # Not strictly used by logic below but good for reference + sentence_column=args.sentence_column, + output_dir=run_output_dir, + n_samples=args.n_samples, + model_name=model_name, + cache_path=None, + full=args.full, + llm_name=args.llm_name, + llm_model=llm_classifier, + sweep_steps=200, # Default + use_redis=args.use_redis, + redis_url=args.redis_url, + redis_index_name=args.redis_index_name, + redis_doc_prefix=args.redis_doc_prefix, + redis_batch_size=args.redis_batch_size, + cross_encoder_model=ce_model_name, + rerank_k=args.rerank_k, + # device defaults to code logic + ) + + # 4. Run Evaluation + try: + print(" Matching...") + if args.use_redis: + queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args) + else: + queries_matched = run_matching(queries.copy(), cache.copy(), eval_args) + + print(" Evaluating...") + if args.full: + run_full_evaluation(queries_matched, eval_args) + else: + run_chr_analysis(queries_matched, eval_args) + + except Exception as e: + print(f" Error in run {run_i}: {e}") + import traceback + + traceback.print_exc() print("\nBenchmark completed.") diff --git a/run_benchmark.sh b/run_benchmark.sh new file mode 100644 index 0000000..3a1b676 --- /dev/null +++ b/run_benchmark.sh @@ -0,0 +1,15 @@ +# Example usage: +uv run run_benchmark.py \ + --dataset_dir "dataset" \ + --output_dir "limitations-experiments-gte-modernbert-base-lora" \ + --models "redis/model-a-baseline" "redis/model-b-structured" \ + --dataset_names "sentencepairs_v3_unique_sentences.csv" "vizio_unique_medium.csv"\ + --sentence_column "sentence" \ + --n_runs 3 \ + --n_samples 16384 \ + --sample_ratio 0.8 \ + --llm_name "tensoropera/Fox-1-1.6B" \ + --full \ + --use_redis \ + # --cross_encoder_models "gemini/text-embedding-001" \ + # --rerank_k 1 \ No newline at end of file diff --git a/run_chr_analysis.sh b/run_chr_analysis.sh index 8dd4fdd..4a7802d 100644 --- a/run_chr_analysis.sh +++ b/run_chr_analysis.sh @@ -1,44 +1,100 @@ uv run evaluation.py \ - --query_log_path ./dataset/queries.csv \ - --cache_path ./dataset/cache.csv \ - --sentence_column text \ - --output_dir ./outputs \ - --n_samples 100 \ - --model_name "redis/langcache-embed-v3.1" - -uv run evaluation.py \ - --query_log_path ./dataset/queries.csv \ - --cache_path ./dataset/cache.csv \ - --sentence_column text \ - --output_dir ./outputs \ - --n_samples 100 \ - --model_name "redis/langcache-embed-v3.1" \ - --use_redis + --query_log_path dataset/mangoes_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./mangoes/v2 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v2" \ + --full \ uv run evaluation.py \ - --query_log_path ./dataset/queries.csv \ - --cache_path ./dataset/cache.csv \ - --sentence_column text \ - --output_dir ./outputs \ - --n_samples 100 \ - --model_name "redis/langcache-embed-v3.1" \ - --full + --query_log_path dataset/mangoes_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./mangoes/v3 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v3" \ + --full \ uv run evaluation.py \ - --query_log_path ./dataset/queries.csv \ - --cache_path ./dataset/cache.csv \ - --sentence_column text \ - --output_dir ./outputs \ - --n_samples 100 \ + --query_log_path dataset/mangoes_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./mangoes/v3.1 \ + --n_samples 1000 \ --model_name "redis/langcache-embed-v3.1" \ --full \ - --use_redis uv run evaluation.py \ - --query_log_path ./dataset/chatgpt.csv \ - --sentence_column sentence2 \ - --output_dir ./outputs \ - --n_samples 20 \ + --query_log_path dataset/mangoes_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./mangoes/v1 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v1" \ + --full + + +# ================================ + +uv run evaluation.py \ + --query_log_path dataset/chatgpt_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./rado_synthetic/v1 \ + --n_samples 500 \ + --model_name "redis/langcache-embed-v1" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/chatgpt_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./rado_synthetic/v2 \ + --n_samples 500 \ + --model_name "redis/langcache-embed-v2" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/chatgpt_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./rado_synthetic/v3 \ + --n_samples 500 \ + --model_name "redis/langcache-embed-v3" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/chatgpt_unique_sentences.csv \ + --sentence_column sentence \ + --output_dir ./rado_synthetic/v3.1 \ + --n_samples 500 \ --model_name "redis/langcache-embed-v3.1" \ - --full \ - --use_redis + --full + +# ================================ + +uv run evaluation.py \ + --query_log_path dataset/vizio_unique_sentences.csv \ + --sentence_column transcription \ + --output_dir ./vizio/v1 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v1" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/vizio_unique_sentences.csv \ + --sentence_column transcription \ + --output_dir ./vizio/v2 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v2" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/vizio_unique_sentences.csv \ + --sentence_column transcription \ + --output_dir ./vizio/v3 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v3" \ + --full + +uv run evaluation.py \ + --query_log_path dataset/vizio_unique_sentences.csv \ + --sentence_column transcription \ + --output_dir ./vizio/v3.1 \ + --n_samples 1000 \ + --model_name "redis/langcache-embed-v3.1" \ + --full \ No newline at end of file diff --git a/scripts/calculate_precision_at_1.py b/scripts/calculate_precision_at_1.py new file mode 100644 index 0000000..f175496 --- /dev/null +++ b/scripts/calculate_precision_at_1.py @@ -0,0 +1,210 @@ +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from utils import crawl_results + +def calculate_precision_at_1(csv_path): + # Reads llm_as_a_judge_results.csv + try: + df = pd.read_csv(csv_path) + if "actual_label" not in df.columns: + return None + # actual_label should be 0 or 1 + return df["actual_label"].mean() + except Exception as e: + print(f"Error reading {csv_path}: {e}") + return None + +def get_model_short_name(model_name): + if "gte-modernbert" in model_name: + return "ModernBERT" + elif "v1" in model_name: + return "v1" + elif "v2" in model_name: + return "v2" + elif "v3.1" in model_name: + return "v3-small" + return model_name.split("/")[-1] + +def print_latex_table( + title, metric_key, dataset_names, sorted_models, model_short_names, data_map, label, minimize=False +): + print("\\begin{table*}[h]") + print("\\centering") + + col_def = "l" + "c" * len(sorted_models) + print(f"\\begin{{tabular}}{{{col_def}}}") + print("\\toprule") + + header = "Dataset" + for model in sorted_models: + short_name = model_short_names[model] + header += f" & {short_name}" + print(f"{header} \\\\") + print("\\midrule") + + for dataset in dataset_names: + ds_name = dataset.replace("_test.csv", "").replace("_", "\\_") + row_str = ds_name + + # Calculate best mean for highlighting + means = [] + for model in sorted_models: + stats = data_map[dataset].get(model, {metric_key: {"mean": None}}) + val = stats[metric_key]["mean"] + if val is not None: + means.append(val) + + best_mean = None + if means: + if minimize: + best_mean = min(means) + else: + best_mean = max(means) + + for model in sorted_models: + stats = data_map[dataset].get( + model, {metric_key: {"mean": None, "std": None}} + ) + mean = stats[metric_key]["mean"] + std = stats[metric_key]["std"] + + if mean is None: + row_str += " & -" + else: + cell_str = f"{mean:.3f}" + if std is not None: + cell_str += f" \\pm {std:.3f}" + + # Highlight best + if best_mean is not None and abs(mean - best_mean) < 1e-6: + row_str += f" & \\textbf{{{cell_str}}}" + else: + row_str += f" & {cell_str}" + + print(f"{row_str} \\\\") + print("\\bottomrule") + print("\\end{tabular}") + print(f"\\caption{{{title}}}") + print(f"\\label{{{label}}}") + print("\\end{table*}") + print("\n") + +def main(): + parser = argparse.ArgumentParser("Usage: python calculate_precision_at_1.py --base_dir ") + parser.add_argument("--base_dir", type=str, required=False, default="complete_benchmark_results") + args = parser.parse_args() + + base_dir = args.base_dir + benchmark_map = crawl_results(base_dir) + + if not benchmark_map: + print("No results found.") + return + + dataset_names = sorted(list(benchmark_map.keys())) + + # Collect all models + all_models = set() + for ds in benchmark_map: + for model in benchmark_map[ds]: + all_models.add(model) + sorted_models = sorted(list(all_models)) + model_short_names = {m: get_model_short_name(m) for m in sorted_models} + + data_map = {} + + # 1. Collect Data + print(f"{'Dataset':<30} | {'Model':<30} | Precision@1") + print("-" * 80) + + for dataset in dataset_names: + data_map[dataset] = {} + for model in sorted_models: + run_paths = benchmark_map[dataset].get(model, []) + precisions = [] + + for run_path in run_paths: + details_path = os.path.join(run_path, "llm_as_a_judge_results.csv") + if os.path.exists(details_path): + p1 = calculate_precision_at_1(details_path) + if p1 is not None: + precisions.append(p1) + + stats = {"mean": None, "std": None} + if precisions: + mean_val = np.mean(precisions) + std_val = np.std(precisions) if len(precisions) > 1 else 0.0 + stats["mean"] = mean_val + stats["std"] = std_val + + print(f"{dataset:<30} | {model:<30} | {mean_val:.4f} ± {std_val:.4f}") + else: + pass # No data for this model/dataset + + data_map[dataset][model] = {"precision": stats} + + print("\n" + "="*80 + "\n") + + # 2. Print Latex Table + print_latex_table( + title="Precision@1 (Mean $\\pm$ Std, $\\uparrow$)", + metric_key="precision", + dataset_names=dataset_names, + sorted_models=sorted_models, + model_short_names=model_short_names, + data_map=data_map, + label="tab:precision_at_1", + minimize=False + ) + + # 3. Plot HBar + for dataset in dataset_names: + dataset_full_path = os.path.join(base_dir, dataset) + if not os.path.exists(dataset_full_path): + continue + + models_in_ds = [] + means = [] + stds = [] + + for model in sorted_models: + stats = data_map[dataset][model]["precision"] + if stats["mean"] is not None: + models_in_ds.append(model_short_names[model]) + means.append(stats["mean"]) + stds.append(stats["std"]) + + if not models_in_ds: + continue + + # Sort by mean precision + zipped = sorted(zip(means, stds, models_in_ds)) + means_sorted, stds_sorted, models_sorted = zip(*zipped) + + plt.figure(figsize=(10, max(4, len(models_sorted) * 0.8 + 2))) + y_pos = np.arange(len(models_sorted)) + + plt.barh(y_pos, means_sorted, xerr=stds_sorted, align='center', alpha=0.8, capsize=5) + plt.yticks(y_pos, models_sorted) + plt.xlabel('Precision@1') + plt.title(f'Precision@1 for {dataset}') + plt.xlim(0, 1.05) + plt.grid(axis='x', alpha=0.3) + + # Add values to bars + for i, v in enumerate(means_sorted): + plt.text(v + 0.01, i, f"{v:.3f}", va='center') + + plt.tight_layout() + output_path = os.path.join(dataset_full_path, "precision_at_1.png") + print(f"Saving plot to {output_path}") + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + +if __name__ == "__main__": + main() + diff --git a/scripts/plot_multiple_precision_vs_cache_hit_ratio.py b/scripts/plot_multiple_precision_vs_cache_hit_ratio.py index 0a693db..7ba8206 100644 --- a/scripts/plot_multiple_precision_vs_cache_hit_ratio.py +++ b/scripts/plot_multiple_precision_vs_cache_hit_ratio.py @@ -1,12 +1,101 @@ import argparse import os +from collections import defaultdict import matplotlib.pyplot as plt +import matplotlib.colors as mcolors import numpy as np import pandas as pd from utils import crawl_results +def extract_retriever_name(model_name): + """Extract retriever name from model name (part before '_rerank_').""" + if '_rerank_' in model_name: + return model_name.split('_rerank_')[0] + return model_name + + +def extract_reranker_name(model_name): + """Extract reranker name from model name (part after '_rerank_'), or None if no reranker.""" + if '_rerank_' in model_name: + return model_name.split('_rerank_')[1] + return None + + +def darken_color(color, factor): + """ + Darken a color by a given factor (0 = original, 1 = black). + factor should be between 0 and 1. + """ + rgb = mcolors.to_rgb(color) + darkened = tuple(c * (1 - factor) for c in rgb) + return darkened + + +def get_retriever_color_map(model_names): + """ + Create a color mapping for retrievers and their cross-encoder variants. + Returns: dict mapping model_name -> color + """ + # Group models by retriever + retriever_groups = defaultdict(list) + for model_name in model_names: + retriever = extract_retriever_name(model_name) + retriever_groups[retriever].append(model_name) + + # Sort retrievers for consistent ordering + sorted_retrievers = sorted(retriever_groups.keys()) + + # Use a colorful palette with good distinction + base_colors = [ + '#e6194B', # Red + '#3cb44b', # Green + '#4363d8', # Blue + '#f58231', # Orange + '#911eb4', # Purple + '#42d4f4', # Cyan + '#f032e6', # Magenta + '#bfef45', # Lime + '#fabed4', # Pink + '#469990', # Teal + '#dcbeff', # Lavender + '#9A6324', # Brown + '#fffac8', # Beige + '#800000', # Maroon + '#aaffc3', # Mint + ] + + color_map = {} + + for i, retriever in enumerate(sorted_retrievers): + base_color = base_colors[i % len(base_colors)] + models_in_group = retriever_groups[retriever] + + # Sort models within group: base retriever first, then rerankers alphabetically + def sort_key(m): + reranker = extract_reranker_name(m) + if reranker is None: + return (0, '') # Base retriever comes first + return (1, reranker) + + models_in_group.sort(key=sort_key) + + # Assign colors with increasing darkness + n_models = len(models_in_group) + for j, model_name in enumerate(models_in_group): + if n_models == 1: + # Only base retriever, use base color + color_map[model_name] = base_color + else: + # Darken progressively: base is brightest, last reranker is darkest + # factor ranges from 0 (base) to ~0.6 (darkest reranker) + darken_factor = j * 0.5 / (n_models - 1) if n_models > 1 else 0 + color_map[model_name] = darken_color(base_color, darken_factor) + + return color_map, sorted_retrievers + + def main(): parser = argparse.ArgumentParser( "Usage: python plot_multiple_precision_vs_cache_hit_ratio.py --base_dir " @@ -26,8 +115,13 @@ def main(): dataset_full_path = os.path.join(base_dir, dataset_name) if not os.path.exists(dataset_full_path): continue - fig, ax = plt.subplots(figsize=(10, 7)) - colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + + # CHANGED: Create two subplots: one for curves, one for the AUC bar chart + fig, (ax_main, ax_bar) = plt.subplots(1, 2, figsize=(22, 10), gridspec_kw={'width_ratios': [2, 1]}) + + # Build color map based on retriever grouping + all_model_names = list(model_data.keys()) + color_map, sorted_retrievers = get_retriever_color_map(all_model_names) # Get base rate from first valid run to compute theoretical curves base_rate = None @@ -48,6 +142,9 @@ def main(): if base_rate is not None: break + # Theoretical AUCs storage + theory_aucs = {} + # Plot theoretical curves if base_rate is not None: # Theoretical Perfect (Uniform Negatives) @@ -56,16 +153,30 @@ def main(): x_uniform = np.concatenate(([0], x_uniform)) y_uniform = np.concatenate(([1], y_uniform)) auc_uniform = base_rate * (1 - np.log(base_rate)) - ax.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}") + ax_main.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}") + theory_aucs['Uniform'] = auc_uniform # Theoretical Perfect (Zero Negatives) x_zeros = [0, base_rate, 1.0] y_zeros = [1.0, 1.0, base_rate] auc_zeros = base_rate + 0.5 * (1 - base_rate**2) - ax.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}") + ax_main.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}") + theory_aucs['ZeroNegs'] = auc_zeros + + # Sort models by retriever group, then by reranker + def model_sort_key(m): + retriever = extract_retriever_name(m) + reranker = extract_reranker_name(m) + if reranker is None: + return (retriever, 0, '') + return (retriever, 1, reranker) + + sorted_models = sorted(model_data.keys(), key=model_sort_key) + + # Store data for the bar plot + auc_records = [] - sorted_models = sorted(model_data.keys()) - for i, model_name in enumerate(sorted_models): + for model_name in sorted_models: run_paths = model_data[model_name] precisions_interp = [] aucs_pchr = [] @@ -81,7 +192,7 @@ def main(): try: df = pd.read_csv(csv_path) - # Remove the last row because it's it's always precision = 1.0 + # Remove the last row because it's always precision = 1.0 df = df.iloc[:-1] x_chr = df["cache_hit_ratio"].values @@ -95,7 +206,13 @@ def main(): p_interp = np.interp(common_chr, x_chr, y_prec) precisions_interp.append(p_interp) - aucs_pchr.append(np.trapezoid(p_interp, common_chr)) + # Use numpy.trapezoid (NumPy 2.0) or numpy.trapz (older) + try: + auc_val = np.trapezoid(p_interp, common_chr) + except AttributeError: + auc_val = np.trapz(p_interp, common_chr) + + aucs_pchr.append(auc_val) valid_runs += 1 except Exception as e: @@ -109,17 +226,60 @@ def main(): mean_auc_pchr = np.mean(aucs_pchr) std_auc_pchr = np.std(aucs_pchr) if valid_runs > 1 else 0.0 - color = colors[i % len(colors)] + # Get color from the retriever-based color map + color = color_map[model_name] - label_chr = f"{model_name}, AUC: {mean_auc_pchr:.3f} ± {std_auc_pchr:.3f}" - ax.plot(common_chr, mean_p_chr, label=label_chr, color=color) + ax_main.plot(common_chr, mean_p_chr, label=model_name, color=color) if valid_runs > 1: - ax.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2) - ax.set_xlabel("Cache Hit Ratio") - ax.set_ylabel("Precision") - ax.set_title("Precision vs Cache Hit Ratio") - ax.grid(True) - ax.legend() + ax_main.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2) + + # Save data for bar chart + auc_records.append({ + 'name': model_name, + 'mean': mean_auc_pchr, + 'std': std_auc_pchr, + 'color': color, + 'retriever': extract_retriever_name(model_name) + }) + + # --- Configure Main Curve Plot --- + ax_main.set_xlabel("Cache Hit Ratio") + ax_main.set_ylabel("Precision") + ax_main.set_title("Precision vs Cache Hit Ratio") + ax_main.grid(True) + ax_main.legend() + + # --- Configure Bar Chart --- + if auc_records: + # Sort by mean AUC (ascending so best is at top) + auc_records.sort(key=lambda x: x['mean'], reverse=False) + + names = [r['name'] for r in auc_records] + means = [r['mean'] for r in auc_records] + stds = [r['std'] for r in auc_records] + bar_colors = [r['color'] for r in auc_records] + y_pos = np.arange(len(names)) + + ax_bar.barh(y_pos, means, xerr=stds, color=bar_colors, align='center', capsize=5, alpha=0.8) + ax_bar.set_yticks(y_pos) + ax_bar.set_yticklabels(names) + + # Add AUC values as text labels on the bars + for i, (mean, std) in enumerate(zip(means, stds)): + ax_bar.text(mean + std + 0.05, i, f'{mean:.3f} ± {std:.3f}', va='center', ha='left', fontsize=12) + ax_bar.set_xlabel("AUC") + ax_bar.set_title("AUC Comparison") + ax_bar.grid(axis='x', linestyle='--', alpha=0.7) + + # Add theoretical lines to bar chart + if 'Uniform' in theory_aucs: + ax_bar.axvline(theory_aucs['Uniform'], color='black', linestyle='--', alpha=0.7) + if 'ZeroNegs' in theory_aucs: + ax_bar.axvline(theory_aucs['ZeroNegs'], color='black', linestyle=':', alpha=0.7) + + # Set x-limits to focus on relevant area if needed, or 0-1 + # ax_bar.set_xlim(0, 1.05) + fig.suptitle(f"Performance on {dataset_name.split('_')[0]}") plt.tight_layout() output_path = os.path.join(dataset_full_path, "precision_vs_cache_hit_ratio.png") @@ -129,4 +289,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/scripts/plot_precision_over_threshold.py b/scripts/plot_precision_over_threshold.py new file mode 100644 index 0000000..5300ed6 --- /dev/null +++ b/scripts/plot_precision_over_threshold.py @@ -0,0 +1,119 @@ +import argparse +import os +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from utils import crawl_results + +def main(): + parser = argparse.ArgumentParser("Usage: python plot_precision_over_threshold.py --base_dir ") + parser.add_argument("--base_dir", type=str, required=False, default="complete_benchmark_results") + args = parser.parse_args() + + base_dir = args.base_dir + benchmark_map = crawl_results(base_dir) + + if not benchmark_map: + print("No results found.") + return + + for dataset_name, model_data in benchmark_map.items(): + print(f"Processing {dataset_name}...") + + dataset_full_path = os.path.join(base_dir, dataset_name) + if not os.path.exists(dataset_full_path): + continue + + n = len(model_data) + cols = 3 + rows = (n + cols - 1) // cols + + if n == 1: + rows = 1 + cols = 1 + elif n == 0: + continue + + # Create a single plot for this dataset + plt.figure(figsize=(12, 8)) + plt.title(f"Precision vs Threshold for {dataset_name}", fontsize=16) + + sorted_models = sorted(model_data.keys()) + + # Common thresholds for interpolation + common_thresholds = np.linspace(0.5, 1, 200) + + # Use a qualitative colormap with many distinct colors (tab20) + # and cycle through them if we have more models than colors. + cmap = plt.get_cmap("tab20") + + # Define some linestyles to help distinguish further + linestyles = ['-', '--', '-.', ':'] + + for idx, model_name in enumerate(sorted_models): + # improved color selection: + color = cmap(idx % 20) + linestyle = linestyles[(idx // 20) % len(linestyles)] + + run_paths = model_data[model_name] + try: + all_precisions = [] + valid_runs = 0 + + for run_path in run_paths: + sweep_path = os.path.join(run_path, "threshold_sweep_results.csv") + if not os.path.exists(sweep_path): + continue + + try: + df = pd.read_csv(sweep_path) + if "threshold" not in df.columns or "precision" not in df.columns: + continue + + # Sort by threshold + df = df.sort_values("threshold") + + x = df["threshold"].values + y = df["precision"].values + + y_interp = np.interp(common_thresholds, x, y) + all_precisions.append(y_interp) + valid_runs += 1 + + except Exception as e: + print(f"Error processing run {run_path}: {e}") + + if valid_runs == 0: + continue + + mean_precision = np.mean(all_precisions, axis=0) + std_precision = np.std(all_precisions, axis=0) if valid_runs > 1 else np.zeros_like(mean_precision) + + plt.plot(common_thresholds, mean_precision, label=model_name, color=color, linestyle=linestyle, linewidth=2) + plt.fill_between( + common_thresholds, + np.maximum(0, mean_precision - std_precision), + np.minimum(1, mean_precision + std_precision), + alpha=0.1, + color=color + ) + + except Exception as e: + print(f"Error plotting {model_name}: {e}") + + plt.xlabel("Threshold") + plt.ylabel("Precision") + plt.ylim(0, 1.05) + plt.xlim(0.5, 1.0) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') + plt.grid(True, alpha=0.3) + plt.tight_layout() + + output_path = os.path.join(dataset_full_path, "precision_over_threshold.png") + print(f"Saving plot to {output_path}") + plt.savefig(output_path, dpi=300, bbox_inches="tight") + plt.close() + +if __name__ == "__main__": + main() + diff --git a/src/customer_analysis/__init__.py b/src/customer_analysis/__init__.py index 4cd24eb..78e0e58 100644 --- a/src/customer_analysis/__init__.py +++ b/src/customer_analysis/__init__.py @@ -6,7 +6,14 @@ run_matching, run_matching_redis, ) -from .embedding_interface import NeuralEmbedding +from .embedding_interface import EmbeddingModel, get_embedding_model +from .embedding_providers import ( + EmbeddingProvider, + GeminiProvider, + HuggingFaceProvider, + OpenAIProvider, + get_embedding_provider, +) from .file_handler import FileHandler from .metrics_util import ( calculate_f_beta_score, @@ -17,21 +24,37 @@ from .plotting import generate_plots, plot_cache_hit_ratio from .query_engine import RedisVectorIndex from .s3_util import s3_upload_dataframe_csv, s3_upload_matplotlib_png +from .similarity_matcher import SimilarityMatcher __all__ = [ + # File handling "FileHandler", - "NeuralEmbedding", + # Embedding + "EmbeddingModel", + "get_embedding_model", + # Providers + "EmbeddingProvider", + "HuggingFaceProvider", + "OpenAIProvider", + "GeminiProvider", + "get_embedding_provider", + "SimilarityMatcher", + # Data processing "load_data", "postprocess_results_for_metrics", "run_matching", "run_matching_redis", + # Metrics "evaluate_threshold_on_results", "sweep_thresholds_on_results", "calculate_f_beta_score", "calculate_metrics", + # Plotting "generate_plots", "plot_cache_hit_ratio", + # S3 utilities "s3_upload_dataframe_csv", "s3_upload_matplotlib_png", + # Redis "RedisVectorIndex", ] diff --git a/src/customer_analysis/data_processing.py b/src/customer_analysis/data_processing.py index 450b01d..cfeed75 100644 --- a/src/customer_analysis/data_processing.py +++ b/src/customer_analysis/data_processing.py @@ -1,7 +1,8 @@ import pandas as pd +import numpy as np import torch -from src.customer_analysis.embedding_interface import NeuralEmbedding +from src.customer_analysis.embedding_interface import get_embedding_model from src.customer_analysis.file_handler import FileHandler from src.customer_analysis.query_engine import RedisVectorIndex @@ -23,6 +24,21 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): """ text_col = args.sentence_column + # Determine k for retrieval + k = 1 + cross_encoder = None + if getattr(args, "cross_encoder_model", None): + try: + from sentence_transformers import CrossEncoder + cross_encoder = CrossEncoder( + args.cross_encoder_model, + device=getattr(args, "device", None) or ("cuda" if torch.cuda.is_available() else "cpu") + ) + k = getattr(args, "rerank_k", 10) + print(f"Using Cross-Encoder reranking: {args.cross_encoder_model} (top-{k})") + except ImportError: + print("Warning: sentence_transformers not found or CrossEncoder import failed. Skipping reranking.") + rindex = RedisVectorIndex( col_query=text_col, index_name=getattr(args, "redis_index_name", "idx_cache_match"), @@ -38,27 +54,83 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): # 2) embed + load cache cache_texts = cache[text_col].tolist() cache_vecs = rindex._embed_batch(cache_texts) # (M, D) + + # Fix: Ensure vectors are normalized and float32 + norms = np.linalg.norm(cache_vecs, axis=1, keepdims=True) + norms[norms == 0] = 1e-9 + cache_vecs = (cache_vecs / norms).astype(np.float32) + rindex.load_texts_and_vecs(cache_texts, cache_vecs) - # 3) embed queries and search top-1 + # 3) embed queries and search top-k query_texts = queries[text_col].tolist() query_vecs = rindex._embed_batch(query_texts) + + # Normalize queries too and ensure float32 + norms = np.linalg.norm(query_vecs, axis=1, keepdims=True) + norms[norms == 0] = 1e-9 + query_vecs = (query_vecs / norms).astype(np.float32) best_scores: list[float] = [] matches: list[str] = [] - for qv in query_vecs: - resp = rindex.query_vector_topk(qv, k=1) - if not resp: - best_scores.append(0.0) - matches.append("") - continue - - hit = resp[0] - cosine_sim = 1.0 - float(hit["vector_distance"]) # convert to similarity - - best_scores.append(cosine_sim) - matches.append(hit[text_col]) + if cross_encoder and k > 1: + all_pairs = [] + query_candidate_counts = [] + candidates_list = [] # store candidates for each query to retrieve text later + + print("Retrieving candidates from Redis...") + for i, qv in enumerate(query_vecs): + resp = rindex.query_vector_topk(qv, k=k) + if not resp: + query_candidate_counts.append(0) + candidates_list.append([]) + continue + + q_text = query_texts[i] + cands = [r[text_col] for r in resp] + candidates_list.append(cands) + + for c_text in cands: + all_pairs.append([q_text, c_text]) + + query_candidate_counts.append(len(cands)) + + if all_pairs: + print(f"Reranking {len(all_pairs)} pairs with Cross-Encoder...") + all_scores = cross_encoder.predict(all_pairs, batch_size=32, show_progress_bar=True) + + # Reassemble + score_idx = 0 + for i, count in enumerate(query_candidate_counts): + if count == 0: + best_scores.append(0.0) + matches.append("") + continue + + # Get scores for this query + q_scores = all_scores[score_idx : score_idx + count] + score_idx += count + + best_idx = np.argmax(q_scores) + best_scores.append(float(q_scores[best_idx])) + matches.append(candidates_list[i][best_idx]) + else: + best_scores = [0.0] * len(queries) + matches = [""] * len(queries) + + else: + for qv in query_vecs: + resp = rindex.query_vector_topk(qv, k=1) + if not resp: + best_scores.append(0.0) + matches.append("") + continue + + hit = resp[0] + cosine_sim = 1.0 - float(hit["vector_distance"]) # convert to similarity + best_scores.append(cosine_sim) + matches.append(hit[text_col]) # 4) attach outputs out = queries.copy() @@ -73,18 +145,64 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): def run_matching(queries, cache, args): - embedding_model = NeuralEmbedding(args.model_name, device="cuda" if torch.cuda.is_available() else "cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" + embedding_model = get_embedding_model(args.model_name, device=device) + + # Determine k for retrieval + k = 1 + cross_encoder = None + if getattr(args, "cross_encoder_model", None): + try: + from sentence_transformers import CrossEncoder + cross_encoder = CrossEncoder( + args.cross_encoder_model, + device="cuda" if torch.cuda.is_available() else "cpu" + ) + k = getattr(args, "rerank_k", 10) + print(f"Using Cross-Encoder reranking: {args.cross_encoder_model} (top-{k})") + except ImportError: + print("Warning: sentence_transformers not found or CrossEncoder import failed. Skipping reranking.") queries["best_scores"] = 0 + query_list = queries[args.sentence_column].to_list() + cache_list = cache[args.sentence_column].to_list() + best_indices, best_scores, decision_methods = embedding_model.calculate_best_matches_with_cache_large_dataset( - queries=queries[args.sentence_column].to_list(), - cache=cache[args.sentence_column].to_list(), + queries=query_list, + cache=cache_list, batch_size=512, + k=k ) - queries["best_scores"] = best_scores - queries["matches"] = cache.iloc[best_indices][args.sentence_column].to_list() + if cross_encoder and k > 1: + print("Reranking results with Cross-Encoder...") + # best_indices is (N, k) + all_pairs = [] + N = len(query_list) + + for i in range(N): + q_text = query_list[i] + for idx in best_indices[i]: + all_pairs.append([q_text, cache_list[idx]]) + + if all_pairs: + all_scores = cross_encoder.predict(all_pairs, batch_size=128, show_progress_bar=True) + all_scores = all_scores.reshape(N, k) + + best_idx_in_k = np.argmax(all_scores, axis=1) # (N,) + + final_scores = all_scores[np.arange(N), best_idx_in_k] + final_cache_indices = best_indices[np.arange(N), best_idx_in_k] + + queries["best_scores"] = final_scores + queries["matches"] = [cache_list[i] for i in final_cache_indices] + else: + queries["best_scores"] = 0.0 + queries["matches"] = "" + else: + queries["best_scores"] = best_scores + queries["matches"] = cache.iloc[best_indices][args.sentence_column].to_list() del embedding_model torch.cuda.empty_cache() diff --git a/src/customer_analysis/embedding_interface.py b/src/customer_analysis/embedding_interface.py index 68e9d9d..d5d5490 100644 --- a/src/customer_analysis/embedding_interface.py +++ b/src/customer_analysis/embedding_interface.py @@ -1,458 +1,89 @@ -import os -import tempfile +""" +Embedding interface for semantic cache evaluation. + +This module provides a unified EmbeddingModel class that works with any embedding provider +(HuggingFace, OpenAI, Gemini) and delegates matching logic to SimilarityMatcher. +""" + from typing import Optional import numpy as np -from sentence_transformers import SentenceTransformer -from tqdm import tqdm +from src.customer_analysis.embedding_providers import ( + EmbeddingProvider, + get_embedding_provider, +) +from src.customer_analysis.similarity_matcher import SimilarityMatcher -class NeuralEmbedding: - """ - A placeholder for a neural embedding model that will use a Hugging Face model. + +class EmbeddingModel: """ + Unified embedding model that works with any EmbeddingProvider. - def __init__(self, model_name: str, device: str = "cpu"): - """ - Initialize the NeuralEmbedding model. - """ - self.model = SentenceTransformer(model_name, device=device, local_files_only=False, trust_remote_code=True) - self.embeddings = None + This class provides a simple interface for embedding and similarity matching, + delegating actual work to the underlying provider and SimilarityMatcher. + """ - def encode(self, sentences: list[str], **kwargs) -> np.ndarray: - """ - A placeholder 'encode' method to maintain compatibility with the evaluation - pipeline, which expects a model to have this method. + def __init__(self, provider: EmbeddingProvider): + """Initialize with an embedding provider.""" + self._provider = provider + self._matcher = SimilarityMatcher(provider) + self.embeddings: Optional[dict[str, list[float]]] = None - This method does not generate meaningful embeddings. Instead, it returns - a zero vector for each sentence. The actual similarity logic is handled - by the `calculate_best_matches` method. + @property + def model(self): + """Return underlying model for HuggingFace providers.""" + return getattr(self._provider, "model", None) - Args: - sentences (list[str]): A list of sentences to "encode". - **kwargs: Additional keyword arguments (ignored). + @property + def client(self): + """Return underlying client for API providers.""" + return getattr(self._provider, "client", None) - Returns: - np.ndarray: A numpy array of zero vectors. - """ - return self.model.encode(sentences) + @property + def model_name(self) -> str: + """Return model name.""" + return getattr(self._provider, "model_name", "unknown") - def calculate_best_matches( - self, sentences: list[str], batch_size: int = 32, large_dataset: bool = False, early_stop: int = 0 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Calculates the best similarity match for each sentence against all other - sentences using a neural embedding model. - - Args: - sentences (list[str]): The list of sentences to compare. - batch_size (int): The batch size to use for the similarity search. - - Returns: - tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: - - best_indices (np.ndarray): An array of indices for the best match of each sentence. - - best_scores (np.ndarray): An array of similarity scores (0-1) for the best matches. - - decision_methods (np.ndarray): An array with the string value "neural" for every sentence. - """ - if not large_dataset: - self.embeddings = self.embed_all_sentences(sentences, batch_size) - return self.calculate_best_matches_from_embeddings(self.embeddings, sentences, batch_size) - else: - return self._calculate_best_matches_large_dataset(sentences, batch_size, early_stop=early_stop) + def encode(self, sentences: list[str], **kwargs) -> np.ndarray: + """Encode sentences to embeddings.""" + batch_size = kwargs.get("batch_size", 32) + return self._provider.encode(sentences, batch_size=batch_size, normalize=False) def embed_all_sentences(self, sentences: list[str], batch_size: int) -> dict[str, list[float]]: - """Embed all unique sentences with the provided model.""" - sentence_to_embeddings: dict[str, list[float]] = {} - sentence_list = list(set(sentences)) - total = len(sentence_list) - - print(f"Embedding {total} unique sentences in batches of {batch_size} ...") + """Embed all unique sentences.""" + return self._matcher.embed_all_sentences(sentences, batch_size) - for start in tqdm(range(0, total, batch_size), desc="Embedding sentences..."): - end = min(start + batch_size, total) - batch = sentence_list[start:end] - - batch_embs = self.model.encode(batch) - for sent, emb in zip(batch, batch_embs): - sentence_to_embeddings[sent] = emb.tolist() if hasattr(emb, "tolist") else emb - return sentence_to_embeddings - - def calculate_best_matches_from_embeddings( - self, embeddings: dict[str, list[float]], sentences: list[str], batch_size: int = 1024 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Calculate the best similarity match for each sentence without building a full similarity matrix. - """ - best_indices = np.zeros(len(sentences), dtype=np.int32) - best_scores = np.zeros(len(sentences), dtype=np.float32) - decision_methods = np.full(len(sentences), "neural", dtype=object) - - for sentence_batch_idx in tqdm( - range(0, len(sentences), batch_size), desc="Calculating best matches with cache..." - ): - max_index = min(sentence_batch_idx + batch_size, len(sentences)) - - out = self.calculate_best_matches_from_embeddings_with_cache( - cache_embeddings=embeddings, - sentence_embeddings=embeddings, - sentences=sentences[sentence_batch_idx:max_index], - cache=sentences, - batch_size=batch_size, - sentence_offset=sentence_batch_idx, - mask_self_similarity=True, - ) - best_indices_batch, best_scores_batch, decision_methods_batch = out - - best_indices[sentence_batch_idx:max_index] = best_indices_batch - best_scores[sentence_batch_idx:max_index] = best_scores_batch - decision_methods[sentence_batch_idx:max_index] = decision_methods_batch - - return best_indices, best_scores, decision_methods - - # ------------------------------ - # Large dataset helper methods - # ------------------------------ - def _infer_embedding_dim(self, sentences: list[str]) -> int: - """Return the embedding dimension for the current model.""" - try: - return int(self.model.get_sentence_embedding_dimension()) - except Exception: - probe = self.model.encode([sentences[0]]) - return int(probe.shape[1]) - - def _prepare_memmap_dir(self, memmap_dir: Optional[str]) -> tuple[bool, str, str]: - """Ensure a directory exists for memmap files and return path components. - - Returns (created_tmpdir, directory_path, embeddings_path). - """ - created_tmpdir = False - if memmap_dir is None: - memmap_dir = tempfile.mkdtemp(prefix="embedding_eval_memmap_") - created_tmpdir = True - else: - os.makedirs(memmap_dir, exist_ok=True) - emb_path = os.path.join(memmap_dir, "embeddings.dat") - return created_tmpdir, memmap_dir, emb_path - - def _write_embeddings_memmap( + def calculate_best_matches( self, sentences: list[str], - emb_path: str, - num_sentences: int, - embedding_dim: int, - batch_size: int, - dtype: np.dtype, - ) -> None: - """Encode sentences in batches, normalize, and write to a memmap file.""" - embeddings_mm = np.memmap(emb_path, mode="w+", dtype=dtype, shape=(num_sentences, embedding_dim)) - print(f"Encoding and writing {num_sentences} embeddings to memmap at {emb_path} ...") - for start in tqdm(range(0, num_sentences, batch_size), desc="Encoding (memmap)..."): - end = min(start + batch_size, num_sentences) - batch = sentences[start:end] - batch_embs = self.model.encode(batch, normalize_embeddings=True).astype(dtype, copy=False) - embeddings_mm[start:end] = batch_embs - embeddings_mm.flush() - del embeddings_mm - - def _choose_block_sizes(self, batch_size: int) -> tuple[int, int]: - """Pick conservative row/col block sizes to bound peak memory.""" - max_block_bytes = 128 * 1024 * 1024 # ~128MB per similarity block - row_block = min(batch_size, 4096) - col_block = max(512, min(batch_size, int(max_block_bytes / 4 / max(1, row_block)))) - return row_block, col_block - - def _compute_blockwise_best_matches( - self, - emb_path: str, - num_sentences: int, - embedding_dim: int, - row_block: int, - col_block: int, - dtype: np.dtype, - early_stop: int = 0, - ) -> tuple[np.ndarray, np.ndarray]: - """Blockwise exact nearest-neighbour by cosine-similarity via dot-products.""" - n = early_stop if early_stop > 0 else num_sentences - best_scores = np.full(n, -np.inf, dtype=np.float32) - best_indices = np.zeros(n, dtype=np.int32) - - embeddings_mm = np.memmap(emb_path, mode="r", dtype=dtype, shape=(n, embedding_dim)) - for row_start in tqdm(range(0, n, row_block), desc="Row blocks"): - row_end = min(row_start + row_block, n) - row_emb = np.asarray(embeddings_mm[row_start:row_end]) - - chunk_best_scores = np.full(row_end - row_start, -np.inf, dtype=np.float32) - chunk_best_indices = np.zeros(row_end - row_start, dtype=np.int32) - - for col_start in range(0, n, col_block): - col_end = min(col_start + col_block, n) - col_emb = np.asarray(embeddings_mm[col_start:col_end]) - - sim = row_emb @ col_emb.T - - # mask diagonal for overlapping region to avoid self-match - overlap_start = max(row_start, col_start) - overlap_end = min(row_end, col_end) - if overlap_start < overlap_end: - i = np.arange(overlap_start, overlap_end) - sim[i - row_start, i - col_start] = -np.inf - - block_idx = np.argmax(sim, axis=1) - block_val = sim[np.arange(sim.shape[0]), block_idx].astype(np.float32, copy=False) - - better = block_val > chunk_best_scores - if np.any(better): - chunk_best_scores[better] = block_val[better] - chunk_best_indices[better] = col_start + block_idx[better] - - best_scores[row_start:row_end] = chunk_best_scores - best_indices[row_start:row_end] = chunk_best_indices - - del embeddings_mm - return best_indices, best_scores - - def _cleanup_memmap(self, created_tmpdir: bool, memmap_dir: str, emb_path: str) -> None: - """Best-effort cleanup of memmap file and temp directory if created here.""" - if not created_tmpdir: - return - try: - if os.path.exists(emb_path): - os.remove(emb_path) - os.rmdir(memmap_dir) - except Exception: - pass - - def _compute_blockwise_best_matches_two_sets( - self, - row_emb_path: str, - num_rows: int, - col_emb_path: str, - num_cols: int, - embedding_dim: int, - row_block: int, - col_block: int, - dtype: np.dtype, - *, - mask_self_similarity: bool = False, - sentence_offset: int = 0, + batch_size: int = 32, + large_dataset: bool = False, early_stop: int = 0, - ) -> tuple[np.ndarray, np.ndarray]: - """Blockwise nearest-neighbour where rows and columns come from two sets. - - If mask_self_similarity is True, rows are assumed to correspond to a - contiguous slice of the columns starting at `sentence_offset`, and the - diagonal entries for that alignment will be masked to -inf. - """ - n_rows = early_stop if early_stop > 0 else num_rows - best_scores = np.full(n_rows, -np.inf, dtype=np.float32) - best_indices = np.zeros(n_rows, dtype=np.int32) - - rows_mm = np.memmap(row_emb_path, mode="r", dtype=dtype, shape=(n_rows, embedding_dim)) - cols_mm = np.memmap(col_emb_path, mode="r", dtype=dtype, shape=(num_cols, embedding_dim)) - - for row_start in tqdm(range(0, n_rows, row_block), desc="Row blocks (two-sets)"): - row_end = min(row_start + row_block, n_rows) - row_emb = np.asarray(rows_mm[row_start:row_end]) - - chunk_best_scores = np.full(row_end - row_start, -np.inf, dtype=np.float32) - chunk_best_indices = np.zeros(row_end - row_start, dtype=np.int32) - - for col_start in range(0, num_cols, col_block): - col_end = min(col_start + col_block, num_cols) - col_emb = np.asarray(cols_mm[col_start:col_end]) - - sim = row_emb @ col_emb.T - - # Mask diagonal if needed to avoid self-similarity - if mask_self_similarity: - # Calculate the overlap between row indices and column indices - row_global_start = row_start + sentence_offset - row_global_end = row_end + sentence_offset - overlap_start = max(row_global_start, col_start) - overlap_end = min(row_global_end, col_end) - - if overlap_start < overlap_end: - # Map global indices to local block indices - row_local_indices = np.arange(overlap_start - row_global_start, overlap_end - row_global_start) - col_local_indices = np.arange(overlap_start - col_start, overlap_end - col_start) - sim[row_local_indices, col_local_indices] = -np.inf - - block_idx = np.argmax(sim, axis=1) - block_val = sim[np.arange(sim.shape[0]), block_idx].astype(np.float32, copy=False) - - for i in range(len(block_val)): - if block_val[i] > chunk_best_scores[i]: - chunk_best_scores[i] = block_val[i] - chunk_best_indices[i] = col_start + block_idx[i] - - best_scores[row_start:row_end] = chunk_best_scores - best_indices[row_start:row_end] = chunk_best_indices - - del rows_mm - del cols_mm - return best_indices, best_scores + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate best similarity match for each sentence against all others.""" + result = self._matcher.calculate_best_matches(sentences, batch_size, large_dataset, early_stop) + self.embeddings = self._matcher.embeddings + return result - def _calculate_best_matches_large_dataset( + def calculate_best_matches_from_embeddings( self, + embeddings: dict[str, list[float]], sentences: list[str], batch_size: int = 1024, - *, - memmap_dir: Optional[str] = None, - dtype: np.dtype = np.float32, - early_stop: int = 0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Memory-efficient exact similarity search using a disk-backed memmap.""" - num_sentences = len(sentences) - if num_sentences == 0: - return ( - np.zeros(0, dtype=np.int32), - np.zeros(0, dtype=np.float32), - np.zeros(0, dtype=object), - ) + """Calculate best matches using pre-computed embeddings.""" + return self._matcher.calculate_best_matches_from_embeddings(embeddings, sentences, batch_size) - # Determine embedding dimension and memmap paths - embedding_dim = self._infer_embedding_dim(sentences) - created_tmpdir = False - if memmap_dir is None: - memmap_dir = tempfile.mkdtemp(prefix="embedding_eval_memmap_") - created_tmpdir = True - os.makedirs(memmap_dir, exist_ok=True) - emb_path = os.path.join(memmap_dir, "embeddings.dat") - - # Phase 1: write normalized embeddings to disk - self._write_embeddings_memmap( - sentences=sentences, - emb_path=emb_path, - num_sentences=num_sentences, - embedding_dim=embedding_dim, - batch_size=batch_size, - dtype=dtype, - ) - - # Phase 2: blockwise nearest neighbour search - print("Finding best matches with blockwise dot-products ...") - row_block, col_block = self._choose_block_sizes(batch_size) - best_indices, best_scores = self._compute_blockwise_best_matches( - emb_path=emb_path, - num_sentences=num_sentences, - embedding_dim=embedding_dim, - row_block=row_block, - col_block=col_block, - dtype=dtype, - early_stop=early_stop, - ) - - decision_methods = np.full(num_sentences, "neural", dtype=object) - self._cleanup_memmap(created_tmpdir, memmap_dir, emb_path) - return best_indices, best_scores, decision_methods - - def calculate_best_matches_with_cache_large_dataset( + def calculate_best_matches_with_cache( self, - queries: list[str], + sentences: list[str], cache: list[str], batch_size: int = 1024, - *, - memmap_dir: Optional[str] = None, - dtype: np.dtype = np.float32, - sentence_offset: int = 0, - early_stop: int = 0, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Large-dataset variant: find best cache match for each sentence using memmaps. - - Writes two memmaps (rows for sentences, cols for cache), normalised, and - performs blockwise dot-products. If `sentence_offset` is provided and the - cache corresponds to the same corpus, the self-similarity diagonal is masked. - """ - num_sentences = len(queries) - num_cache = len(cache) - if num_sentences == 0 or num_cache == 0: - return ( - np.zeros(num_sentences, dtype=np.int32), - np.zeros(num_sentences, dtype=np.float32), - np.zeros(num_sentences, dtype=object), - ) - - embedding_dim = self._infer_embedding_dim(queries) - - created_tmpdir = False - if memmap_dir is None: - memmap_dir = tempfile.mkdtemp(prefix="embedding_eval_memmap_") - created_tmpdir = True - os.makedirs(memmap_dir, exist_ok=True) - - row_emb_path = os.path.join(memmap_dir, "rows_embeddings.dat") - col_emb_path = os.path.join(memmap_dir, "cols_embeddings.dat") - - # Write sentence and cache embeddings - self._write_embeddings_memmap( - sentences=queries, - emb_path=row_emb_path, - num_sentences=num_sentences, - embedding_dim=embedding_dim, - batch_size=batch_size, - dtype=dtype, - ) - # For cache we might reuse the same model; normalisation happens inside - self._write_embeddings_memmap( - sentences=cache, - emb_path=col_emb_path, - num_sentences=num_cache, - embedding_dim=embedding_dim, - batch_size=batch_size, - dtype=dtype, - ) - - row_block, col_block = self._choose_block_sizes(batch_size) - best_indices, best_scores = self._compute_blockwise_best_matches_two_sets( - row_emb_path=row_emb_path, - num_rows=num_sentences, - col_emb_path=col_emb_path, - num_cols=num_cache, - embedding_dim=embedding_dim, - row_block=row_block, - col_block=col_block, - dtype=dtype, - mask_self_similarity=(queries is cache or queries == cache), - sentence_offset=sentence_offset, - early_stop=early_stop, - ) - - decision_methods = np.full(num_sentences, "neural", dtype=object) - # Cleanup - try: - if os.path.exists(row_emb_path): - os.remove(row_emb_path) - if os.path.exists(col_emb_path): - os.remove(col_emb_path) - if created_tmpdir: - os.rmdir(memmap_dir) - except Exception: - pass - - return best_indices, best_scores, decision_methods - - def calculate_best_matches_with_cache( - self, sentences: list[str], cache: list[str], batch_size: int = 1024 - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Calculate the best similarity match for each sentence against all other - sentences using a neural embedding model. - """ - cache_embeddings = self.embed_all_sentences(cache, batch_size) - sentence_embeddings = self.embed_all_sentences(sentences, batch_size) - - out = self.calculate_best_matches_from_embeddings_with_cache( - cache_embeddings=cache_embeddings, - sentence_embeddings=sentence_embeddings, - sentences=sentences, - cache=cache, - batch_size=batch_size, - sentence_offset=0, - ) - - best_indices, best_scores, decision_methods = out - - return best_indices, best_scores, decision_methods + """Calculate best matches for each sentence against cache entries.""" + return self._matcher.calculate_best_matches_with_cache(sentences, cache, batch_size, k) def calculate_best_matches_from_embeddings_with_cache( self, @@ -463,55 +94,47 @@ def calculate_best_matches_from_embeddings_with_cache( batch_size: int = 1024, sentence_offset: int = 0, mask_self_similarity: bool = False, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Calculate the best similarity match for each sentence against all other - sentences using a neural embedding model. - """ - cache_embeddings_matrix = np.asarray([cache_embeddings[s] for s in cache], dtype=np.float32) - sentence_embeddings_matrix = np.asarray([sentence_embeddings[s] for s in sentences], dtype=np.float32) - - norms = np.linalg.norm(sentence_embeddings_matrix, axis=1, keepdims=True) - norms[norms == 0] = 1e-9 - sentence_embeddings_matrix /= norms - - norms = np.linalg.norm(cache_embeddings_matrix, axis=1, keepdims=True) - norms[norms == 0] = 1e-9 - cache_embeddings_matrix /= norms - - best_indices = np.zeros(len(sentences), dtype=np.int32) - best_scores = np.zeros(len(sentences), dtype=np.float32) - decision_methods = np.full(len(sentences), "neural", dtype=object) - - for start in tqdm( - range(0, len(sentences), batch_size), - desc="Calculating best matches with cache...", - disable=len(sentences) // batch_size < 10, - ): - end = min(start + batch_size, len(sentences)) - sentence_embedding = sentence_embeddings_matrix[start:end] - - batch_sims = sentence_embedding @ cache_embeddings_matrix.T # (batch_size, cache_size) - row_indices = np.arange(end - start) # (batch_size) - col_indices = np.arange(start, end) + """Calculate best matches using pre-computed embeddings.""" + return self._matcher.calculate_best_matches_from_embeddings_with_cache( + cache_embeddings, sentence_embeddings, sentences, cache, + batch_size, sentence_offset, mask_self_similarity, k, + ) - if ( - sentence_offset - ): # if we are calculating the best matches for a subset of sentences, we need to ignore the self-similarity - col_indices += sentence_offset + def calculate_best_matches_with_cache_large_dataset( + self, + queries: list[str], + cache: list[str], + batch_size: int = 1024, + *, + memmap_dir: Optional[str] = None, + dtype: np.dtype = np.float32, + sentence_offset: int = 0, + early_stop: int = 0, + k: int = 1, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Large-dataset matching using memory-mapped files.""" + return self._matcher.calculate_best_matches_with_cache_large_dataset( + queries, cache, batch_size, + memmap_dir=memmap_dir, dtype=dtype, + sentence_offset=sentence_offset, early_stop=early_stop, k=k, + ) - if mask_self_similarity: - batch_sims[row_indices, col_indices] = -np.inf - best_indices_batch = np.argmax( - batch_sims, axis=1 - ) # we want to find the best match for each sentence in the batch (batch_size) - best_scores_batch = batch_sims[ - row_indices, best_indices_batch - ] # we want to find the best score for each sentence in the batch (batch_size) +def get_embedding_model(model_name: str, device: str = "cpu") -> EmbeddingModel: + """ + Factory function to create an EmbeddingModel based on model name. - best_indices[start:end] = best_indices_batch - best_scores[start:end] = best_scores_batch - decision_methods[start:end] = "neural" + Args: + model_name: Model name with optional prefix: + - 'openai/...' for OpenAI models + - 'gemini/...' for Gemini models + - Otherwise, assumes HuggingFace SentenceTransformer + device: Device to use for local models ('cuda' or 'cpu'). - return best_indices, best_scores, decision_methods + Returns: + An EmbeddingModel instance. + """ + provider = get_embedding_provider(model_name, device) + return EmbeddingModel(provider) diff --git a/src/customer_analysis/embedding_providers.py b/src/customer_analysis/embedding_providers.py new file mode 100644 index 0000000..ffc058a --- /dev/null +++ b/src/customer_analysis/embedding_providers.py @@ -0,0 +1,337 @@ +""" +Embedding provider interface and implementations. + +This module defines an abstract base class for embedding providers and +concrete implementations for HuggingFace (SentenceTransformer), OpenAI, and Gemini. +""" + +import os +import time +from abc import ABC, abstractmethod +from typing import Callable, Optional, TypeVar + +import numpy as np + +T = TypeVar("T") + + +def retry_with_backoff( + func: Callable[[], T], + max_retries: int = 3, + retry_on: Callable[[Exception], bool] = lambda e: False, +) -> T: + """ + Execute a function with exponential backoff retry logic. + + Args: + func: The function to execute. + max_retries: Maximum number of retry attempts. + retry_on: A function that returns True if the exception should trigger a retry. + + Returns: + The result of the function. + + Raises: + The last exception if all retries fail. + """ + for attempt in range(max_retries): + try: + return func() + except Exception as e: + if retry_on(e) and attempt < max_retries - 1: + wait_time = 2**attempt + time.sleep(wait_time) + else: + raise + + +def normalize_embeddings(embeddings: np.ndarray) -> np.ndarray: + """L2-normalize embeddings for cosine similarity.""" + if len(embeddings) == 0: + return embeddings + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + norms[norms == 0] = 1e-9 + return embeddings / norms + + +class EmbeddingProvider(ABC): + """ + Abstract base class for embedding providers. + + All embedding providers must implement the encode() and get_embedding_dim() methods. + The encode() method should return normalized embeddings for cosine similarity. + """ + + @abstractmethod + def encode( + self, + sentences: list[str], + batch_size: int = 32, + normalize: bool = True, + show_progress: bool = False, + ) -> np.ndarray: + """ + Encode sentences to embeddings. + + Args: + sentences: A list of sentences to encode. + batch_size: Batch size for encoding. + normalize: Whether to L2-normalize embeddings (for cosine similarity). + show_progress: Whether to show a progress bar. + + Returns: + np.ndarray: A (N, D) array of embeddings where N is the number of sentences + and D is the embedding dimension. + """ + pass + + @abstractmethod + def get_embedding_dim(self) -> int: + """ + Return the embedding dimension for this provider. + + Returns: + int: The embedding dimension. + """ + pass + + @property + def provider_name(self) -> str: + """Return the name of this provider for decision_methods tracking.""" + return "unknown" + + +class HuggingFaceProvider(EmbeddingProvider): + """Embedding provider using HuggingFace SentenceTransformer models.""" + + def __init__(self, model_name: str, device: str = "cpu"): + from sentence_transformers import SentenceTransformer + + self.model_name = model_name + self.device = device + self.model = SentenceTransformer( + model_name, device=device, local_files_only=False, trust_remote_code=True + ) + self._embedding_dim: Optional[int] = None + + def encode( + self, + sentences: list[str], + batch_size: int = 32, + normalize: bool = True, + show_progress: bool = False, + ) -> np.ndarray: + if not sentences: + return np.array([], dtype=np.float32).reshape(0, self.get_embedding_dim()) + + embeddings = self.model.encode( + sentences, + batch_size=batch_size, + normalize_embeddings=normalize, + show_progress_bar=show_progress, + convert_to_numpy=True, + ) + return embeddings.astype(np.float32, copy=False) + + def get_embedding_dim(self) -> int: + if self._embedding_dim is None: + probe = self.model.encode(["test"], convert_to_numpy=True) + self._embedding_dim = int(probe.shape[1]) + return self._embedding_dim + + @property + def provider_name(self) -> str: + return "neural" + + +class APIProvider(EmbeddingProvider): + """Base class for API-based embedding providers with common functionality.""" + + MODEL_DIMENSIONS: dict[str, int] = {} + + def __init__(self, model_name: str): + self._embedding_dim: Optional[int] = None + self.model_name = self._extract_model_name(model_name) + + def _extract_model_name(self, model_name: str) -> str: + """Extract model name by removing provider prefix.""" + prefixes = ["openai/", "gemini/"] + for prefix in prefixes: + if model_name.startswith(prefix): + return model_name[len(prefix) :] + return model_name + + def _is_rate_limit_error(self, e: Exception) -> bool: + """Check if exception is a rate limit error.""" + error_str = str(e).lower() + return "rate" in error_str or "quota" in error_str or "rate_limit" in error_str + + def _encode_batch(self, batch: list[str]) -> list[list[float]]: + """Encode a single batch. Must be implemented by subclasses.""" + raise NotImplementedError + + def _encode_with_batching( + self, sentences: list[str], batch_size: int, show_progress: bool + ) -> list[list[float]]: + """Encode sentences with batching and retry logic.""" + from tqdm import tqdm + + all_embeddings: list[list[float]] = [] + iterator = range(0, len(sentences), batch_size) + if show_progress: + iterator = tqdm(iterator, desc=f"Embedding sentences ({self.provider_name})...") + + for i in iterator: + batch = sentences[i : i + batch_size] + batch_embeddings = retry_with_backoff( + func=lambda b=batch: self._encode_batch(b), + max_retries=3, + retry_on=self._is_rate_limit_error, + ) + all_embeddings.extend(batch_embeddings) + + return all_embeddings + + def encode( + self, + sentences: list[str], + batch_size: int = 100, + normalize: bool = True, + show_progress: bool = False, + ) -> np.ndarray: + if not sentences: + return np.array([], dtype=np.float32).reshape(0, self.get_embedding_dim()) + + all_embeddings = self._encode_with_batching(sentences, batch_size, show_progress) + embeddings = np.array(all_embeddings, dtype=np.float32) + + if normalize: + embeddings = normalize_embeddings(embeddings) + + return embeddings + + def get_embedding_dim(self) -> int: + if self._embedding_dim is None: + if self.model_name in self.MODEL_DIMENSIONS: + self._embedding_dim = self.MODEL_DIMENSIONS[self.model_name] + else: + probe = self.encode(["test"], normalize=False) + self._embedding_dim = probe.shape[1] + return self._embedding_dim + + +class OpenAIProvider(APIProvider): + """Embedding provider using OpenAI's API.""" + + MODEL_DIMENSIONS = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + } + + def __init__(self, model_name: str, device: str = "cpu"): + try: + from openai import OpenAI + except ImportError: + raise ImportError( + "openai package is required for OpenAI embeddings. " + "Install it with: pip install openai" + ) + + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY environment variable is not set. " + "Please set it to use OpenAI embeddings." + ) + + super().__init__(model_name) + self.client = OpenAI(api_key=api_key) + + def _encode_batch(self, batch: list[str]) -> list[list[float]]: + response = self.client.embeddings.create(input=batch, model=self.model_name) + return [item.embedding for item in response.data] + + @property + def provider_name(self) -> str: + return "openai" + + +class GeminiProvider(APIProvider): + """Embedding provider using Google's Gemini API.""" + + MODEL_DIMENSIONS = { + "text-embedding-004": 768, + "text-embedding-005": 768, + "embedding-001": 768, + } + MAX_BATCH_SIZE = 100 # Gemini API limit + + def __init__(self, model_name: str, device: str = "cpu"): + try: + from google import genai + except ImportError: + raise ImportError( + "google-genai package is required for Gemini embeddings. " + "Install it with: pip install google-genai" + ) + + api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY") + if not api_key: + raise ValueError( + "GEMINI_API_KEY or GOOGLE_API_KEY environment variable is not set. " + "Please set it to use Gemini embeddings." + ) + + super().__init__(model_name) + self.client = genai.Client(api_key=api_key) + self.model_path = self._get_model_path() + + def _get_model_path(self) -> str: + if self.model_name.startswith("models/"): + return self.model_name + return f"models/{self.model_name}" + + def _encode_batch(self, batch: list[str]) -> list[list[float]]: + response = self.client.models.embed_content( + model=self.model_path, + contents=batch, + ) + return [embedding.values for embedding in response.embeddings] + + def encode( + self, + sentences: list[str], + batch_size: int = 100, + normalize: bool = True, + show_progress: bool = False, + ) -> np.ndarray: + # Enforce Gemini's batch size limit + batch_size = min(batch_size, self.MAX_BATCH_SIZE) + return super().encode(sentences, batch_size, normalize, show_progress) + + @property + def provider_name(self) -> str: + return "gemini" + + +def get_embedding_provider(model_name: str, device: str = "cpu") -> EmbeddingProvider: + """ + Factory function to get the appropriate embedding provider based on model name. + + Args: + model_name: Model name with optional prefix: + - 'openai/text-embedding-3-small' for OpenAI models + - 'gemini/text-embedding-004' for Gemini models + - Otherwise, assumes HuggingFace SentenceTransformer model + device: Device to use for local models ('cuda' or 'cpu'). + + Returns: + An EmbeddingProvider instance. + """ + if model_name.startswith("openai/"): + return OpenAIProvider(model_name, device=device) + elif model_name.startswith("gemini/"): + return GeminiProvider(model_name, device=device) + else: + return HuggingFaceProvider(model_name, device=device) diff --git a/src/customer_analysis/metrics_util.py b/src/customer_analysis/metrics_util.py index 64ea008..c603007 100644 --- a/src/customer_analysis/metrics_util.py +++ b/src/customer_analysis/metrics_util.py @@ -64,8 +64,9 @@ def sweep_thresholds_on_results(results_df: pd.DataFrame) -> pd.DataFrame: """Perform threshold sweep and return results.""" print("\nPerforming threshold sweep") min_score = results_df["similarity_score"].min() - steps = max(min(200, len(results_df)), 1) # At least 1 step, cannot be 0, and max length of 200 - thresholds = np.linspace(min_score, 1.0, steps) + max_score = results_df["similarity_score"].max() + steps = 200 + thresholds = np.linspace(min_score, 1.0 if max_score < 1.0 else max_score, steps) results = [] for i, threshold in enumerate(thresholds): diff --git a/src/customer_analysis/query_engine.py b/src/customer_analysis/query_engine.py index c44b2f6..1d4123e 100644 --- a/src/customer_analysis/query_engine.py +++ b/src/customer_analysis/query_engine.py @@ -2,14 +2,13 @@ import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import numpy as np import pandas as pd from redisvl.index import SearchIndex from redisvl.query import VectorQuery from redisvl.schema import IndexSchema -from sentence_transformers import SentenceTransformer try: import torch @@ -19,10 +18,16 @@ _HAS_TORCH = False +def _is_api_model(model_name: str) -> bool: + """Check if model name refers to an API-based model (OpenAI or Gemini).""" + return model_name.startswith("openai/") or model_name.startswith("gemini/") + + @dataclass class RedisVectorIndex: """ - RedisVL vector index backed by a local SentenceTransformer model. + RedisVL vector index backed by either a local SentenceTransformer model, + OpenAI embeddings API, or Gemini embeddings API. Parameters ---------- @@ -33,11 +38,14 @@ class RedisVectorIndex: prefix : str Document key prefix (e.g. "cache:"). model_name : str - SentenceTransformer model name or local path (e.g. "all-MiniLM-L6-v2"). + SentenceTransformer model name, local path (e.g. "all-MiniLM-L6-v2"), + OpenAI model with 'openai/' prefix (e.g. "openai/text-embedding-3-small"), + or Gemini model with 'gemini/' prefix (e.g. "gemini/text-embedding-004"). redis_url : str Redis connection URL (default "redis://localhost:6379"). device : str "cuda" or "cpu". If None, auto-selects CUDA when available. + Ignored for API-based models. batch_size : int Batch size for encoder.encode(). additional_fields : list[dict] @@ -54,12 +62,18 @@ class RedisVectorIndex: additional_fields: List[Dict[str, Any]] = field(default_factory=list) def __post_init__(self): - # 0) init local embedding model + self._is_api = _is_api_model(self.model_name) + + # Use the new embedding provider architecture + from src.customer_analysis.embedding_providers import get_embedding_provider + device = self.device or ("cuda" if _HAS_TORCH and torch.cuda.is_available() else "cpu") - self.model = SentenceTransformer(self.model_name, device=device) - self.embed_dim = int(self.model.get_sentence_embedding_dimension()) + self._provider = get_embedding_provider(self.model_name, device=device) + + # Get embedding dimension + self.embed_dim = self._provider.get_embedding_dim() - # 1) ensure Redis index exists (schema dims come from the model) + # Ensure Redis index exists (schema dims come from the model) schema_dict = { "index": {"name": self.index_name, "prefix": self.prefix}, "fields": [ @@ -79,16 +93,16 @@ def __post_init__(self): } schema = IndexSchema.from_dict(schema_dict) self.index: SearchIndex = SearchIndex(schema, redis_url=self.redis_url) - if not self.index.exists(): - self.index.create(overwrite=False) + # Always overwrite to ensure schema matches the current model dimensions + self.index.create(overwrite=True) def _embed_batch(self, texts: List[str]) -> np.ndarray: - vecs = self.model.encode( + """Embed a batch of texts using the configured provider.""" + vecs = self._provider.encode( texts, batch_size=self.batch_size, - convert_to_numpy=True, - normalize_embeddings=False, # leave unnormalized; Redis uses true cosine - show_progress_bar=False, + normalize=False, # leave unnormalized; Redis uses true cosine + show_progress=False, ) # ensure float32 if vecs.dtype != np.float32: @@ -134,9 +148,12 @@ def drop(self): except Exception: pass # best-effort free model memory on CUDA - try: - if _HAS_TORCH and self.model.device.type == "cuda": - del self.model - torch.cuda.empty_cache() - except Exception: - pass + if not self._is_api: + try: + if _HAS_TORCH and hasattr(self._provider, 'model'): + model = self._provider.model + if hasattr(model, 'device') and model.device.type == "cuda": + del self._provider + torch.cuda.empty_cache() + except Exception: + pass diff --git a/src/customer_analysis/similarity_matcher.py b/src/customer_analysis/similarity_matcher.py new file mode 100644 index 0000000..3a4d909 --- /dev/null +++ b/src/customer_analysis/similarity_matcher.py @@ -0,0 +1,477 @@ +""" +Similarity matching logic for semantic cache evaluation. + +This module contains the SimilarityMatcher class which performs similarity-based +matching between queries and cache entries using embedding providers. +""" + +import os +import tempfile +from typing import Optional + +import numpy as np +from tqdm import tqdm + +from src.customer_analysis.embedding_providers import EmbeddingProvider, normalize_embeddings + + +# ------------------------------ +# Top-K Helper Functions +# ------------------------------ + + +def find_top1(sim: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Find the best match (top-1) for each row in similarity matrix.""" + best_idx = np.argmax(sim, axis=1) + best_val = sim[np.arange(sim.shape[0]), best_idx] + return best_idx, best_val.astype(np.float32, copy=False) + + +def find_topk(sim: np.ndarray, k: int) -> tuple[np.ndarray, np.ndarray]: + """Find the top-k matches for each row in similarity matrix.""" + if sim.shape[1] <= k: + # Fewer candidates than k - return all sorted + top_k_idx = np.argsort(-sim, axis=1) + top_k_val = np.take_along_axis(sim, top_k_idx, axis=1) + else: + # Use partial sort for efficiency + part_idx = np.argpartition(-sim, k, axis=1)[:, :k] + top_k_val = np.take_along_axis(sim, part_idx, axis=1) + sorted_sub_idx = np.argsort(-top_k_val, axis=1) + top_k_val = np.take_along_axis(top_k_val, sorted_sub_idx, axis=1) + top_k_idx = np.take_along_axis(part_idx, sorted_sub_idx, axis=1) + return top_k_idx, top_k_val + + +def merge_topk( + current_idx: np.ndarray, + current_val: np.ndarray, + new_idx: np.ndarray, + new_val: np.ndarray, + k: int, +) -> tuple[np.ndarray, np.ndarray]: + """Merge two sets of top-k results and keep the best k.""" + combined_vals = np.concatenate([current_val, new_val], axis=1) + combined_idxs = np.concatenate([current_idx, new_idx], axis=1) + best_args = np.argsort(-combined_vals, axis=1)[:, :k] + return ( + np.take_along_axis(combined_idxs, best_args, axis=1), + np.take_along_axis(combined_vals, best_args, axis=1), + ) + + +def update_best_k1( + chunk_idx: np.ndarray, + chunk_val: np.ndarray, + block_idx: np.ndarray, + block_val: np.ndarray, + col_offset: int, +) -> None: + """Update best scores in-place for k=1 case.""" + better = block_val > chunk_val + chunk_val[better] = block_val[better] + chunk_idx[better] = col_offset + block_idx[better] + + +# ------------------------------ +# Self-Similarity Masking +# ------------------------------ + + +def mask_self_similarity_block( + sim: np.ndarray, + row_start: int, + row_end: int, + col_start: int, + col_end: int, + sentence_offset: int, +) -> None: + """Mask diagonal entries in similarity matrix to avoid self-matching.""" + row_global_start = row_start + sentence_offset + row_global_end = row_end + sentence_offset + overlap_start = max(row_global_start, col_start) + overlap_end = min(row_global_end, col_end) + + if overlap_start < overlap_end: + row_local = np.arange(overlap_start - row_global_start, overlap_end - row_global_start) + col_local = np.arange(overlap_start - col_start, overlap_end - col_start) + sim[row_local, col_local] = -np.inf + + +# ------------------------------ +# Result Initialization +# ------------------------------ + + +def init_results(n: int, k: int) -> tuple[np.ndarray, np.ndarray]: + """Initialize result arrays for indices and scores.""" + if k == 1: + return np.zeros(n, dtype=np.int32), np.full(n, -np.inf, dtype=np.float32) + return np.zeros((n, k), dtype=np.int32), np.full((n, k), -np.inf, dtype=np.float32) + + +# ------------------------------ +# SimilarityMatcher Class +# ------------------------------ + + +class SimilarityMatcher: + """ + Similarity matcher for finding best matches between queries and cache entries. + + This class handles all the matching logic, including: + - Embedding all sentences + - Computing cosine similarity + - Finding top-k matches + - Memory-efficient large dataset handling with memmaps + """ + + def __init__(self, provider: EmbeddingProvider): + """Initialize the SimilarityMatcher with an embedding provider.""" + self.provider = provider + self.embeddings: Optional[dict[str, list[float]]] = None + + def embed_all_sentences( + self, sentences: list[str], batch_size: int = 32 + ) -> dict[str, list[float]]: + """Embed all unique sentences and return a dictionary mapping sentences to embeddings.""" + sentence_to_embeddings: dict[str, list[float]] = {} + sentence_list = list(set(sentences)) + total = len(sentence_list) + + print(f"Embedding {total} unique sentences in batches of {batch_size} ...") + + for start in tqdm(range(0, total, batch_size), desc="Embedding sentences..."): + end = min(start + batch_size, total) + batch = sentence_list[start:end] + batch_embs = self.provider.encode(batch, batch_size=batch_size, normalize=False) + for sent, emb in zip(batch, batch_embs): + sentence_to_embeddings[sent] = emb.tolist() if hasattr(emb, "tolist") else list(emb) + + return sentence_to_embeddings + + def calculate_best_matches( + self, + sentences: list[str], + batch_size: int = 32, + large_dataset: bool = False, + early_stop: int = 0, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate the best similarity match for each sentence against all other sentences.""" + if not large_dataset: + self.embeddings = self.embed_all_sentences(sentences, batch_size) + return self.calculate_best_matches_from_embeddings(self.embeddings, sentences, batch_size) + return self._calculate_best_matches_large_dataset(sentences, batch_size, early_stop=early_stop) + + def calculate_best_matches_from_embeddings( + self, + embeddings: dict[str, list[float]], + sentences: list[str], + batch_size: int = 1024, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate best similarity matches using pre-computed embeddings (self-matching).""" + best_indices, best_scores = init_results(len(sentences), k=1) + decision_methods = np.full(len(sentences), self.provider.provider_name, dtype=object) + + for batch_start in tqdm(range(0, len(sentences), batch_size), desc="Calculating best matches..."): + batch_end = min(batch_start + batch_size, len(sentences)) + out = self.calculate_best_matches_from_embeddings_with_cache( + cache_embeddings=embeddings, + sentence_embeddings=embeddings, + sentences=sentences[batch_start:batch_end], + cache=sentences, + batch_size=batch_size, + sentence_offset=batch_start, + mask_self_similarity=True, + ) + best_indices[batch_start:batch_end] = out[0] + best_scores[batch_start:batch_end] = out[1] + + return best_indices, best_scores, decision_methods + + def calculate_best_matches_with_cache( + self, + sentences: list[str], + cache: list[str], + batch_size: int = 1024, + k: int = 1, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate the best similarity match for each sentence against cache entries.""" + cache_embeddings = self.embed_all_sentences(cache, batch_size) + sentence_embeddings = self.embed_all_sentences(sentences, batch_size) + return self.calculate_best_matches_from_embeddings_with_cache( + cache_embeddings=cache_embeddings, + sentence_embeddings=sentence_embeddings, + sentences=sentences, + cache=cache, + batch_size=batch_size, + sentence_offset=0, + k=k, + ) + + def calculate_best_matches_from_embeddings_with_cache( + self, + cache_embeddings: dict[str, list[float]], + sentence_embeddings: dict[str, list[float]], + sentences: list[str], + cache: list[str], + batch_size: int = 1024, + sentence_offset: int = 0, + mask_self_similarity: bool = False, + k: int = 1, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Calculate the best similarity match using pre-computed embeddings.""" + # Build and normalize embedding matrices + cache_matrix = normalize_embeddings( + np.asarray([cache_embeddings[s] for s in cache], dtype=np.float32) + ) + sentence_matrix = normalize_embeddings( + np.asarray([sentence_embeddings[s] for s in sentences], dtype=np.float32) + ) + + best_indices, best_scores = init_results(len(sentences), k) + decision_methods = np.full(len(sentences), self.provider.provider_name, dtype=object) + + for start in tqdm( + range(0, len(sentences), batch_size), + desc="Calculating best matches with cache...", + disable=len(sentences) // batch_size < 10, + ): + end = min(start + batch_size, len(sentences)) + batch_sims = sentence_matrix[start:end] @ cache_matrix.T + + if mask_self_similarity: + self._mask_batch_self_similarity(batch_sims, start, end, sentence_offset, len(cache)) + + if k == 1: + idx, val = find_top1(batch_sims) + else: + idx, val = find_topk(batch_sims, k) + + best_indices[start:end] = idx + best_scores[start:end] = val + + return best_indices, best_scores, decision_methods + + def _mask_batch_self_similarity( + self, batch_sims: np.ndarray, start: int, end: int, offset: int, cache_len: int + ) -> None: + """Mask self-similarity in a batch similarity matrix.""" + row_indices = np.arange(end - start) + col_indices = np.arange(start, end) + offset + valid = col_indices < cache_len + if np.any(valid): + batch_sims[row_indices[valid], col_indices[valid]] = -np.inf + + # ------------------------------ + # Large dataset methods + # ------------------------------ + + def _infer_embedding_dim(self, sentences: list[str]) -> int: + return self.provider.get_embedding_dim() + + def _prepare_memmap_dir(self, memmap_dir: Optional[str]) -> tuple[bool, str, str]: + """Ensure a directory exists for memmap files.""" + created = memmap_dir is None + if created: + memmap_dir = tempfile.mkdtemp(prefix="embedding_eval_memmap_") + else: + os.makedirs(memmap_dir, exist_ok=True) + return created, memmap_dir, os.path.join(memmap_dir, "embeddings.dat") + + def _write_embeddings_memmap( + self, + sentences: list[str], + emb_path: str, + num_sentences: int, + embedding_dim: int, + batch_size: int, + dtype: np.dtype, + ) -> None: + """Encode sentences and write normalized embeddings to memmap.""" + mm = np.memmap(emb_path, mode="w+", dtype=dtype, shape=(num_sentences, embedding_dim)) + print(f"Encoding and writing {num_sentences} embeddings to memmap at {emb_path} ...") + + for start in tqdm(range(0, num_sentences, batch_size), desc="Encoding (memmap)..."): + end = min(start + batch_size, num_sentences) + batch_embs = self.provider.encode(sentences[start:end], batch_size=batch_size, normalize=True) + mm[start:end] = batch_embs.astype(dtype, copy=False) + + mm.flush() + del mm + + def _choose_block_sizes(self, batch_size: int) -> tuple[int, int]: + """Pick conservative row/col block sizes to bound peak memory.""" + max_block_bytes = 128 * 1024 * 1024 + row_block = min(batch_size, 4096) + col_block = max(512, min(batch_size, int(max_block_bytes / 4 / max(1, row_block)))) + return row_block, col_block + + def _compute_blockwise_best_matches( + self, + emb_path: str, + num_sentences: int, + embedding_dim: int, + row_block: int, + col_block: int, + dtype: np.dtype, + early_stop: int = 0, + ) -> tuple[np.ndarray, np.ndarray]: + """Blockwise exact nearest-neighbour (k=1) with self-similarity masking.""" + n = early_stop if early_stop > 0 else num_sentences + best_indices, best_scores = init_results(n, k=1) + mm = np.memmap(emb_path, mode="r", dtype=dtype, shape=(n, embedding_dim)) + + for row_start in tqdm(range(0, n, row_block), desc="Row blocks"): + row_end = min(row_start + row_block, n) + row_emb = np.asarray(mm[row_start:row_end]) + chunk_idx, chunk_val = init_results(row_end - row_start, k=1) + + for col_start in range(0, n, col_block): + col_end = min(col_start + col_block, n) + sim = row_emb @ np.asarray(mm[col_start:col_end]).T + mask_self_similarity_block(sim, row_start, row_end, col_start, col_end, 0) + block_idx, block_val = find_top1(sim) + update_best_k1(chunk_idx, chunk_val, block_idx, block_val, col_start) + + best_indices[row_start:row_end] = chunk_idx + best_scores[row_start:row_end] = chunk_val + + del mm + return best_indices, best_scores + + def _cleanup_memmap(self, created: bool, memmap_dir: str, emb_path: str) -> None: + """Best-effort cleanup of memmap files.""" + if not created: + return + try: + if os.path.exists(emb_path): + os.remove(emb_path) + os.rmdir(memmap_dir) + except Exception: + pass + + def _compute_blockwise_best_matches_two_sets( + self, + row_emb_path: str, + num_rows: int, + col_emb_path: str, + num_cols: int, + embedding_dim: int, + row_block: int, + col_block: int, + dtype: np.dtype, + *, + mask_self_similarity: bool = False, + sentence_offset: int = 0, + early_stop: int = 0, + k: int = 1, + ) -> tuple[np.ndarray, np.ndarray]: + """Blockwise nearest-neighbour where rows and columns come from two sets.""" + n_rows = early_stop if early_stop > 0 else num_rows + best_indices, best_scores = init_results(n_rows, k) + + rows_mm = np.memmap(row_emb_path, mode="r", dtype=dtype, shape=(n_rows, embedding_dim)) + cols_mm = np.memmap(col_emb_path, mode="r", dtype=dtype, shape=(num_cols, embedding_dim)) + + for row_start in tqdm(range(0, n_rows, row_block), desc="Row blocks (two-sets)"): + row_end = min(row_start + row_block, n_rows) + row_emb = np.asarray(rows_mm[row_start:row_end]) + chunk_idx, chunk_val = init_results(row_end - row_start, k) + + for col_start in range(0, num_cols, col_block): + col_end = min(col_start + col_block, num_cols) + sim = row_emb @ np.asarray(cols_mm[col_start:col_end]).T + + if mask_self_similarity: + mask_self_similarity_block(sim, row_start, row_end, col_start, col_end, sentence_offset) + + if k == 1: + block_idx, block_val = find_top1(sim) + update_best_k1(chunk_idx, chunk_val, block_idx, block_val, col_start) + else: + block_idx, block_val = find_topk(sim, k) + chunk_idx, chunk_val = merge_topk(chunk_idx, chunk_val, block_idx + col_start, block_val, k) + + best_indices[row_start:row_end] = chunk_idx + best_scores[row_start:row_end] = chunk_val + + del rows_mm, cols_mm + return best_indices, best_scores + + def _calculate_best_matches_large_dataset( + self, + sentences: list[str], + batch_size: int = 1024, + *, + memmap_dir: Optional[str] = None, + dtype: np.dtype = np.float32, + early_stop: int = 0, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Memory-efficient exact similarity search using disk-backed memmap.""" + if len(sentences) == 0: + return np.zeros(0, dtype=np.int32), np.zeros(0, dtype=np.float32), np.zeros(0, dtype=object) + + embedding_dim = self._infer_embedding_dim(sentences) + created, memmap_dir, emb_path = self._prepare_memmap_dir(memmap_dir) + + self._write_embeddings_memmap(sentences, emb_path, len(sentences), embedding_dim, batch_size, dtype) + + print("Finding best matches with blockwise dot-products ...") + row_block, col_block = self._choose_block_sizes(batch_size) + best_indices, best_scores = self._compute_blockwise_best_matches( + emb_path, len(sentences), embedding_dim, row_block, col_block, dtype, early_stop + ) + + self._cleanup_memmap(created, memmap_dir, emb_path) + return best_indices, best_scores, np.full(len(sentences), self.provider.provider_name, dtype=object) + + def calculate_best_matches_with_cache_large_dataset( + self, + queries: list[str], + cache: list[str], + batch_size: int = 1024, + *, + memmap_dir: Optional[str] = None, + dtype: np.dtype = np.float32, + sentence_offset: int = 0, + early_stop: int = 0, + k: int = 1, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Large-dataset variant: find best cache match for each query using memmaps.""" + num_queries, num_cache = len(queries), len(cache) + if num_queries == 0 or num_cache == 0: + idx, scores = init_results(num_queries, k) + return idx, scores, np.zeros(num_queries, dtype=object) + + embedding_dim = self._infer_embedding_dim(queries) + created, memmap_dir, _ = self._prepare_memmap_dir(memmap_dir) + row_path = os.path.join(memmap_dir, "rows_embeddings.dat") + col_path = os.path.join(memmap_dir, "cols_embeddings.dat") + + self._write_embeddings_memmap(queries, row_path, num_queries, embedding_dim, batch_size, dtype) + self._write_embeddings_memmap(cache, col_path, num_cache, embedding_dim, batch_size, dtype) + + row_block, col_block = self._choose_block_sizes(batch_size) + best_indices, best_scores = self._compute_blockwise_best_matches_two_sets( + row_path, num_queries, col_path, num_cache, embedding_dim, + row_block, col_block, dtype, + mask_self_similarity=(queries is cache or queries == cache), + sentence_offset=sentence_offset, + early_stop=early_stop, + k=k, + ) + + # Cleanup + for path in [row_path, col_path]: + try: + os.remove(path) + except Exception: + pass + if created: + try: + os.rmdir(memmap_dir) + except Exception: + pass + + return best_indices, best_scores, np.full(num_queries, self.provider.provider_name, dtype=object)