diff --git a/rdmo_chatbot/chatbot/stores/mysql.py b/rdmo_chatbot/chatbot/stores/mysql.py index 28ae075..994cf5a 100644 --- a/rdmo_chatbot/chatbot/stores/mysql.py +++ b/rdmo_chatbot/chatbot/stores/mysql.py @@ -10,52 +10,61 @@ class MysqlStore(BaseStore): def __init__(self): - self.connection = MySQLdb.connect(**config.STORE_CONNECTION) - self.cursor = self.connection.cursor() self.create_table() + def connect(self): + return MySQLdb.connect(**config.STORE_CONNECTION) + def create_table(self): - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS history ( - id INT AUTO_INCREMENT PRIMARY KEY, - user_identifier VARCHAR(150), - project_id INT, - messages JSON, - created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, - UNIQUE KEY unique_user_project (user_identifier, project_id) - ); - """) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS history ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_identifier VARCHAR(150), + project_id INT, + messages JSON, + created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + UNIQUE KEY unique_user_project (user_identifier, project_id) + ); + """) def has_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s; - """, (user_identifier, project_id) - ) - result = self.cursor.fetchone() - return result[0] > 0 if result else False + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s; + """, (user_identifier, project_id) + ) + result = cursor.fetchone() + return result[0] > 0 if result else False def get_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s; - """, (user_identifier, project_id) - ) - result = self.cursor.fetchone() - return dicts_to_messages(json.loads(result[0])) if result else [] + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s; + """, (user_identifier, project_id) + ) + result = cursor.fetchone() + return dicts_to_messages(json.loads(result[0])) if result else [] def set_history(self, user_identifier, project_id, messages): - self.cursor.execute(""" - INSERT INTO history (user_identifier, project_id, messages, created) VALUES (%s, %s, %s, CURRENT_TIMESTAMP) - ON DUPLICATE KEY UPDATE - messages = VALUES(messages), - updated = CURRENT_TIMESTAMP; - """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + INSERT INTO history (user_identifier, project_id, messages, created) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON DUPLICATE KEY UPDATE + messages = VALUES(messages), + updated = CURRENT_TIMESTAMP; + """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) def reset_history(self, user_identifier, project_id): - self.cursor.execute(""" - DELETE FROM history WHERE user_identifier = %s AND project_id = %s; - """, [user_identifier, project_id] - ) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + DELETE FROM history WHERE user_identifier = %s AND project_id = %s; + """, [user_identifier, project_id] + ) diff --git a/rdmo_chatbot/chatbot/stores/postgres.py b/rdmo_chatbot/chatbot/stores/postgres.py index 933b9dd..de4deec 100644 --- a/rdmo_chatbot/chatbot/stores/postgres.py +++ b/rdmo_chatbot/chatbot/stores/postgres.py @@ -10,50 +10,61 @@ class PostgresStore(BaseStore): def __init__(self): - - self.connection = psycopg.connect(**config.STORE_CONNECTION) - self.cursor = self.connection.cursor() self.create_table() + def connect(self): + return psycopg.connect(**config.STORE_CONNECTION) + def create_table(self): - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS history ( - id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, - user_identifier VARCHAR(150), - project_id INT, - messages JSONB, - created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE (user_identifier, project_id) - ); - """) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS history ( + id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY, + user_identifier VARCHAR(150), + project_id INT, + messages JSONB, + created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (user_identifier, project_id) + ); + """) + connection.commit() def has_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s; - """, (user_identifier, project_id)) - result = self.cursor.fetchone() - return result[0] > 0 if result else False + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s; + """, (user_identifier, project_id)) + result = cursor.fetchone() + return result[0] > 0 if result else False def get_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s; - """, (user_identifier, project_id)) - result = self.cursor.fetchone() - return dicts_to_messages(result[0]) if result else [] + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s; + """, (user_identifier, project_id)) + result = cursor.fetchone() + return dicts_to_messages(result[0]) if result else [] def set_history(self, user_identifier, project_id, messages): - self.cursor.execute(""" - INSERT INTO history (user_identifier, project_id, messages, created) VALUES (%s, %s, %s, CURRENT_TIMESTAMP) - ON CONFLICT (user_identifier, project_id) DO UPDATE SET - messages = EXCLUDED.messages, - updated = CURRENT_TIMESTAMP; - """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + INSERT INTO history (user_identifier, project_id, messages, created) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (user_identifier, project_id) DO UPDATE SET + messages = EXCLUDED.messages, + updated = CURRENT_TIMESTAMP; + """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) + connection.commit() def reset_history(self, user_identifier, project_id): - self.cursor.execute(""" - DELETE FROM history WHERE user_identifier = %s AND project_id = %s; - """, (user_identifier, project_id)) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + DELETE FROM history WHERE user_identifier = %s AND project_id = %s; + """, (user_identifier, project_id)) + connection.commit() diff --git a/rdmo_chatbot/chatbot/stores/sqlite3.py b/rdmo_chatbot/chatbot/stores/sqlite3.py index a38f70c..ccf8c3a 100644 --- a/rdmo_chatbot/chatbot/stores/sqlite3.py +++ b/rdmo_chatbot/chatbot/stores/sqlite3.py @@ -1,5 +1,6 @@ import json import sqlite3 +import sys from ..utils import dicts_to_messages, get_config, messages_to_dicts from . import BaseStore @@ -10,51 +11,65 @@ class Sqlite3Store(BaseStore): def __init__(self): - self.connection = sqlite3.connect(config.STORE_CONNECTION) - self.cursor = self.connection.cursor() + if sys.version_info < (3, 11): + raise RuntimeError("Sqlite3Store requires Python 3.11 or higher.") + self.create_table() + def connect(self): + return sqlite3.connect(config.STORE_CONNECTION) + def create_table(self): - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS history ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_identifier TEXT, - project_id INTEGER, - messages JSON, - created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - UNIQUE(user_identifier, project_id) - ); - """) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_identifier TEXT, + project_id INTEGER, + messages JSON, + created TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE(user_identifier, project_id) + ); + """) + connection.commit() def has_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT count(*) FROM history WHERE user_identifier = ? AND project_id = ?; - """, (user_identifier, project_id)) - result = self.cursor.fetchone() - return result[0] > 0 if result else False + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT count(*) FROM history WHERE user_identifier = ? AND project_id = ?; + """, (user_identifier, project_id)) + result = cursor.fetchone() + return result[0] > 0 if result else False def get_history(self, user_identifier, project_id): - self.cursor.execute(""" - SELECT messages FROM history WHERE user_identifier = ? AND project_id = ?; - """, (user_identifier, project_id)) - result = self.cursor.fetchone() - if not result or result[0] is None: - return [] - return dicts_to_messages(json.loads(result[0])) + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + SELECT messages FROM history WHERE user_identifier = ? AND project_id = ?; + """, (user_identifier, project_id)) + result = cursor.fetchone() + if not result or result[0] is None: + return [] + return dicts_to_messages(json.loads(result[0])) def set_history(self, user_identifier, project_id, messages): - self.cursor.execute(""" - INSERT INTO history (user_identifier, project_id, messages) VALUES (?, ?, ?) - ON CONFLICT (user_identifier, project_id) DO UPDATE SET - messages = EXCLUDED.messages, - updated = CURRENT_TIMESTAMP; - """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + INSERT INTO history (user_identifier, project_id, messages) VALUES (?, ?, ?) + ON CONFLICT (user_identifier, project_id) DO UPDATE SET + messages = EXCLUDED.messages, + updated = CURRENT_TIMESTAMP; + """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) + connection.commit() def reset_history(self, user_identifier, project_id): - self.cursor.execute(""" - DELETE FROM history WHERE user_identifier = ? AND project_id = ?; - """, (user_identifier, project_id)) - self.connection.commit() + with self.connect() as connection: + with connection.cursor() as cursor: + cursor.execute(""" + DELETE FROM history WHERE user_identifier = ? AND project_id = ?; + """, (user_identifier, project_id)) + connection.commit()