Skip to content

Commit cad188e

Browse files
abrookinsclaude
andcommitted
Fix authentication event loop corruption by converting get_current_user to async
The get_current_user() function was using asyncio.run() within FastAPI's async context, which creates a new event loop and causes "Event loop is closed" errors. This led to intermittent authentication failures where requests would alternate between success (200) and failure (500). Changes: - Convert get_current_user() from sync to async function - Replace asyncio.run(verify_token()) with await verify_token() - Update require_scope() and require_role() dependency functions to be async - Fix all related test cases to use await when calling these functions This resolves the issue where tool calls to the memory server would intermittently fail with 500 errors during token verification. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent a2fcd95 commit cad188e

File tree

3 files changed

+33
-32
lines changed

3 files changed

+33
-32
lines changed

agent_memory_server/auth.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ async def verify_token(token: str) -> UserInfo:
346346
) from e
347347

348348

349-
def get_current_user(
349+
async def get_current_user(
350350
credentials: HTTPAuthorizationCredentials | None = Depends(oauth2_scheme),
351351
) -> UserInfo:
352352
if settings.disable_auth or settings.auth_mode == "disabled":
@@ -371,17 +371,15 @@ def get_current_user(
371371

372372
# Determine authentication mode
373373
if settings.auth_mode == "token" or settings.token_auth_enabled:
374-
import asyncio
375-
376-
return asyncio.run(verify_token(credentials.credentials))
374+
return await verify_token(credentials.credentials)
377375
if settings.auth_mode == "oauth2":
378376
return verify_jwt(credentials.credentials)
379377
# Default to OAuth2 for backward compatibility
380378
return verify_jwt(credentials.credentials)
381379

382380

383381
def require_scope(required_scope: str):
384-
def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
382+
async def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
385383
if settings.disable_auth:
386384
return user
387385

@@ -397,7 +395,7 @@ def scope_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
397395

398396

399397
def require_role(required_role: str):
400-
def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
398+
async def role_dependency(user: UserInfo = Depends(get_current_user)) -> UserInfo:
401399
if settings.disable_auth:
402400
return user
403401

tests/test_auth.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ async def test_get_current_user_disabled_auth(self, mock_settings):
685685
"""Test get_current_user when authentication is disabled"""
686686
mock_settings.disable_auth = True
687687

688-
result = get_current_user(None)
688+
result = await get_current_user(None)
689689

690690
assert isinstance(result, UserInfo)
691691
assert result.sub == "local-dev-user"
@@ -700,7 +700,7 @@ async def test_get_current_user_missing_credentials(self, mock_settings):
700700
mock_settings.auth_mode = "oauth2"
701701

702702
with pytest.raises(HTTPException) as exc_info:
703-
get_current_user(None)
703+
await get_current_user(None)
704704

705705
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
706706
assert "Missing authorization header" in str(exc_info.value.detail)
@@ -717,7 +717,7 @@ async def test_get_current_user_empty_credentials(self, mock_settings):
717717
empty_creds = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")
718718

719719
with pytest.raises(HTTPException) as exc_info:
720-
get_current_user(empty_creds)
720+
await get_current_user(empty_creds)
721721

722722
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
723723
assert "Missing bearer token" in str(exc_info.value.detail)
@@ -736,7 +736,7 @@ async def test_get_current_user_valid_token(self, mock_settings, valid_token):
736736
expected_user = UserInfo(sub="test-user", email="[email protected]")
737737
mock_verify.return_value = expected_user
738738

739-
result = get_current_user(creds)
739+
result = await get_current_user(creds)
740740

741741
assert result == expected_user
742742
mock_verify.assert_called_once_with(valid_token)
@@ -753,7 +753,7 @@ async def test_require_scope_success(self, mock_settings):
753753
user = UserInfo(sub="test-user", scope="read write admin")
754754
scope_dependency = require_scope("read")
755755

756-
result = scope_dependency(user)
756+
result = await scope_dependency(user)
757757
assert result == user
758758

759759
@pytest.mark.asyncio
@@ -765,7 +765,7 @@ async def test_require_scope_failure(self, mock_settings):
765765
scope_dependency = require_scope("admin")
766766

767767
with pytest.raises(HTTPException) as exc_info:
768-
scope_dependency(user)
768+
await scope_dependency(user)
769769

770770
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
771771
assert "Insufficient permissions" in str(exc_info.value.detail)
@@ -780,7 +780,7 @@ async def test_require_scope_no_scope(self, mock_settings):
780780
scope_dependency = require_scope("read")
781781

782782
with pytest.raises(HTTPException) as exc_info:
783-
scope_dependency(user)
783+
await scope_dependency(user)
784784

785785
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
786786

@@ -792,7 +792,7 @@ async def test_require_scope_disabled_auth(self, mock_settings):
792792
user = UserInfo(sub="test-user", scope=None)
793793
scope_dependency = require_scope("admin")
794794

795-
result = scope_dependency(user)
795+
result = await scope_dependency(user)
796796
assert result == user
797797

798798
@pytest.mark.asyncio
@@ -803,7 +803,7 @@ async def test_require_role_success(self, mock_settings):
803803
user = UserInfo(sub="test-user", roles=["user", "admin"])
804804
role_dependency = require_role("admin")
805805

806-
result = role_dependency(user)
806+
result = await role_dependency(user)
807807
assert result == user
808808

809809
@pytest.mark.asyncio
@@ -815,7 +815,7 @@ async def test_require_role_failure(self, mock_settings):
815815
role_dependency = require_role("admin")
816816

817817
with pytest.raises(HTTPException) as exc_info:
818-
role_dependency(user)
818+
await role_dependency(user)
819819

820820
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
821821
assert "Insufficient permissions" in str(exc_info.value.detail)
@@ -830,7 +830,7 @@ async def test_require_role_no_roles(self, mock_settings):
830830
role_dependency = require_role("admin")
831831

832832
with pytest.raises(HTTPException) as exc_info:
833-
role_dependency(user)
833+
await role_dependency(user)
834834

835835
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
836836

@@ -842,7 +842,7 @@ async def test_require_role_disabled_auth(self, mock_settings):
842842
user = UserInfo(sub="test-user", roles=None)
843843
role_dependency = require_role("admin")
844844

845-
result = role_dependency(user)
845+
result = await role_dependency(user)
846846
assert result == user
847847

848848

tests/test_token_auth.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -189,59 +189,62 @@ async def test_verify_token_wrong_token(self, mock_redis, sample_token_info):
189189
class TestGetCurrentUser:
190190
"""Test get_current_user with token authentication."""
191191

192-
def test_get_current_user_disabled_auth(self, mock_settings):
192+
@pytest.mark.asyncio
193+
async def test_get_current_user_disabled_auth(self, mock_settings):
193194
"""Test get_current_user with disabled authentication."""
194195
mock_settings.disable_auth = True
195196
mock_settings.auth_mode = "disabled"
196197

197-
user_info = get_current_user(None)
198+
user_info = await get_current_user(None)
198199

199200
assert user_info.sub == "local-dev-user"
200201
assert user_info.aud == "local-dev"
201202

202-
def test_get_current_user_missing_credentials(self, mock_settings):
203+
@pytest.mark.asyncio
204+
async def test_get_current_user_missing_credentials(self, mock_settings):
203205
"""Test get_current_user with missing credentials."""
204206
mock_settings.disable_auth = False
205207
mock_settings.auth_mode = "token"
206208

207209
with pytest.raises(HTTPException) as exc_info:
208-
get_current_user(None)
210+
await get_current_user(None)
209211

210212
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
211213
assert "Missing authorization header" in exc_info.value.detail
212214

213-
def test_get_current_user_missing_token(self, mock_settings):
215+
@pytest.mark.asyncio
216+
async def test_get_current_user_missing_token(self, mock_settings):
214217
"""Test get_current_user with missing token."""
215218
mock_settings.disable_auth = False
216219
mock_settings.auth_mode = "token"
217220

218221
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials="")
219222

220223
with pytest.raises(HTTPException) as exc_info:
221-
get_current_user(credentials)
224+
await get_current_user(credentials)
222225

223226
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
224227
assert "Missing bearer token" in exc_info.value.detail
225228

226229
@patch("agent_memory_server.auth.verify_token")
227-
def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
230+
@pytest.mark.asyncio
231+
async def test_get_current_user_token_auth(self, mock_verify_token, mock_settings):
228232
"""Test get_current_user with token authentication."""
229233
mock_settings.disable_auth = False
230234
mock_settings.auth_mode = "token"
231235

232236
# Mock verify_token to return a user
233237
mock_user = Mock()
234238
mock_user.sub = "token-user"
239+
mock_verify_token.return_value = mock_user
235240

236-
# Mock asyncio.run to return the user directly
237-
with patch("asyncio.run", return_value=mock_user):
238-
credentials = HTTPAuthorizationCredentials(
239-
scheme="Bearer", credentials="test_token"
240-
)
241+
credentials = HTTPAuthorizationCredentials(
242+
scheme="Bearer", credentials="test_token"
243+
)
241244

242-
user_info = get_current_user(credentials)
245+
user_info = await get_current_user(credentials)
243246

244-
assert user_info.sub == "token-user"
247+
assert user_info.sub == "token-user"
245248

246249

247250
class TestAuthConfig:

0 commit comments

Comments
 (0)