Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions nicegui/persistence/redis_persistent_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

class RedisPersistentDict(PersistentDict):

def __init__(self, *, url: str, id: str, key_prefix: str = 'nicegui:') -> None: # pylint: disable=redefined-builtin
def __init__(self, *,
url: str,
id: str, # pylint: disable=redefined-builtin
key_prefix: str = 'nicegui:',
ttl: int | None = None
) -> None:
if not optional_features.has('redis'):
raise ImportError('Redis is not installed. Please run "pip install nicegui[redis]".')
self.url = url
Expand All @@ -28,6 +33,7 @@ def __init__(self, *, url: str, id: str, key_prefix: str = 'nicegui:') -> None:
self.redis_client = redis.from_url(self.url, **self._redis_client_params)
self.pubsub = self.redis_client.pubsub()
self.key = key_prefix + id
self.ttl = ttl
self._listener_task: asyncio.Task | None = None
super().__init__(data={}, on_change=self.publish)

Expand Down Expand Up @@ -84,8 +90,9 @@ async def backup() -> None:
if not await self.redis_client.exists(self.key) and not self:
return
pipeline = self.redis_client.pipeline()
pipeline.set(self.key, json.dumps(self))
pipeline.publish(self.key + 'changes', json.dumps(self))
data = json.dumps(self)
pipeline.set(self.key, data, ex=self.ttl)
pipeline.publish(self.key + 'changes', data)
await pipeline.execute()
if core.loop:
background_tasks.create_lazy(backup(), name=f'redis-{self.key}')
Expand Down
16 changes: 11 additions & 5 deletions nicegui/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@

request_contextvar: contextvars.ContextVar[Request | None] = contextvars.ContextVar('request_var', default=None)

GENERAL_ID = 'general'
USER_PREFIX = 'user-'
TAB_PREFIX = 'tab-'
TTL_BUFFER_SECONDS = 20 # Buffer to avoid race with prune_tab_storage polling


class RequestTrackingMiddleware(BaseHTTPMiddleware):

Expand Down Expand Up @@ -64,14 +69,15 @@ class Storage:
'''Maximum age in seconds before tab storage is automatically purged. Defaults to 30 days.'''

def __init__(self) -> None:
self._general = Storage._create_persistent_dict('general')
self._general = Storage._create_persistent_dict(GENERAL_ID)
self._users: dict[str, PersistentDict] = {}
self._tabs: dict[str, ObservableDict] = {}

@staticmethod
def _create_persistent_dict(id: str) -> PersistentDict: # pylint: disable=redefined-builtin
if Storage.redis_url:
return RedisPersistentDict(url=Storage.redis_url, id=id, key_prefix=Storage.redis_key_prefix)
ttl = int(core.app.storage.max_tab_storage_age + TTL_BUFFER_SECONDS) if id.startswith(TAB_PREFIX) else None
return RedisPersistentDict(url=Storage.redis_url, id=id, key_prefix=Storage.redis_key_prefix, ttl=ttl)
else:
return FilePersistentDict(Storage.path / f'storage-{id}.json', encoding='utf-8')

Expand Down Expand Up @@ -116,7 +122,7 @@ def user(self) -> PersistentDict:
return self._users[session_id]

async def _create_user_storage(self, session_id: str) -> None:
self._users[session_id] = Storage._create_persistent_dict(f'user-{session_id}')
self._users[session_id] = Storage._create_persistent_dict(f'{USER_PREFIX}{session_id}')
await self._users[session_id].initialize()

@property
Expand Down Expand Up @@ -148,7 +154,7 @@ async def _create_tab_storage(self, tab_id: str) -> None:
"""Create tab storage for the given tab ID."""
if tab_id not in self._tabs:
if Storage.redis_url:
self._tabs[tab_id] = Storage._create_persistent_dict(f'tab-{tab_id}')
self._tabs[tab_id] = Storage._create_persistent_dict(f'{TAB_PREFIX}{tab_id}')
tab = self._tabs[tab_id]
assert isinstance(tab, PersistentDict)
await tab.initialize()
Expand All @@ -159,7 +165,7 @@ def copy_tab(self, old_tab_id: str, tab_id: str) -> None:
"""Copy the tab storage to a new tab. (For internal use only.)"""
if old_tab_id in self._tabs:
if Storage.redis_url:
self._tabs[tab_id] = Storage._create_persistent_dict(f'tab-{tab_id}')
self._tabs[tab_id] = Storage._create_persistent_dict(f'{TAB_PREFIX}{tab_id}')
else:
self._tabs[tab_id] = ObservableDict()
self._tabs[tab_id].update(self._tabs[old_tab_id])
Expand Down