diff --git a/python-recipes/recommendation-systems/01_collaborative_filtering.ipynb b/python-recipes/recommendation-systems/01_collaborative_filtering.ipynb index bdc31da4..c249a039 100644 --- a/python-recipes/recommendation-systems/01_collaborative_filtering.ipynb +++ b/python-recipes/recommendation-systems/01_collaborative_filtering.ipynb @@ -6,7 +6,7 @@ "source": [ "![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)\n", "\n", - "# Collaborative Filtering in RedisVL\n", + "# Recommendation Systems: Collaborative Filtering in RedisVL\n", "\n", "\"Open" ] diff --git a/python-recipes/recommendation-systems/02_two_towers.ipynb b/python-recipes/recommendation-systems/02_two_towers.ipynb new file mode 100644 index 00000000..953f8a40 --- /dev/null +++ b/python-recipes/recommendation-systems/02_two_towers.ipynb @@ -0,0 +1,5114 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Redis](https://redis.io/wp-content/uploads/2024/04/Logotype.svg?auto=webp&quality=85,75&width=120)\n", + "\n", + "# Recommendation Systems: Two Tower Deep Learning Models with RedisVL\n", + "\n", + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Recommendation systems are a common application of machine learning and serve many industries from e-commerce to music streaming platforms.\n", + "\n", + "There are many different architectures that can be followed to build a recommendation system. In previous example notebooks we demonstrated two common approaches that leverage different methods. Our first showed how to do [content filtering with RedisVL](content_filtering.ipynb) where an item's underlying features determine what gets recommended.\n", + "Next, we showcased how RedisVL can be used to build a [collaborative filtering recommender](collaborative_filtering.ipynb), which leverages users' ratings of items to create personalized recommendations. Before continuing with this notebook we encourage you to start with the previous two.\n", + "\n", + "In this notebook we'll demonstrate how to build a [two tower recommendation system](https://cloud.google.com/blog/products/ai-machine-learning/scaling-deep-retrieval-tensorflow-two-towers-architecture)\n", + "and compare it to architectures we've seen before.\n", + "\n", + "Let's begin!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To mix things up a bit, instead of using our movies dataset like the previous two examples, we'll look at brick & mortar restaurants in San Francisco as our items to recommend." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "import pandas as pd\n", + "import json\n", + "\n", + "# Replace values below with your own if using Redis Cloud instance\n", + "REDIS_HOST = os.getenv(\"REDIS_HOST\", \"localhost\") # ex: \"redis-18374.c253.us-central1-1.gce.cloud.redislabs.com\"\n", + "REDIS_PORT = os.getenv(\"REDIS_PORT\", \"6379\") # ex: 18374\n", + "REDIS_PASSWORD = os.getenv(\"REDIS_PASSWORD\", \"\") # ex: \"1TNxTEdYRDgIDKM2gDfasupCADXXXX\"\n", + "\n", + "# If SSL is enabled on the endpoint, use rediss:// as the URL prefix\n", + "REDIS_URL = f\"redis://:{REDIS_PASSWORD}@{REDIS_HOST}:{REDIS_PORT}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "def fetch_data(file_name):\n", + " dataset_path = 'datasets/two_towers/'\n", + " try:\n", + " with open(dataset_path + file_name, 'r') as f:\n", + " return json.load(f)\n", + " except:\n", + " url = 'https://redis-ai-resources.s3.us-east-2.amazonaws.com/recommenders/datasets/two-towers/'\n", + " r = requests.get(url + file_name)\n", + " if not os.path.exists(dataset_path):\n", + " os.makedirs(dataset_path)\n", + " with open(dataset_path + file_name, 'wb') as f:\n", + " f.write(r.content)\n", + " return json.loads(r.content.decode('utf-8'))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "we have 147 restaurants in our dataset, with 14700 total reviews\n" + ] + } + ], + "source": [ + "# the original dataset can be found here: https://www.kaggle.com/datasets/jkgatt/restaurant-data-with-100-trip-advisor-reviews-each\n", + "\n", + "restaurant_data = fetch_data('factual_tripadvisor_restaurant_data_all_100_reviews.json')\n", + "\n", + "print(f\"we have {restaurant_data['restaurant_count']} restaurants in our dataset, with {restaurant_data['total_review_count']} total reviews\")\n", + "\n", + "restaurant_data = restaurant_data[\"restaurants\"] # ignore the count fields" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nameaddresslocalitylatitudelongitudecuisinepriceratinghoursparking...meal_takeoutmeal_cateroptions_healthyoptions_organicoptions_vegetarianoptions_veganoptions_glutenfreeoptions_lowfatreviewsunique_name
021st Amendment Brewery & Restaurant563 2nd StSan Francisco37.782448-122.392576[Cafe, Pub Food, American, Burgers, Pizza]24.0{'monday': [['11:30', '23:59']], 'tuesday': [[...True...TrueFalseTrueFalseTrueFalseFalseFalse[{'review_website': 'TripAdvisor', 'review_url...21st Amendment Brewery & Restaurant 563 2nd St
1Absinthe Brasserie & Bar398 Hayes StSan Francisco37.777083-122.422882[French, Californian, Mediterranean, Cafe, Ame...34.0{'tuesday': [['11:30', '23:59']], 'wednesday':...True...TrueTrueTrueFalseTrueFalseFalseFalse[{'review_website': 'TripAdvisor', 'review_url...Absinthe Brasserie & Bar 398 Hayes St
2Amber India Restaurant25 Yerba Buena LnSan Francisco37.785772-122.404401[Indian, Chinese, Vegetarian, Asian, Pakistani]24.5{'monday': [['11:30', '14:30'], ['17:00', '22:...True...TrueTrueTrueFalseTrueTrueTrueFalse[{'review_website': 'TripAdvisor', 'review_url...Amber India Restaurant 25 Yerba Buena Ln
3Americano8 Mission StSan Francisco37.793620-122.392915[Italian, American, Californian, Pub Food, Cafe]33.5{'monday': [['6:30', '10:30'], ['11:30', '14:3...True...TrueTrueTrueFalseTrueFalseFalseFalse[{'review_website': 'TripAdvisor', 'review_url...Americano 8 Mission St
4Anchor & Hope83 Minna StSan Francisco37.787848-122.398812[Seafood, American, Cafe, Chowder, Californian]34.0{'monday': [['11:30', '14:00'], ['17:30', '22:...True...TrueTrueTrueFalseTrueTrueTrueFalse[{'review_website': 'TripAdvisor', 'review_url...Anchor & Hope 83 Minna St
\n", + "

5 rows × 33 columns

\n", + "
" + ], + "text/plain": [ + " name address locality \\\n", + "0 21st Amendment Brewery & Restaurant 563 2nd St San Francisco \n", + "1 Absinthe Brasserie & Bar 398 Hayes St San Francisco \n", + "2 Amber India Restaurant 25 Yerba Buena Ln San Francisco \n", + "3 Americano 8 Mission St San Francisco \n", + "4 Anchor & Hope 83 Minna St San Francisco \n", + "\n", + " latitude longitude cuisine \\\n", + "0 37.782448 -122.392576 [Cafe, Pub Food, American, Burgers, Pizza] \n", + "1 37.777083 -122.422882 [French, Californian, Mediterranean, Cafe, Ame... \n", + "2 37.785772 -122.404401 [Indian, Chinese, Vegetarian, Asian, Pakistani] \n", + "3 37.793620 -122.392915 [Italian, American, Californian, Pub Food, Cafe] \n", + "4 37.787848 -122.398812 [Seafood, American, Cafe, Chowder, Californian] \n", + "\n", + " price rating hours parking \\\n", + "0 2 4.0 {'monday': [['11:30', '23:59']], 'tuesday': [[... True \n", + "1 3 4.0 {'tuesday': [['11:30', '23:59']], 'wednesday':... True \n", + "2 2 4.5 {'monday': [['11:30', '14:30'], ['17:00', '22:... True \n", + "3 3 3.5 {'monday': [['6:30', '10:30'], ['11:30', '14:3... True \n", + "4 3 4.0 {'monday': [['11:30', '14:00'], ['17:30', '22:... True \n", + "\n", + " ... meal_takeout meal_cater options_healthy options_organic \\\n", + "0 ... True False True False \n", + "1 ... True True True False \n", + "2 ... True True True False \n", + "3 ... True True True False \n", + "4 ... True True True False \n", + "\n", + " options_vegetarian options_vegan options_glutenfree options_lowfat \\\n", + "0 True False False False \n", + "1 True False False False \n", + "2 True True True False \n", + "3 True False False False \n", + "4 True True True False \n", + "\n", + " reviews \\\n", + "0 [{'review_website': 'TripAdvisor', 'review_url... \n", + "1 [{'review_website': 'TripAdvisor', 'review_url... \n", + "2 [{'review_website': 'TripAdvisor', 'review_url... \n", + "3 [{'review_website': 'TripAdvisor', 'review_url... \n", + "4 [{'review_website': 'TripAdvisor', 'review_url... \n", + "\n", + " unique_name \n", + "0 21st Amendment Brewery & Restaurant 563 2nd St \n", + "1 Absinthe Brasserie & Bar 398 Hayes St \n", + "2 Amber India Restaurant 25 Yerba Buena Ln \n", + "3 Americano 8 Mission St \n", + "4 Anchor & Hope 83 Minna St \n", + "\n", + "[5 rows x 33 columns]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.DataFrame(restaurant_data)\n", + "\n", + "df.fillna('', inplace=True)\n", + "\n", + "df.drop(columns=['region', 'country', 'tel','fax', 'email', 'website', 'address_extended', 'chain_name','trip_advisor_url'], inplace=True)\n", + "df['unique_name'] = df['name'] +' ' + df['address'] # some restaurants are chains or have more than one location\n", + "df.head()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you wanted to build a content filtering system now would be a good time to extract the text from the reviews, join them together and generate semantic embeddings from them like we did in our previous notebook.\n", + "\n", + "This would be a great approach, but to demonstrate the two tower architecture we won't use a pre-trained embedding model, and instead use the other columns as our raw features - but we will at least extract the numerical ratings from the reviews.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "df['min_rating'] = df['reviews'].apply(lambda x: np.min([r[\"review_rating\"] for r in x]))\n", + "df['max_rating'] = df['reviews'].apply(lambda x: np.max([r[\"review_rating\"] for r in x]))\n", + "df['avg_rating'] = df['reviews'].apply(lambda x: np.mean([r[\"review_rating\"] for r in x]))\n", + "df['stddev_rating'] = df['reviews'].apply(lambda x: np.std([r[\"review_rating\"] for r in x]))\n", + "df['price'] = df['price'].astype(int)\n", + "\n", + "# now take all the features we have and build a raw feature vector for each restaurant\n", + "numerical_cols = df.select_dtypes(include=['float64', 'int64']).columns\n", + "boolean_cols = df.select_dtypes(include=['bool']).columns\n", + "\n", + "# convert boolean columns to integers\n", + "df[boolean_cols] = df[boolean_cols].astype(int)\n", + "\n", + "# combine numerical and boolean columns into a single vector\n", + "df['feature_vector'] = df[numerical_cols.tolist() + boolean_cols.tolist()].values.tolist()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now have feature vectors with 30 features for each restaurant. The next step is to construct our raw feature vectors for our users.\n", + "\n", + "We don't have publicly available user data to correspond with this list of restaurants, so instead we'll generate some using the popular testing tool Faker." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "#NBEVAL_SKIP\n", + "!pip install Faker --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idnameusernameemailaddressphone_numberbirthdatelikesaccount_created_onprice_bracketnewsletternotificationsprofile_visibilitydata_sharing
08c49957f-be61-4083-94f5-de720d139124Kirsten Jensenashleewoodardblakejoshua@example.net865 Nancy Parkways\\nNew Christopherland, AS 31173(954)652-5823x047071934-06-21[breweries, cocktails, pizza, pasta, mexican, ...1973-07-05lowFalseFalsepublicFalse
1d98d91c3-7bcf-4d83-8a6b-cc847dd6214bShane Kerrjasoncantusmoore@example.org31892 Melissa Land Apt. 644\\nJulianhaven, VT 5...+1-768-369-9785x76621930-05-01[brunch, shakes, italian, bbq, ethiopian]1988-03-05middleTrueFalsepublicTrue
2f6344885-7388-4ee2-9686-009b14d8ee21Jon Powersvanessamurraykimberly64@example.net4592 Walton Prairie\\nPort Sharon, NC 56434251-346-1235x688381920-04-06[bbq, fast food, ethiopian, mexican, brunch, f...2009-03-20lowFalseTruefriends-onlyTrue
3156a5384-9285-4aa0-b361-b43855caf29dMelanie Riveraerica01xmorrison@example.org30240 Riley Glen\\nSouth Laura, NH 70744(767)958-22651916-03-12[fast food, fine dining, breweries, pasta, ita...1995-04-07highTrueFalsefriends-onlyFalse
4ae30ca78-f7b3-4637-992c-3e3a8894ebf9Thomas Hillrose77davistina@example.net707 Thompson Club Apt. 907\\nNorth Alexisview, ...+1-498-967-43181960-11-14[ethiopian, mexican]1993-05-26highTrueFalseprivateTrue
\n", + "
" + ], + "text/plain": [ + " user_id name username \\\n", + "0 8c49957f-be61-4083-94f5-de720d139124 Kirsten Jensen ashleewoodard \n", + "1 d98d91c3-7bcf-4d83-8a6b-cc847dd6214b Shane Kerr jasoncantu \n", + "2 f6344885-7388-4ee2-9686-009b14d8ee21 Jon Powers vanessamurray \n", + "3 156a5384-9285-4aa0-b361-b43855caf29d Melanie Rivera erica01 \n", + "4 ae30ca78-f7b3-4637-992c-3e3a8894ebf9 Thomas Hill rose77 \n", + "\n", + " email address \\\n", + "0 blakejoshua@example.net 865 Nancy Parkways\\nNew Christopherland, AS 31173 \n", + "1 smoore@example.org 31892 Melissa Land Apt. 644\\nJulianhaven, VT 5... \n", + "2 kimberly64@example.net 4592 Walton Prairie\\nPort Sharon, NC 56434 \n", + "3 xmorrison@example.org 30240 Riley Glen\\nSouth Laura, NH 70744 \n", + "4 davistina@example.net 707 Thompson Club Apt. 907\\nNorth Alexisview, ... \n", + "\n", + " phone_number birthdate \\\n", + "0 (954)652-5823x04707 1934-06-21 \n", + "1 +1-768-369-9785x7662 1930-05-01 \n", + "2 251-346-1235x68838 1920-04-06 \n", + "3 (767)958-2265 1916-03-12 \n", + "4 +1-498-967-4318 1960-11-14 \n", + "\n", + " likes account_created_on \\\n", + "0 [breweries, cocktails, pizza, pasta, mexican, ... 1973-07-05 \n", + "1 [brunch, shakes, italian, bbq, ethiopian] 1988-03-05 \n", + "2 [bbq, fast food, ethiopian, mexican, brunch, f... 2009-03-20 \n", + "3 [fast food, fine dining, breweries, pasta, ita... 1995-04-07 \n", + "4 [ethiopian, mexican] 1993-05-26 \n", + "\n", + " price_bracket newsletter notifications profile_visibility data_sharing \n", + "0 low False False public False \n", + "1 middle True False public True \n", + "2 low False True friends-only True \n", + "3 high True False friends-only False \n", + "4 high True False private True " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from faker import Faker\n", + "from uuid import uuid4\n", + "\n", + "fake = Faker()\n", + "\n", + "def generate_user():\n", + " return {\n", + " \"user_id\": str(uuid4()),\n", + " \"name\": fake.name(),\n", + " \"username\": fake.user_name(),\n", + " \"email\": fake.email(),\n", + " \"address\": fake.address(),\n", + " \"phone_number\": fake.phone_number(),\n", + " \"birthdate\": fake.date_of_birth().isoformat(),\n", + " \"likes\": fake.random_elements(elements=['burgers', 'shakes', 'pizza', 'italian', 'mexican', 'fine dining', 'bbq', 'cocktails', 'breweries', 'ethiopian', 'pasta', 'brunch','fast food'], unique=True),\n", + " \"account_created_on\": fake.date() ,\n", + " \"price_bracket\": fake.random_element(elements=(\"low\", \"middle\", \"high\")),\n", + " \"newsletter\": fake.boolean(),\n", + " \"notifications\": fake.boolean(),\n", + " \"profile_visibility\": fake.random_element(elements=(\"public\", \"private\", \"friends-only\")),\n", + " \"data_sharing\": fake.boolean()\n", + " }\n", + "\n", + "users = [generate_user() for _ in range(1000)]\n", + "\n", + "users_df = pd.DataFrame(users)\n", + "users_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_idnameusernameemailaddressphone_numberbirthdatelikesaccount_created_onnewsletter...pastapizzashakesprice_bracket_highprice_bracket_lowprice_bracket_middleprofile_visibility_friends-onlyprofile_visibility_privateprofile_visibility_publicfeature_vector
08c49957f-be61-4083-94f5-de720d139124Kirsten Jensenashleewoodardblakejoshua@example.net865 Nancy Parkways\\nNew Christopherland, AS 31173(954)652-5823x047071934-06-21[breweries, cocktails, pizza, pasta, mexican, ...1973-07-050...111010001[0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, ...
1d98d91c3-7bcf-4d83-8a6b-cc847dd6214bShane Kerrjasoncantusmoore@example.org31892 Melissa Land Apt. 644\\nJulianhaven, VT 5...+1-768-369-9785x76621930-05-01[brunch, shakes, italian, bbq, ethiopian]1988-03-051...001001001[1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, ...
2f6344885-7388-4ee2-9686-009b14d8ee21Jon Powersvanessamurraykimberly64@example.net4592 Walton Prairie\\nPort Sharon, NC 56434251-346-1235x688381920-04-06[bbq, fast food, ethiopian, mexican, brunch, f...2009-03-200...011010100[0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, ...
3156a5384-9285-4aa0-b361-b43855caf29dMelanie Riveraerica01xmorrison@example.org30240 Riley Glen\\nSouth Laura, NH 70744(767)958-22651916-03-12[fast food, fine dining, breweries, pasta, ita...1995-04-071...110100100[1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, ...
4ae30ca78-f7b3-4637-992c-3e3a8894ebf9Thomas Hillrose77davistina@example.net707 Thompson Club Apt. 907\\nNorth Alexisview, ...+1-498-967-43181960-11-14[ethiopian, mexican]1993-05-261...000100010[1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, ...
\n", + "

5 rows × 32 columns

\n", + "
" + ], + "text/plain": [ + " user_id name username \\\n", + "0 8c49957f-be61-4083-94f5-de720d139124 Kirsten Jensen ashleewoodard \n", + "1 d98d91c3-7bcf-4d83-8a6b-cc847dd6214b Shane Kerr jasoncantu \n", + "2 f6344885-7388-4ee2-9686-009b14d8ee21 Jon Powers vanessamurray \n", + "3 156a5384-9285-4aa0-b361-b43855caf29d Melanie Rivera erica01 \n", + "4 ae30ca78-f7b3-4637-992c-3e3a8894ebf9 Thomas Hill rose77 \n", + "\n", + " email address \\\n", + "0 blakejoshua@example.net 865 Nancy Parkways\\nNew Christopherland, AS 31173 \n", + "1 smoore@example.org 31892 Melissa Land Apt. 644\\nJulianhaven, VT 5... \n", + "2 kimberly64@example.net 4592 Walton Prairie\\nPort Sharon, NC 56434 \n", + "3 xmorrison@example.org 30240 Riley Glen\\nSouth Laura, NH 70744 \n", + "4 davistina@example.net 707 Thompson Club Apt. 907\\nNorth Alexisview, ... \n", + "\n", + " phone_number birthdate \\\n", + "0 (954)652-5823x04707 1934-06-21 \n", + "1 +1-768-369-9785x7662 1930-05-01 \n", + "2 251-346-1235x68838 1920-04-06 \n", + "3 (767)958-2265 1916-03-12 \n", + "4 +1-498-967-4318 1960-11-14 \n", + "\n", + " likes account_created_on \\\n", + "0 [breweries, cocktails, pizza, pasta, mexican, ... 1973-07-05 \n", + "1 [brunch, shakes, italian, bbq, ethiopian] 1988-03-05 \n", + "2 [bbq, fast food, ethiopian, mexican, brunch, f... 2009-03-20 \n", + "3 [fast food, fine dining, breweries, pasta, ita... 1995-04-07 \n", + "4 [ethiopian, mexican] 1993-05-26 \n", + "\n", + " newsletter ... pasta pizza shakes price_bracket_high \\\n", + "0 0 ... 1 1 1 0 \n", + "1 1 ... 0 0 1 0 \n", + "2 0 ... 0 1 1 0 \n", + "3 1 ... 1 1 0 1 \n", + "4 1 ... 0 0 0 1 \n", + "\n", + " price_bracket_low price_bracket_middle profile_visibility_friends-only \\\n", + "0 1 0 0 \n", + "1 0 1 0 \n", + "2 1 0 1 \n", + "3 0 0 1 \n", + "4 0 0 0 \n", + "\n", + " profile_visibility_private profile_visibility_public \\\n", + "0 0 1 \n", + "1 0 1 \n", + "2 0 0 \n", + "3 0 0 \n", + "4 1 0 \n", + "\n", + " feature_vector \n", + "0 [0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, ... \n", + "1 [1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, ... \n", + "2 [0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, ... \n", + "3 [1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, ... \n", + "4 [1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, ... \n", + "\n", + "[5 rows x 32 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "from sklearn.preprocessing import MultiLabelBinarizer\n", + "\n", + "# use a MultiLabelBinarizer to one-hot encode our user's 'likes' column, which has a list of users' food preferences\n", + "mlb = MultiLabelBinarizer()\n", + "\n", + "likes_encoded = mlb.fit_transform(users_df['likes'])\n", + "likes_df = pd.DataFrame(likes_encoded, columns=mlb.classes_)\n", + "\n", + "# concatenate the original users_df with the new one-hot encoded likes_df\n", + "users_df = pd.concat([users_df, likes_df], axis=1)\n", + "\n", + "# one-hot encode categorical columns\n", + "categorical_cols = ['price_bracket', 'profile_visibility']\n", + "users_df = pd.get_dummies(users_df, columns=categorical_cols)\n", + "\n", + "# convert boolean columns to integers\n", + "boolean_cols = users_df.select_dtypes(include=['boolean']).columns\n", + "users_df[boolean_cols] = users_df[boolean_cols].astype(int)\n", + "\n", + "# combine all numerical columns into a single feature vector\n", + "numerical_cols = users_df.select_dtypes(include=['int64', 'uint8']).columns\n", + "users_df['feature_vector'] = users_df[numerical_cols].values.tolist()\n", + "users_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Because two tower models are also trained on interaction data like our SVD collaborative filtering model we need to generate some purchases.\n", + "\n", + "This will be a 1 or -1 to indicate if a user has eaten at this restaurant before.\n", + "\n", + "Once again we're generating random labels for this example to go along with our random users." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "user_ids = users_df['user_id'].tolist()\n", + "restaurant_names = df[\"unique_name\"].tolist()\n", + "\n", + "# generate purchases by randomly selecting users and businesses\n", + "purchases = [\n", + " (user_ids[random.randrange(0, len(user_ids))],\n", + " restaurant_names[random.randrange(0, len(restaurant_names))]\n", + " )\n", + " for _ in range(200)\n", + "]\n", + "\n", + "positive_labels = []\n", + "for i in range(len(purchases)):\n", + " user_index = users_df[users_df['user_id'] == purchases[i][0]].index.item()\n", + " restaurant_index = df[df['unique_name'] == purchases[i][1]].index.item()\n", + " positive_labels.append((user_index, restaurant_index, 1.))\n", + "\n", + "# generate an equal number of negative examples\n", + "negative_labels = []\n", + "for i in range(len(purchases)):\n", + " user_index = random.randint(0, len(user_ids)-1)\n", + " restaurant_index = random.randint(0, len(restaurant_names)-1)\n", + " negative_labels.append((user_index, restaurant_index, -1.))\n", + "\n", + "labels = positive_labels + negative_labels" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now we have all of our data. The next steps are to define a the model and train it." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "\n", + "# define a custom dataset\n", + "class PurchaseDataset(Dataset):\n", + " def __init__(self, user_features, restaurant_features, labels):\n", + " self.user_features = user_features\n", + " self.restaurant_features = restaurant_features\n", + " self.labels = labels\n", + "\n", + " def __len__(self):\n", + " return len(self.labels)\n", + "\n", + " def __getitem__(self, idx):\n", + " user_index, restaurant_index, label = self.labels[idx]\n", + " return self.user_features[user_index], self.restaurant_features[restaurant_index], torch.tensor(label, dtype=torch.float32)\n", + "\n", + "# define the two tower model\n", + "class TwoTowerModel(nn.Module):\n", + " def __init__(self, user_input_dim, restaurant_input_dim, hidden_dim):\n", + " super(TwoTowerModel, self).__init__()\n", + " self.user_tower = nn.Sequential(\n", + " nn.Linear(user_input_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Dropout(p=0.5),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Dropout(p=0.5),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " )\n", + " self.restaurant_tower = nn.Sequential(\n", + " nn.Linear(restaurant_input_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Dropout(p=0.5),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " nn.ReLU(),\n", + " nn.Dropout(p=0.5),\n", + " nn.Linear(hidden_dim, hidden_dim),\n", + " )\n", + "\n", + " def get_user_embeddings(self, user_features):\n", + " return nn.functional.normalize(self.user_tower(user_features), dim=1)\n", + "\n", + " def get_restaurant_embeddings(self, restaurant_features):\n", + " return nn.functional.normalize(self.restaurant_tower(restaurant_features), dim=1)\n", + "\n", + " def forward(self, user_features, restaurant_features):\n", + " user_embedding = self.get_user_embeddings(user_features)\n", + " restaurant_embedding = self.get_restaurant_embeddings(restaurant_features)\n", + " return user_embedding, restaurant_embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare the data and data loader\n", + "user_features = torch.tensor(users_df['feature_vector'].tolist(), dtype=torch.float32)\n", + "restaurant_features = torch.tensor(df['feature_vector'].tolist(), dtype=torch.float32)\n", + "\n", + "dataset = PurchaseDataset(user_features, restaurant_features, labels)\n", + "dataloader = DataLoader(dataset, batch_size=64, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch [1/200], loss: 0.5038220286369324\n", + "epoch [11/200], loss: 0.5040968060493469\n", + "epoch [21/200], loss: 0.4700140357017517\n", + "epoch [31/200], loss: 0.40341031551361084\n", + "epoch [41/200], loss: 0.30302390456199646\n", + "epoch [51/200], loss: 0.511837363243103\n", + "epoch [61/200], loss: 0.37064486742019653\n", + "epoch [71/200], loss: 0.32029417157173157\n", + "epoch [81/200], loss: 0.41176533699035645\n", + "epoch [91/200], loss: 0.2636029124259949\n", + "epoch [101/200], loss: 0.5995010137557983\n", + "epoch [111/200], loss: 0.26358169317245483\n", + "epoch [121/200], loss: 0.15962541103363037\n", + "epoch [131/200], loss: 0.04395533353090286\n", + "epoch [141/200], loss: 0.2500947117805481\n", + "epoch [151/200], loss: 0.40849578380584717\n", + "epoch [161/200], loss: 0.14111760258674622\n", + "epoch [171/200], loss: 0.29329919815063477\n", + "epoch [181/200], loss: 0.13339880108833313\n", + "epoch [191/200], loss: 0.33117613196372986\n" + ] + } + ], + "source": [ + "# initialize the model, loss function and optimizer\n", + "model = TwoTowerModel(user_input_dim=user_features.shape[1], restaurant_input_dim=restaurant_features.shape[1], hidden_dim=128)\n", + "cosine_criterion = nn.CosineEmbeddingLoss()\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + "\n", + "# train model\n", + "num_epochs = 200\n", + "losses = []\n", + "for epoch in range(num_epochs):\n", + " for user_batch, restaurant_batch, label_batch in dataloader:\n", + " optimizer.zero_grad()\n", + " user_embeddings, restaurant_embeddings = model(user_batch, restaurant_batch)\n", + " loss = cosine_criterion(user_embeddings, restaurant_embeddings, label_batch)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if epoch % 10 == 0:\n", + " print(f'epoch [{epoch+1}/{num_epochs}], loss: {loss.item()}')\n", + " losses.append(loss.item())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Why use two towers instead of content or collaborative filtering?\n", + "This seems rather complicated compared to other recommender system architectures, so why go through with all of this effort? The best way to answer this is to compare with other recommendation system approaches.\n", + "\n", + "### Shortcomings of content filtering\n", + "The simplest machine learning approach to recommendations is content filtering. It's also an approach that doesn't take into account user behaviors beyond finding similar content. This may not sound too bad, but can quickly lead to users getting trapped into content bubbles, where once they interact with a certain item - even if it was just randomly - they only see similar items.\n", + "\n", + "### Shortcomings of collaborative filtering\n", + "Collaborative filtering approaches like Singular Value Decomposition (SVD) take the opposite approach and _only_ consider user behaviors to make recommendations. This has clear advantages, but one major drawback; SVD can't handle brand new users or brand new content. Each time a new user joins, or a new content is added to your library they won't have associated vectors. There also won't be meaningful new interaction data to re-train a model and generate vectors. It can be bad enough that a model needs frequent re-training; it can be an even bigger issue if you can't make recommendations for new users and content.\n", + "\n", + "Two tower models overcome these obstacles and to better understand how let's dive into what this type of architectures is really doing." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Two Towers Separate Embedding Vector Creation from Model Training\n", + "Now that we have a trained model we can use each tower in our two tower model to generate embeddings for our users and items.\n", + "Unlike SVD, we don't have to retrain our model to get these vectors. We also don't need new interaction data for our users or content." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "user_embeddings = model.get_user_embeddings(user_features=torch.tensor(users_df['feature_vector'].tolist(), dtype=torch.float32))\n", + "restaurant_embeddings = model.get_restaurant_embeddings(restaurant_features=torch.tensor(df['feature_vector'].tolist(), dtype=torch.float32))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The Best of Both Worlds\n", + "Two tower models are a triple whammy when it comes to solving the above problems:\n", + "- They are trained on interaction data, aka our labels, so learn not to fall into content bubbles\n", + "- They directly consider the user features _and_ content features\n", + "- they can handle brand new users and content that don't yet have interaction data. No retraining necessary\n", + "\n", + "While we need some interaction data to train our model initially, it's totally fine if not all users or restaurants are included in our labelled data. Only a sample is needed.\n", + "This is why we can handle new users and content without retraining. Only their raw features are needed to generate embeddings\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading into Redis\n", + "With two sets of vectors we'll load the restaurant data into a Redis vector store to search over, and the user vectors into a regular key look up for quick access.\n", + "We'll handle our restaurants opening and closing hours, as well as their location in longitude and latitude. We'll want these for later." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'name': '21st Amendment Brewery & Restaurant', 'address': '563 2nd St', 'locality': 'San Francisco', 'location': '-122.392576,37.782448', 'cuisine': ['Cafe', 'Pub Food', 'American', 'Burgers', 'Pizza'], 'price': 2, 'rating': 4.0, 'sunday_open': 1000, 'sunday_close': 2359, 'monday_open': 1130, 'monday_close': 2359, 'tuesday_open': 1130, 'tuesday_close': 2359, 'wednesday_open': 1130, 'wednesday_close': 2359, 'thursday_open': 1130, 'thursday_close': 2359, 'friday_open': 1130, 'friday_close': 2359, 'saturday_open': 1130, 'saturday_close': 2359, 'embedding': [0.028717460110783577, -0.010432110168039799, -0.15579310059547424, -0.09007972478866577, -0.08904828876256943, -0.06460892409086227, 0.06160367652773857, 0.08463871479034424, -0.11541824042797089, -0.15989527106285095, 0.15000291168689728, -0.0801846981048584, -0.06959360092878342, -0.06584298610687256, 0.08495321869850159, -0.05949929729104042, -0.01901606284081936, 0.044410597532987595, -0.06874579191207886, -0.06110486015677452, 0.08664961159229279, 0.10069684684276581, 0.00703305983915925, -0.1213110089302063, 0.07221467792987823, 0.08289125561714172, -0.020599940791726112, 0.08590658009052277, -0.05167737603187561, -0.034252531826496124, -0.032192397862672806, -0.013088015839457512, 0.051425255835056305, 0.10542334616184235, 0.11928749829530716, -0.043923888355493546, -0.03416838124394417, -0.09220845252275467, 0.008960519917309284, -0.03631928935647011, -0.009584952145814896, 0.02850543148815632, 0.041595641523599625, 0.008507505059242249, -0.023945361375808716, -0.029285553842782974, 0.1319705694913864, -0.000728881626855582, 0.17677032947540283, -0.11249547451734543, 0.08006928116083145, -0.02895255759358406, 0.08162163943052292, -0.001116430852562189, 0.12038654834032059, 0.08053001761436462, 0.05644155293703079, -0.13006895780563354, -0.09181211143732071, 0.047289978712797165, 0.15262317657470703, -0.15985533595085144, 0.15764641761779785, -0.06318134069442749, -0.0019673688802868128, -0.1305117905139923, -0.036956388503313065, 0.07861090451478958, -0.07377144694328308, -0.029948312789201736, -0.0180054884403944, 0.16489765048027039, -0.09569795429706573, 0.11060012876987457, -0.06405990570783615, 0.1352132111787796, -0.13919328153133392, -0.04358096793293953, 0.034535180777311325, -0.01949075236916542, 0.0075964899733662605, -0.014157578349113464, 0.008092504926025867, 0.031047292053699493, 0.0617695152759552, 0.014792567119002342, 0.181081160902977, 0.09052076935768127, 0.014595440588891506, -0.1563986986875534, -0.027365028858184814, -0.03568330034613609, -0.07662227004766464, -0.04418030381202698, -0.004757001996040344, -0.036813508719205856, 0.12870939075946808, -0.07621592283248901, -0.12695179879665375, 0.11817491799592972, 0.037855129688978195, -0.06083338335156441, 0.14988923072814941, 0.09214119613170624, -0.008170578628778458, -0.09872289001941681, 0.07875774055719376, -0.10017523169517517, 0.00689289253205061, 0.12864772975444794, -0.19386278092861176, 0.12486553192138672, -0.05002424493432045, 0.10140490531921387, -0.025215785950422287, 0.08094171434640884, 0.11039657890796661, -0.004418433643877506, 0.09027619659900665, -0.052607160061597824, 0.1531265676021576, 0.07230117917060852, -0.06995918601751328, 0.1714017540216446, 0.09504596143960953, -0.1021222248673439, 0.005956844426691532, 0.14401885867118835]}\n" + ] + } + ], + "source": [ + "# extract opening and closing times from the 'hours' column\n", + "def extract_opening_closing_times(hours, day):\n", + " # convert to a simple numeric representation of times\n", + " if day in hours:\n", + " return int(hours[day][0][0].replace(':','')), int(hours[day][0][1].replace(':',''))\n", + " else:\n", + " # we don't know their hours, assume a reasonable default of 9:00am to 8:00pm\n", + " return 900, 2000\n", + "\n", + "# create new columns for opening and closing times for each day of the week\n", + "for day in ['sunday', 'monday', 'tuesday', 'wednesday', 'thursday', 'friday', 'saturday']:\n", + " df[f'{day}_open'], df[f'{day}_close'] = zip(*df['hours'].apply(lambda x: extract_opening_closing_times(x, day)))\n", + "\n", + "# combine 'longitude' and 'latitude' into a single 'location' column\n", + "df['location'] = df.apply(lambda row: f\"{row['longitude']},{row['latitude']}\", axis=1)\n", + "\n", + "# drop the original 'hours' separate 'latitude' and 'longitude' columns as we don't need them anymore\n", + "df.drop(columns=['hours', 'latitude', 'longitude'], inplace=True)\n", + "\n", + "# ensure the 'embedding' column is in the correct format (list of floats)\n", + "df['embedding'] = restaurant_embeddings.detach().numpy().tolist()\n", + "\n", + "# ensure all columns are in the correct order as defined in the schema\n", + "df = df[['name', 'address', 'locality', 'location', 'cuisine', 'price', 'rating', 'sunday_open', 'sunday_close', 'monday_open', 'monday_close', 'tuesday_open', 'tuesday_close', 'wednesday_open', 'wednesday_close', 'thursday_open', 'thursday_close', 'friday_open', 'friday_close', 'saturday_open', 'saturday_close', 'embedding']]\n", + "\n", + "# print the first record to verify the format\n", + "print(df.to_dict(orient='records')[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11:30:23 redisvl.index.index INFO Index already exists, overwriting.\n" + ] + } + ], + "source": [ + "from redis import Redis\n", + "from redisvl.schema import IndexSchema\n", + "from redisvl.index import SearchIndex\n", + "\n", + "client = Redis.from_url(REDIS_URL)\n", + "\n", + "restaurant_schema = IndexSchema.from_dict({\n", + " 'index': {\n", + " 'name': 'restaurants',\n", + " 'prefix': 'restaurant',\n", + " 'storage_type': 'json'\n", + " },\n", + " 'fields': [\n", + " {'name': 'name', 'type': 'text'},\n", + " {'name': 'address', 'type': 'text'},\n", + " {'name': 'locality', 'type': 'tag'},\n", + " {'name': 'location', 'type': 'geo'},\n", + " {'name': 'cuisine', 'type': 'tag'},\n", + " {'name': 'price', 'type': 'numeric'},\n", + " {'name': 'rating', 'type': 'numeric'},\n", + " {'name': 'sunday_open', 'type': 'numeric'},\n", + " {'name': 'sunday_close', 'type': 'numeric'},\n", + " {'name': 'monday_open', 'type': 'numeric'},\n", + " {'name': 'monday_close', 'type': 'numeric'},\n", + " {'name': 'tuesday_open', 'type': 'numeric'},\n", + " {'name': 'tuesday_close', 'type': 'numeric'},\n", + " {'name': 'wednesday_open', 'type': 'numeric'},\n", + " {'name': 'wednesday_close', 'type': 'numeric'},\n", + " {'name': 'thursday_open', 'type': 'numeric'},\n", + " {'name': 'thursday_close', 'type': 'numeric'},\n", + " {'name': 'friday_open', 'type': 'numeric'},\n", + " {'name': 'friday_close', 'type': 'numeric'},\n", + " {'name': 'saturday_open', 'type': 'numeric'},\n", + " {'name': 'saturday_close', 'type': 'numeric'},\n", + " {\n", + " 'name': 'embedding',\n", + " 'type': 'vector',\n", + " 'attrs': {\n", + " 'dims': 128,\n", + " 'algorithm': 'flat',\n", + " 'datatype': 'float32',\n", + " 'distance_metric': 'cosine'\n", + " }\n", + " }\n", + " ]\n", + "})\n", + "\n", + "restaurant_index = SearchIndex(restaurant_schema, redis_client=client)\n", + "restaurant_index.create(overwrite=True, drop=True)\n", + "\n", + "restaurant_keys = restaurant_index.load(df.to_dict(orient='records'))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# load the user vectors into a regular redis space\n", + "from redis.commands.json.path import Path\n", + "\n", + "with client.pipeline() as pipe:\n", + " for user_id, embedding in zip(users_df['user_id'], user_embeddings):\n", + " user_key = f\"user:{user_id}\"\n", + "\n", + " user_data = {\n", + " \"user_embedding\": embedding.tolist(),\n", + " }\n", + " pipe.json().set(user_key, Path.root_path(), user_data)\n", + " pipe.execute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The power of deep learning with the speed of Redis\n", + "\n", + "I can hear you say it, \"deep learning is cool and all, but I need my system to be fast. I don't want to call a deep neural network to get recommendations.\"\n", + "\n", + "Well not to fear my friend, you won't have to! While training our model may take a while, you won't need to do this often.\n", + "And if you look closely you'll see that both the user and content embedding vectors can be generated once and reused again and again.\n", + "Only the vector search is happening when generating recommendations.\n", + "These embeddings will only change if your user or content features change and if you select your features wisely this won't be often." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Location Aware Recommendations\n", + "\n", + "We've shown how Redis can apply filters on top of vector similarity search to further refine results, but did you know it can also refine search results by location?\n", + "Using the `Geo` field type on our index definition we can apply a `GeoRadius` filter to find only places nearby, which seems mighty useful for a restaurant recommendation system.\n", + "\n", + "Combining `GeoRadius` with `Num` tags we can find places that are personally relevant to us, nearby _and_ open for business right now.\n", + "\n", + "We have all our data and vectors ready to go. Now let's put it all together with query logic." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "from redisvl.query.filter import Tag, Num, Geo, GeoRadius\n", + "import datetime\n", + "\n", + "def get_filter(user_long,\n", + " user_lat,\n", + " current_date_time,\n", + " radius=1000,\n", + " low_price=0.0,\n", + " high_price=5.0,\n", + " rating=0.0,\n", + " cuisines=[]):\n", + "\n", + " geo_filter = Geo(\"location\") == GeoRadius(user_long, user_lat, radius, unit=\"m\") # use a distance unit of meters\n", + "\n", + " open_filter = Num(f\"{current_date_time.strftime('%A').lower()}_open\") < current_date_time.hour*100 + current_date_time.minute\n", + " close_filter = Num(f\"{current_date_time.strftime('%A').lower()}_close\") > current_date_time.hour*100 + current_date_time.minute\n", + " time_filter = open_filter & close_filter\n", + "\n", + " price_filter = (Num('price') >= low_price) & (Num('price') <= high_price)\n", + "\n", + " rating_filter = Num('rating') >= rating\n", + "\n", + " cuisine_filter = Tag('cuisine') == cuisines\n", + "\n", + " return geo_filter & time_filter & price_filter & rating_filter & cuisine_filter\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "found 8 results from our query\n", + "{'id': 'restaurant:7e440f66c21d42f4974726aa7f7f5947', 'name': 'Nanis Coffee', 'address': '2739 Geary Blvd', 'location': '-122.448613,37.782187'}\n", + "{'id': 'restaurant:feca10fd72264cca9622f8f1acec41aa', 'name': 'Double Decker', 'address': '465 Grove St', 'location': '-122.424033,37.777531'}\n", + "{'id': 'restaurant:eb4d101c6ffa4ed2954a03c94e303568', 'name': 'Cafe Du Soleil', 'address': '200 Fillmore St', 'location': '-122.430158,37.771303'}\n", + "{'id': 'restaurant:ae692ddd39184a129aa04d9072022560', 'name': 'La Boulange', 'address': '2043 Fillmore St', 'location': '-122.43386,37.788408'}\n", + "{'id': 'restaurant:50f751f9d5ee45c2b017e2d77b61e124', 'name': 'Burgermeister', 'address': '138 Church St', 'location': '-122.42914,37.768755'}\n", + "{'id': 'restaurant:d74be4aa7ad04890884d01595502a100', 'name': 'Magnolia Pub and Brewery', 'address': '1398 Haight St', 'location': '-122.445238,37.770276'}\n", + "{'id': 'restaurant:6fdd9b24129f447383c8c7edb28308c7', 'name': 'Memphis Minnies BBQ Joint', 'address': '576 Haight St', 'location': '-122.431702,37.772058'}\n", + "{'id': 'restaurant:f949c6a0282e44eb8244ba3046cff7f7', 'name': 'Panini', 'address': '1457 Haight St', 'location': '-122.44629,37.770036'}\n" + ] + } + ], + "source": [ + "from redisvl.query import VectorQuery\n", + "\n", + "random_user = random.choice(users_df['user_id'].tolist())\n", + "user_vector = client.json().get(f\"user:{random_user}\")[\"user_embedding\"]\n", + "\n", + "# get a location for this user. Your app may call an API, here we'll set one randomly to within San Francisco\n", + "# San Francisco is within the longitude and latitude bounding box of:\n", + "# Lower corner: (-122.5137, 37.7099) in (longitude, latitude) format\n", + "# Upper corner: (-122.3785, 37.8101)\n", + "\n", + "longitude = random.uniform(-122.5137, -122.3785)\n", + "latitude = random.uniform(37.7099, 37.8101)\n", + "longitude, latitude = -122.439, 37.779\n", + "radius = 1500\n", + "\n", + "full_filter = get_filter(user_long=longitude,\n", + " user_lat=latitude,\n", + " radius=radius,\n", + " current_date_time=datetime.datetime.today())\n", + "\n", + "query = VectorQuery(vector=user_vector,\n", + " vector_field_name='embedding',\n", + " num_results=10,\n", + " return_score=False,\n", + " return_fields=['name', 'address', 'location', 'distance'],\n", + " filter_expression=full_filter,\n", + " )\n", + "\n", + "results = restaurant_index.query(query)\n", + "print(f\"found {len(results)} results from our query\")\n", + "for r in results:\n", + " print(r)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Seeing Is Believing\n", + "\n", + "With our vectors loaded and helper functions defined we can get some nearby recommendations. That's all well and good, but don't you wish you could see these recommendations? I sure do. So let's visualize them on an interactive map." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "#NBEVAL_SKIP\n", + "!pip install folium clipboard --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import folium\n", + "import clipboard\n", + "from IPython.display import display\n", + "\n", + "# create a map centered around San Francisco\n", + "figure = folium.Figure(width=700, height=600)\n", + "sf_map = folium.Map(location=[37.7749, -122.4194],\n", + " zoom_start=13,\n", + " max_bounds=True,\n", + " min_lat= 37.709 - 0.1,\n", + " max_lat= 37.8101 + 0.1,\n", + " min_lon= -122.3785 - 0.3,\n", + " max_lon= -122.5137 + 0.3,\n", + " )\n", + "\n", + "sf_map.add_to(figure)\n", + "\n", + "# add markers for each restaurant in blue\n", + "for idx, row in df.iterrows():\n", + " lat, lon = map(float, row['location'].split(','))\n", + " folium.Marker([lon, lat], popup=row['name']).add_to(sf_map)\n", + "\n", + "\n", + "# get personalized recommendations\n", + "user = users_df['user_id'].tolist()[42]\n", + "user_vector = client.json().get(f\"user:{user}\")[\"user_embedding\"]\n", + "\n", + "# get a location for this user. Your app may call an API, here we'll set one randomly to within San Francisco\n", + "# lower corner: (-122.5137, 37.7099) in (longitude, latitude) format\n", + "# upper corner: (-122.3785, 37.8101)\n", + "\n", + "longitude, latitude = -122.439, 37.779\n", + "num_results = 25\n", + "radius = 2000\n", + "\n", + "# draw a circle centered on our user\n", + "folium.Circle(\n", + " location=[latitude, longitude],\n", + " radius=radius,\n", + " color=\"green\",\n", + " weight=3,\n", + " fill=True,\n", + " fill_opacity=0.3,\n", + " opacity=1,\n", + ").add_to(sf_map)\n", + "\n", + "\n", + "full_filter = get_filter(user_long=longitude,\n", + " user_lat=latitude,\n", + " radius=radius,\n", + " current_date_time=datetime.datetime.today()\n", + " )\n", + "\n", + "query = VectorQuery(vector=user_vector,\n", + " vector_field_name='embedding',\n", + " num_results=num_results,\n", + " return_score=False,\n", + " return_fields=['name', 'address', 'location', 'rating'],\n", + " filter_expression=full_filter,\n", + " )\n", + "\n", + "results = restaurant_index.query(query)\n", + "\n", + "# now show our recommended places in red\n", + "for restaurant in results:\n", + " lat, lon = map(float, restaurant['location'].split(','))\n", + " folium.Marker([lon, lat], popup=restaurant['name'] + ' ' + restaurant['rating'] + ' stars', icon=folium.Icon(color='red')).add_to(sf_map)\n", + "\n", + "display(sf_map)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Conclusion\n", + "That's it! You've built a deep learning restaurant recommendation system with Redis. It's personalized, location aware, adaptable, and fast." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deleted 147 keys\n" + ] + }, + { + "data": { + "text/plain": [ + "1000" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# clean up your index\n", + "while remaining := restaurant_index.clear():\n", + " print(f\"Deleted {remaining} keys\")\n", + "\n", + "client.delete(*[f\"user:{user_id}\" for user_id in users_df['user_id'].tolist()])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "redis-ai-res", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}