diff --git a/rdmo_chatbot/chatbot/stores/mysql.py b/rdmo_chatbot/chatbot/stores/mysql.py index 28ae075..e4f73fd 100644 --- a/rdmo_chatbot/chatbot/stores/mysql.py +++ b/rdmo_chatbot/chatbot/stores/mysql.py @@ -1,6 +1,7 @@ import json import MySQLdb +from MySQLdb import OperationalError from ..utils import dicts_to_messages, get_config, messages_to_dicts from . import BaseStore @@ -10,12 +11,34 @@ class MysqlStore(BaseStore): def __init__(self): + self.connection = None + self.cursor = None + self._connect() + self.create_table() + + def _connect(self): self.connection = MySQLdb.connect(**config.STORE_CONNECTION) + self.connection.autocommit(True) self.cursor = self.connection.cursor() - self.create_table() + + def _ensure_connection(self): + try: + self.connection.ping(True) + except Exception: + self._connect() + + def _execute(self, sql, params=None): + self._ensure_connection() + try: + return self.cursor.execute(sql, params or ()) + except OperationalError as exc: + if exc.args and exc.args[0] in (2006, 2013): + self._connect() + return self.cursor.execute(sql, params or ()) + raise def create_table(self): - self.cursor.execute(""" + self._execute(""" CREATE TABLE IF NOT EXISTS history ( id INT AUTO_INCREMENT PRIMARY KEY, user_identifier VARCHAR(150), @@ -26,10 +49,9 @@ def create_table(self): UNIQUE KEY unique_user_project (user_identifier, project_id) ); """) - self.connection.commit() def has_history(self, user_identifier, project_id): - self.cursor.execute(""" + self._execute(""" SELECT count(*) FROM history WHERE user_identifier = %s AND project_id = %s; """, (user_identifier, project_id) ) @@ -37,7 +59,7 @@ def has_history(self, user_identifier, project_id): return result[0] > 0 if result else False def get_history(self, user_identifier, project_id): - self.cursor.execute(""" + self._execute(""" SELECT messages FROM history WHERE user_identifier = %s AND project_id = %s; """, (user_identifier, project_id) ) @@ -45,17 +67,15 @@ def get_history(self, user_identifier, project_id): return dicts_to_messages(json.loads(result[0])) if result else [] def set_history(self, user_identifier, project_id, messages): - self.cursor.execute(""" + self._execute(""" INSERT INTO history (user_identifier, project_id, messages, created) VALUES (%s, %s, %s, CURRENT_TIMESTAMP) ON DUPLICATE KEY UPDATE messages = VALUES(messages), updated = CURRENT_TIMESTAMP; """, (user_identifier, project_id, json.dumps(messages_to_dicts(messages)))) - self.connection.commit() def reset_history(self, user_identifier, project_id): - self.cursor.execute(""" + self._execute(""" DELETE FROM history WHERE user_identifier = %s AND project_id = %s; """, [user_identifier, project_id] ) - self.connection.commit()