Skip to content

Commit c181037

Browse files
authored
RSDK-10208 - create app client from env vars (#874)
1 parent 40a4361 commit c181037

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

src/viam/app/viam_client.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import Mapping, Optional
23

34
from grpclib.client import Channel
@@ -10,7 +11,7 @@
1011
from viam.app.ml_training_client import MLTrainingClient
1112
from viam.app.provisioning_client import ProvisioningClient
1213
from viam.robot.client import RobotClient
13-
from viam.rpc.dial import DialOptions, _dial_app, _get_access_token
14+
from viam.rpc.dial import Credentials, DialOptions, _dial_app, _get_access_token
1415

1516
LOGGER = logging.getLogger(__name__)
1617

@@ -24,14 +25,46 @@ class ViamClient:
2425
ViamClient.create_from_dial_options(...)
2526
"""
2627

28+
@classmethod
29+
async def create_from_env_vars(cls, dial_options: Optional[DialOptions] = None, app_url: Optional[str] = None) -> Self:
30+
"""Create `ViamClient` using credentials set in the environment as `VIAM_API_KEY` and `VIAM_API_KEY_ID`.
31+
32+
::
33+
34+
client = await ViamClient.create_from_env_vars()
35+
36+
Args:
37+
dial_options (Optional[viam.rpc.dial.DialOptions]): Options for authorization and connection to app.
38+
If not provided, default options will be selected. Note that `creds` and `auth_entity`
39+
fields will be overwritten by the values set by a module.
40+
app_url: (Optional[str]): URL of app. Uses app.viam.com if not specified.
41+
42+
Raises:
43+
ValueError: If there are no env vars set by the module, or if they are set improperly
44+
45+
"""
46+
dial_options = dial_options if dial_options else DialOptions()
47+
api_key = os.environ.get("VIAM_API_KEY")
48+
if api_key is None:
49+
raise ValueError("api key cannot be None")
50+
api_key_id = os.environ.get("VIAM_API_KEY_ID")
51+
if api_key_id is None:
52+
raise ValueError("api key ID cannot be None")
53+
credentials = Credentials(type="api-key", payload=api_key)
54+
dial_options.credentials = credentials
55+
dial_options.auth_entity = api_key_id
56+
57+
return await cls.create_from_dial_options(dial_options, app_url)
58+
59+
2760
@classmethod
2861
async def create_from_dial_options(cls, dial_options: DialOptions, app_url: Optional[str] = None) -> Self:
2962
"""Create `ViamClient` that establishes a connection to the Viam app.
3063
3164
::
3265
3366
dial_options = DialOptions.with_api_key("<API-KEY>", "<API-KEY-ID>")
34-
ViamClient.create_from_dial_options(dial_options)
67+
client = await ViamClient.create_from_dial_options(dial_options)
3568
3669
Args:
3770
dial_options (viam.rpc.dial.DialOptions): Required information for authorization and connection to app.

tests/test_viam_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
import os
23
from unittest.mock import patch
34
from uuid import uuid4
45

@@ -73,6 +74,35 @@ async def test_clients(self):
7374
assert client.provisioning_client._channel == channel
7475
assert client.provisioning_client._metadata == METADATA
7576

77+
async def test_client_from_env_vars(self):
78+
async with ChannelFor([]) as channel:
79+
with patch("viam.app.viam_client._dial_app") as patched_dial:
80+
patched_dial.return_value = channel
81+
with patch("viam.app.viam_client._get_access_token") as patched_auth:
82+
ACCESS_TOKEN = "MY_ACCESS_TOKEN"
83+
METADATA = {"authorization": f"Bearer {ACCESS_TOKEN}"}
84+
patched_auth.return_value = ACCESS_TOKEN
85+
86+
os.environ["VIAM_API_KEY"] = "MY_API_KEY"
87+
os.environ["VIAM_API_KEY_ID"] = str(uuid4())
88+
89+
client = await ViamClient.create_from_env_vars()
90+
91+
assert client.data_client._channel == channel
92+
assert client.data_client._metadata == METADATA
93+
94+
assert client.app_client._channel == channel
95+
assert client.app_client._metadata == METADATA
96+
97+
assert client.ml_training_client._channel == channel
98+
assert client.ml_training_client._metadata == METADATA
99+
100+
assert client.billing_client._channel == channel
101+
assert client.billing_client._metadata == METADATA
102+
103+
assert client.provisioning_client._channel == channel
104+
assert client.provisioning_client._metadata == METADATA
105+
76106
async def test_closes(self):
77107
async with ChannelFor([]) as channel:
78108
with patch.object(channel, "close"):

0 commit comments

Comments
 (0)