Skip to content
This repository was archived by the owner on Sep 8, 2025. It is now read-only.

Commit df3f69e

Browse files
feat: change .update method to also allow dictionaries (#130)
* Allow users to send a dict instead of UserAttributes model * Add tests * Check Python version before trying to import TypedDict * Remove `TypedDict` from the main `typing` import * Change format * Change format of types * 'Refactored by Sourcery' Co-authored-by: odiseo0 <[email protected]> Co-authored-by: Sourcery AI <>
1 parent 76fad56 commit df3f69e

File tree

5 files changed

+83
-12
lines changed

5 files changed

+83
-12
lines changed

gotrue/_async/client.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Subscription,
1919
User,
2020
UserAttributes,
21+
UserAttributesDict,
2122
)
2223
from .api import AsyncGoTrueAPI
2324
from .storage import AsyncMemoryStorage, AsyncSupportedStorage
@@ -301,13 +302,15 @@ async def refresh_session(self) -> Session:
301302
raise ValueError("Not logged in.")
302303
return await self._call_refresh_token()
303304

304-
async def update(self, *, attributes: UserAttributes) -> User:
305+
async def update(
306+
self, *, attributes: Union[UserAttributesDict, UserAttributes]
307+
) -> User:
305308
"""Updates user data, if there is a logged in user.
306309
307310
Parameters
308311
----------
309-
attributes : UserAttributes
310-
The attributes to update.
312+
attributes : UserAttributesDict | UserAttributes
313+
Attributes to update, could be: email, password, email_change_token, data
311314
312315
Returns
313316
-------
@@ -321,9 +324,15 @@ async def update(self, *, attributes: UserAttributes) -> User:
321324
"""
322325
if not self.current_session:
323326
raise ValueError("Not logged in.")
327+
328+
if isinstance(attributes, dict):
329+
attributes_to_update = UserAttributes(**attributes)
330+
else:
331+
attributes_to_update = attributes
332+
324333
response = await self.api.update_user(
325334
jwt=self.current_session.access_token,
326-
attributes=attributes,
335+
attributes=attributes_to_update,
327336
)
328337
self.current_session.user = response
329338
await self._save_session(session=self.current_session)

gotrue/_sync/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Subscription,
1919
User,
2020
UserAttributes,
21+
UserAttributesDict,
2122
)
2223
from .api import SyncGoTrueAPI
2324
from .storage import SyncMemoryStorage, SyncSupportedStorage
@@ -299,7 +300,7 @@ def refresh_session(self) -> Session:
299300
raise ValueError("Not logged in.")
300301
return self._call_refresh_token()
301302

302-
def update(self, *, attributes: UserAttributes) -> User:
303+
def update(self, *, attributes: Union[UserAttributesDict, UserAttributes]) -> User:
303304
"""Updates user data, if there is a logged in user.
304305
305306
Parameters
@@ -319,9 +320,15 @@ def update(self, *, attributes: UserAttributes) -> User:
319320
"""
320321
if not self.current_session:
321322
raise ValueError("Not logged in.")
323+
324+
if isinstance(attributes, dict):
325+
attributes_to_update = UserAttributes(**attributes)
326+
else:
327+
attributes_to_update = attributes
328+
322329
response = self.api.update_user(
323330
jwt=self.current_session.access_token,
324-
attributes=attributes,
331+
attributes=attributes_to_update,
325332
)
326333
self.current_session.user = response
327334
self._save_session(session=self.current_session)

gotrue/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
from __future__ import annotations
22

3+
import sys
34
from datetime import datetime
45
from enum import Enum
56
from time import time
67
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
78
from uuid import UUID
89

10+
if sys.version_info >= (3, 8):
11+
from typing import TypedDict
12+
else:
13+
from typing_extensions import TypedDict
14+
915
from httpx import Response
1016
from pydantic import BaseModel, root_validator
1117

@@ -150,3 +156,12 @@ class LinkType(str, Enum):
150156
magiclink = "magiclink"
151157
recovery = "recovery"
152158
invite = "invite"
159+
160+
161+
class UserAttributesDict(TypedDict):
162+
"""Dict version of `UserAttributes`"""
163+
164+
email: Optional[str]
165+
password: Optional[str]
166+
email_change_token: Optional[str]
167+
data: Optional[Any]

tests/_async/test_client_with_auto_confirm_enabled.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,25 @@ async def test_update_user(client: AsyncGoTrueClient):
259259
assert False, str(e)
260260

261261

262+
@pytest.mark.asyncio
263+
@pytest.mark.depends(on=[test_sign_in.__name__])
264+
async def test_update_user_dict(client: AsyncGoTrueClient):
265+
try:
266+
await client.init_recover()
267+
response = await client.update(attributes={"data": {"hello": "world"}})
268+
assert isinstance(response, User)
269+
assert response.id
270+
assert response.email == email
271+
assert response.email_confirmed_at
272+
assert response.last_sign_in_at
273+
assert response.created_at
274+
assert response.updated_at
275+
assert response.user_metadata
276+
assert response.user_metadata.get("hello") == "world"
277+
except Exception as e:
278+
assert False, str(e)
279+
280+
262281
@pytest.mark.asyncio
263282
@pytest.mark.depends(on=[test_update_user.__name__])
264283
async def test_get_user_after_update(client: AsyncGoTrueClient):
@@ -319,10 +338,10 @@ async def test_get_update_user_after_sign_out(client: AsyncGoTrueClient):
319338
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
320339
async def test_sign_in_with_the_wrong_password(client: AsyncGoTrueClient):
321340
try:
322-
await client.sign_in(email=email, password=password + "2")
341+
await client.sign_in(email=email, password=f"{password}2")
323342
assert False
324343
except APIError:
325-
assert True
344+
pass
326345
except Exception as e:
327346
assert False, str(e)
328347

@@ -401,8 +420,9 @@ async def test_get_session_from_url_errors(client: AsyncGoTrueClient):
401420
error_description = fake.email()
402421
try:
403422
await client.get_session_from_url(
404-
url=dummy_url + f"?error_description={error_description}"
423+
url=f"{dummy_url}?error_description={error_description}"
405424
)
425+
406426
assert False
407427
except APIError as e:
408428
assert e.code == 400

tests/_sync/test_client_with_auto_confirm_enabled.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,25 @@ def test_update_user(client: SyncGoTrueClient):
253253
assert False, str(e)
254254

255255

256+
@pytest.mark.asyncio
257+
@pytest.mark.depends(on=[test_sign_in.__name__])
258+
def test_update_user(client: SyncGoTrueClient):
259+
try:
260+
client.init_recover()
261+
response = client.update(attributes={"data": {"hello": "world"}})
262+
assert isinstance(response, User)
263+
assert response.id
264+
assert response.email == email
265+
assert response.email_confirmed_at
266+
assert response.last_sign_in_at
267+
assert response.created_at
268+
assert response.updated_at
269+
assert response.user_metadata
270+
assert response.user_metadata.get("hello") == "world"
271+
except Exception as e:
272+
assert False, str(e)
273+
274+
256275
@pytest.mark.asyncio
257276
@pytest.mark.depends(on=[test_update_user.__name__])
258277
def test_get_user_after_update(client: SyncGoTrueClient):
@@ -313,10 +332,10 @@ def test_get_update_user_after_sign_out(client: SyncGoTrueClient):
313332
@pytest.mark.depends(on=[test_get_user_after_sign_out.__name__])
314333
def test_sign_in_with_the_wrong_password(client: SyncGoTrueClient):
315334
try:
316-
client.sign_in(email=email, password=password + "2")
335+
client.sign_in(email=email, password=f"{password}2")
317336
assert False
318337
except APIError:
319-
assert True
338+
pass
320339
except Exception as e:
321340
assert False, str(e)
322341

@@ -395,8 +414,9 @@ def test_get_session_from_url_errors(client: SyncGoTrueClient):
395414
error_description = fake.email()
396415
try:
397416
client.get_session_from_url(
398-
url=dummy_url + f"?error_description={error_description}"
417+
url=f"{dummy_url}?error_description={error_description}"
399418
)
419+
400420
assert False
401421
except APIError as e:
402422
assert e.code == 400

0 commit comments

Comments
 (0)