Skip to content

Commit cbe5ce2

Browse files
author
Rohan Jadvani
authored
Allow optional provider parameter when building authorization URL (#10)
* Allow optional provider parameter when building authorization URL * Update README * Exception -> ValueError * Add connection type enum * Fix formatting * Send provider and domain * Update documentation
1 parent 348155a commit cbe5ce2

File tree

4 files changed

+84
-7
lines changed

4 files changed

+84
-7
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ from workos import client
3030

3131
# URL to redirect a User to to initiate the WorkOS OAuth 2.0 workflow
3232
client.sso.get_authorization_url(
33-
'customer-domain.com',
34-
'my-domain.com/auth/callback',
33+
domain='customer-domain.com',
34+
redirect_uri='my-domain.com/auth/callback',
3535
state={
3636
'stuff': 'from_the_original_request',
3737
'more_things': 'ill_get_it_all_back_when_oauth_is_complete',

tests/test_sso.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
import workos
77
from workos.sso import SSO
8+
from workos.utils.connection_types import ConnectionType
89
from workos.utils.request import RESPONSE_TYPE_CODE
910

1011

1112
class TestSSO(object):
1213
@pytest.fixture(autouse=True)
1314
def setup(self, set_api_key_and_project_id):
15+
self.provider = ConnectionType.GoogleOAuth
1416
self.customer_domain = "workos.com"
1517
self.redirect_uri = "https://localhost/auth/callback"
1618
self.state = {
@@ -30,15 +32,67 @@ def mock_profile(self):
3032
"idp_id": "00u1klkowm8EGah2H357",
3133
}
3234

33-
def test_authorization_url_has_expected_query_params(self):
35+
def test_authorization_url_throws_value_error_with_missing_domain_and_provider(
36+
self,
37+
):
38+
with pytest.raises(ValueError, match=r"Incomplete arguments.*"):
39+
self.sso.get_authorization_url(
40+
redirect_uri=self.redirect_uri, state=self.state
41+
)
42+
43+
def test_authorization_url_throws_value_error_with_incorrect_provider_type(self):
44+
with pytest.raises(
45+
ValueError, match="'provider' must be of type ConnectionType"
46+
):
47+
self.sso.get_authorization_url(
48+
provider="foo", redirect_uri=self.redirect_uri, state=self.state
49+
)
50+
51+
def test_authorization_url_has_expected_query_params_with_provider(self):
52+
authorization_url = self.sso.get_authorization_url(
53+
provider=self.provider, redirect_uri=self.redirect_uri, state=self.state
54+
)
55+
56+
parsed_url = urlparse(authorization_url)
57+
58+
assert dict(parse_qsl(parsed_url.query)) == {
59+
"provider": str(self.provider),
60+
"client_id": workos.project_id,
61+
"redirect_uri": self.redirect_uri,
62+
"response_type": RESPONSE_TYPE_CODE,
63+
"state": json.dumps(self.state),
64+
}
65+
66+
def test_authorization_url_has_expected_query_params_with_domain(self):
67+
authorization_url = self.sso.get_authorization_url(
68+
domain=self.customer_domain,
69+
redirect_uri=self.redirect_uri,
70+
state=self.state,
71+
)
72+
73+
parsed_url = urlparse(authorization_url)
74+
75+
assert dict(parse_qsl(parsed_url.query)) == {
76+
"domain": self.customer_domain,
77+
"client_id": workos.project_id,
78+
"redirect_uri": self.redirect_uri,
79+
"response_type": RESPONSE_TYPE_CODE,
80+
"state": json.dumps(self.state),
81+
}
82+
83+
def test_authorization_url_has_expected_query_params_with_domain_and_provider(self):
3484
authorization_url = self.sso.get_authorization_url(
35-
self.customer_domain, self.redirect_uri, state=self.state
85+
domain=self.customer_domain,
86+
provider=self.provider,
87+
redirect_uri=self.redirect_uri,
88+
state=self.state,
3689
)
3790

3891
parsed_url = urlparse(authorization_url)
3992

4093
assert dict(parse_qsl(parsed_url.query)) == {
4194
"domain": self.customer_domain,
95+
"provider": str(self.provider),
4296
"client_id": workos.project_id,
4397
"redirect_uri": self.redirect_uri,
4498
"response_type": RESPONSE_TYPE_CODE,

workos/sso.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import workos
66
from workos.exceptions import ConfigurationException
77
from workos.resources.sso import WorkOSProfile
8+
from workos.utils.connection_types import ConnectionType
89
from workos.utils.request import RequestHelper, RESPONSE_TYPE_CODE, REQUEST_METHOD_POST
910
from workos.utils.validation import validate_api_key_and_project_id
1011

@@ -27,27 +28,41 @@ def request_helper(self):
2728
self._request_helper = RequestHelper()
2829
return self._request_helper
2930

30-
def get_authorization_url(self, domain, redirect_uri, state=None):
31+
def get_authorization_url(
32+
self, domain=None, redirect_uri=None, state=None, provider=None
33+
):
3134
"""Generate an OAuth 2.0 authorization URL.
3235
3336
The URL generated will redirect a User to the Identity Provider configured through
3437
WorkOS.
3538
36-
Args:
39+
Kwargs:
3740
domain (str) - The domain a user is associated with, as configured on WorkOS
3841
redirect_uri (str) - A valid redirect URI, as specified on WorkOS
3942
state (dict) - A dict passed to WorkOS, that'd be preserved through the authentication workflow, passed
4043
back as a query parameter
44+
provider (str) - Authentication service provider descriptor
4145
4246
Returns:
4347
str: URL to redirect a User to to begin the OAuth workflow with WorkOS
4448
"""
4549
params = {
46-
"domain": domain,
4750
"client_id": workos.project_id,
4851
"redirect_uri": redirect_uri,
4952
"response_type": RESPONSE_TYPE_CODE,
5053
}
54+
55+
if domain is None and provider is None:
56+
raise ValueError(
57+
"Incomplete arguments. Need to specify either a 'domain' or 'provider'"
58+
)
59+
if provider is not None:
60+
if not isinstance(provider, ConnectionType):
61+
raise ValueError("'provider' must be of type ConnectionType")
62+
params["provider"] = str(provider)
63+
if domain is not None:
64+
params["domain"] = domain
65+
5166
if state is not None:
5267
params["state"] = json.dumps(state)
5368

workos/utils/connection_types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from enum import Enum
2+
3+
4+
class ConnectionType(Enum):
5+
ADFSSAML = "ADFSSAML"
6+
AzureSAML = "AzureSAML"
7+
GoogleOAuth = "GoogleOAuth"
8+
OktaSAML = "OktaSAML"

0 commit comments

Comments
 (0)