Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions rdmo_chatbot/chatbot/stores/mysql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json

import MySQLdb
from MySQLdb import OperationalError

from ..utils import dicts_to_messages, get_config, messages_to_dicts
from . import BaseStore
Expand All @@ -10,12 +11,34 @@
class MysqlStore(BaseStore):

def __init__(self):
self.connection = None
self.cursor = None
self._connect()
self.create_table()

def _connect(self):
self.connection = MySQLdb.connect(**config.STORE_CONNECTION)
self.connection.autocommit(True)
self.cursor = self.connection.cursor()
self.create_table()

def _ensure_connection(self):
try:
self.connection.ping(True)
except Exception:
self._connect()

def _execute(self, sql, params=None):
self._ensure_connection()
try:
return self.cursor.execute(sql, params or ())
except OperationalError as exc:
if exc.args and exc.args[0] in (2006, 2013):
self._connect()
return self.cursor.execute(sql, params or ())
raise

def create_table(self):
self.cursor.execute("""
self._execute("""
CREATE TABLE IF NOT EXISTS history (
id INT AUTO_INCREMENT PRIMARY KEY,
user_identifier VARCHAR(150),
Expand All @@ -26,36 +49,33 @@ def create_table(self):
UNIQUE KEY unique_user_project (user_identifier, project_id)
);
""")
self.connection.commit()

def has_history(self, user_identifier, project_id):
self.cursor.execute("""
self._execute("""
SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s;
""", (user_identifier, project_id)
)
result = self.cursor.fetchone()
return result[0] > 0 if result else False

def get_history(self, user_identifier, project_id):
self.cursor.execute("""
self._execute("""
SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s;
""", (user_identifier, project_id)
)
result = self.cursor.fetchone()
return dicts_to_messages(json.loads(result[0])) if result else []

def set_history(self, user_identifier, project_id, messages):
self.cursor.execute("""
self._execute("""
INSERT INTO history (user_identifier, project_id, messages, created) VALUES (%s, %s, %s, CURRENT_TIMESTAMP)
ON DUPLICATE KEY UPDATE
messages = VALUES(messages),
updated = CURRENT_TIMESTAMP;
""", (user_identifier, project_id, json.dumps(messages_to_dicts(messages))))
self.connection.commit()

def reset_history(self, user_identifier, project_id):
self.cursor.execute("""
self._execute("""
DELETE FROM history WHERE user_identifier = %s AND project_id = %s;
""", [user_identifier, project_id]
)
self.connection.commit()