Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/praisonai/praisonai/bots/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ async def process_inbound_telegram_message(

# 2. User allowlist and pairing check
user_id = message.sender.user_id if message.sender else ""
is_explicitly_allowed = bot.config.is_user_allowed(user_id)
is_explicitly_allowed = bool(bot.config.allowed_users) and bot.config.is_user_allowed(user_id)

if not is_explicitly_allowed:
# Check if bot context is available for pairing system
Expand Down
206 changes: 162 additions & 44 deletions src/praisonai/praisonai/cli/session/unified.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class UnifiedSession:
total_cost: float = 0.0
request_count: int = 0

# Track baseline values for proper delta merging (not persisted)
_baseline_input_tokens: int = field(default=0, init=False, repr=False)
_baseline_output_tokens: int = field(default=0, init=False, repr=False)
_baseline_cost: float = field(default=0.0, init=False, repr=False)
_baseline_request_count: int = field(default=0, init=False, repr=False)

# Model info
current_model: str = "gpt-4o-mini"

Expand Down Expand Up @@ -89,19 +95,45 @@ def update_stats(self, input_tokens: int, output_tokens: int, cost: float = 0.0)
self.request_count += 1
self.updated_at = datetime.now().isoformat()

def set_baseline_stats(self) -> None:
"""Set baseline stats for delta tracking during merge operations."""
self._baseline_input_tokens = self.total_input_tokens
self._baseline_output_tokens = self.total_output_tokens
self._baseline_cost = self.total_cost
self._baseline_request_count = self.request_count

def get_stat_deltas(self) -> Dict[str, int | float]:
"""Get deltas from baseline for proper merge."""
return {
"input_tokens": self.total_input_tokens - self._baseline_input_tokens,
"output_tokens": self.total_output_tokens - self._baseline_output_tokens,
"cost": self.total_cost - self._baseline_cost,
"request_count": self.request_count - self._baseline_request_count,
}

def clear_messages(self) -> None:
"""Clear all messages from the session."""
self.messages.clear()
self.updated_at = datetime.now().isoformat()

def to_dict(self) -> Dict[str, Any]:
"""Convert session to dictionary."""
return asdict(self)
"""Convert session to dictionary, excluding internal baseline fields."""
data = asdict(self)
# Remove internal baseline fields from serialization
for key in list(data.keys()):
if key.startswith('_baseline_'):
del data[key]
return data

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "UnifiedSession":
"""Create session from dictionary."""
return cls(**data)
# Remove any internal baseline fields that might have leaked into saved data
clean_data = {k: v for k, v in data.items() if not k.startswith('_baseline_')}
instance = cls(**clean_data)
# Initialize baseline values to current values
instance.set_baseline_stats()
return instance

@property
def message_count(self) -> int:
Expand Down Expand Up @@ -131,6 +163,7 @@ def __init__(self, session_dir: Optional[Path] = None):
self.session_dir = Path(session_dir) if session_dir else DEFAULT_SESSION_DIR
self.session_dir.mkdir(parents=True, exist_ok=True)
self._cache: Dict[str, UnifiedSession] = {}
self._cache_mtime: Dict[str, float] = {}
self._last_session_id: Optional[str] = None

def _get_session_path(self, session_id: str) -> Path:
Expand All @@ -140,6 +173,93 @@ def _get_session_path(self, session_id: str) -> Path:
def _get_last_session_path(self) -> Path:
"""Get the path to the last session marker file."""
return self.session_dir / ".last_session"

@staticmethod
def _messages_common_prefix(
left: List[Dict[str, str]], right: List[Dict[str, str]]
) -> int:
"""Return shared message prefix length for safe concurrent merge."""
prefix = 0
for left_msg, right_msg in zip(left, right, strict=False):
if left_msg.get("role") != right_msg.get("role"):
break
if left_msg.get("content") != right_msg.get("content"):
break
prefix += 1
return prefix

def _parse_session_file(self, f) -> Optional[UnifiedSession]:
"""Parse session JSON from an open file handle."""
try:
f.seek(0)
raw = f.read()
if not raw:
return None
data = json.loads(raw.decode('utf-8'))
return UnifiedSession.from_dict(data)
except Exception as e:
logger.error(f"Failed to parse session file: {e}")
return None

def _read_session_from_file(self, path: Path) -> Optional[UnifiedSession]:
"""Read a session from disk without using the in-process cache."""
if not path.exists():
return None

try:
with open(path, 'rb') as f:
if sys.platform == "win32":
import msvcrt
f.seek(0)
msvcrt.locking(f.fileno(), msvcrt.LK_RLCK, 1)
try:
session = self._parse_session_file(f)
finally:
f.seek(0)
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
elif _HAS_FCNTL:
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
try:
session = self._parse_session_file(f)
finally:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
else:
session = self._parse_session_file(f)

return session
except Exception as e:
logger.error(f"Failed to read session file {path}: {e}")
return None

def _merge_sessions(
self, disk_session: Optional[UnifiedSession], incoming: UnifiedSession
) -> UnifiedSession:
"""Merge incoming session updates without clobbering concurrent writes."""
if disk_session is None:
return incoming

merged = UnifiedSession.from_dict(disk_session.to_dict())

# Use prefix-based merge for append-only scenarios (original design)
prefix = self._messages_common_prefix(disk_session.messages, incoming.messages)
merged.messages = disk_session.messages + incoming.messages[prefix:]

# Merge stats using deltas instead of max()
incoming_deltas = incoming.get_stat_deltas()
merged.total_input_tokens += max(0, incoming_deltas["input_tokens"])
merged.total_output_tokens += max(0, incoming_deltas["output_tokens"])
merged.total_cost += max(0.0, incoming_deltas["cost"])
merged.request_count += max(0, incoming_deltas["request_count"])

# Update other fields with incoming values if present
if incoming.current_model:
merged.current_model = incoming.current_model
if incoming.metadata:
merged.metadata.update(incoming.metadata)
if incoming.workspace:
merged.workspace = incoming.workspace

return merged

def save(self, session: UnifiedSession) -> None:
"""
Expand All @@ -149,7 +269,6 @@ def save(self, session: UnifiedSession) -> None:
session: Session to save
"""
path = self._get_session_path(session.session_id)
session.updated_at = datetime.now().isoformat()

try:
# Open in r+b mode to avoid truncation before locking
Expand All @@ -165,25 +284,33 @@ def save(self, session: UnifiedSession) -> None:
f.seek(0)
msvcrt.locking(f.fileno(), msvcrt.LK_LOCK, 1) # Use blocking lock
try:
disk_session = self._parse_session_file(f)
merged = self._merge_sessions(disk_session, session)
merged.updated_at = datetime.now().isoformat()
f.seek(0)
f.truncate() # Clear file after acquiring lock
json_data = json.dumps(session.to_dict(), indent=2).encode('utf-8')
json_data = json.dumps(merged.to_dict(), indent=2).encode('utf-8')
f.write(json_data)
f.flush()
os.fsync(f.fileno()) # Force data to disk before unlock
session = merged
finally:
f.seek(0) # Return to start position for unlock
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
elif _HAS_FCNTL:
# Unix locking
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
try:
disk_session = self._parse_session_file(f)
merged = self._merge_sessions(disk_session, session)
merged.updated_at = datetime.now().isoformat()
f.seek(0)
f.truncate() # Clear file after acquiring lock
json_data = json.dumps(session.to_dict(), indent=2).encode('utf-8')
json_data = json.dumps(merged.to_dict(), indent=2).encode('utf-8')
f.write(json_data)
f.flush()
os.fsync(f.fileno()) # Force data to disk before unlock
session = merged
finally:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
else:
Expand All @@ -195,13 +322,24 @@ def save(self, session: UnifiedSession) -> None:
"concurrent writers may corrupt session files."
)
_WARNED_NO_FCNTL = True
disk_session = self._parse_session_file(f)
merged = self._merge_sessions(disk_session, session)
merged.updated_at = datetime.now().isoformat()
f.seek(0)
f.truncate()
json_data = json.dumps(session.to_dict(), indent=2).encode('utf-8')
json_data = json.dumps(merged.to_dict(), indent=2).encode('utf-8')
f.write(json_data)
session = merged

# Update cache
self._cache[session.session_id] = session
# Safely update mtime cache with error handling
try:
if path.exists():
self._cache_mtime[session.session_id] = path.stat().st_mtime
except (FileNotFoundError, OSError):
# File was deleted/moved between write and stat, skip mtime update
pass

# Update last session marker
self._update_last_session(session.session_id)
Expand All @@ -221,48 +359,25 @@ def load(self, session_id: str) -> Optional[UnifiedSession]:
Returns:
Session if found, None otherwise
"""
# Check cache first
if session_id in self._cache:
return self._cache[session_id]

path = self._get_session_path(session_id)
if not path.exists():
self._cache.pop(session_id, None)
self._cache_mtime.pop(session_id, None)
return None

try:
with open(path, 'rb') as f:
# Cross-platform file locking (shared lock for reading)
if sys.platform == "win32":
# Windows shared locking (read-only) - use blocking read lock
import msvcrt
f.seek(0)
msvcrt.locking(f.fileno(), msvcrt.LK_RLCK, 1) # Use shared/read lock
try:
json_data = f.read().decode('utf-8')
data = json.loads(json_data)
finally:
f.seek(0) # Return to start position for unlock
msvcrt.locking(f.fileno(), msvcrt.LK_UNLCK, 1)
elif _HAS_FCNTL:
# Unix shared locking
fcntl.flock(f.fileno(), fcntl.LOCK_SH)
try:
json_data = f.read().decode('utf-8')
data = json.loads(json_data)
finally:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
else:
# No locking available - just read
json_data = f.read().decode('utf-8')
data = json.loads(json_data)

session = UnifiedSession.from_dict(data)

session = self._read_session_from_file(path)
if session is not None:
# Set baseline stats for proper delta tracking
session.set_baseline_stats()
self._cache[session_id] = session
# Safely update mtime cache
try:
self._cache_mtime[session_id] = path.stat().st_mtime
except (FileNotFoundError, OSError):
# File was deleted/moved after read, skip mtime update
pass
logger.debug(f"Loaded session: {session_id}")
return session
except Exception as e:
logger.error(f"Failed to load session {session_id}: {e}")
return None
return session

def get_or_create(self, session_id: Optional[str] = None) -> UnifiedSession:
"""
Expand All @@ -282,6 +397,8 @@ def get_or_create(self, session_id: Optional[str] = None) -> UnifiedSession:
# Create new session
new_id = session_id or str(uuid.uuid4())[:8]
session = UnifiedSession(session_id=new_id)
# Set baseline stats for new session
session.set_baseline_stats()
self.save(session)
return session

Expand All @@ -299,6 +416,7 @@ def delete(self, session_id: str) -> bool:
if path.exists():
path.unlink()
self._cache.pop(session_id, None)
self._cache_mtime.pop(session_id, None)
logger.debug(f"Deleted session: {session_id}")
return True
return False
Expand Down
24 changes: 24 additions & 0 deletions src/praisonai/tests/unit/cli/test_unified_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,30 @@ def test_load_nonexistent(self, temp_session_dir):

assert session is None

def test_stale_cache_write_preserves_concurrent_updates(self, temp_session_dir):
"""Stale in-process cache must not clobber messages written by another store."""
writer = UnifiedSessionStore(session_dir=temp_session_dir)
reader = UnifiedSessionStore(session_dir=temp_session_dir)

session = UnifiedSession(session_id="shared")
session.add_user_message("warm cache")
writer.save(session)
stale = reader.load("shared")

writer_session = writer.load("shared")
writer_session.add_user_message("from writer")
writer_session.add_assistant_message("writer reply")
writer.save(writer_session)

stale.add_user_message("from reader")
stale.add_assistant_message("reader reply")
reader.save(stale)

final = writer.load("shared")
assert len(final.messages) == 5
assert final.messages[1]["content"] == "from writer"
assert final.messages[3]["content"] == "from reader"


class TestGlobalSessionStore:
"""Tests for global session store."""
Expand Down
Loading
Loading