1111
1212
1313class TestSSO (object ):
14- @pytest .fixture (autouse = True )
15- def setup (self , set_api_key_and_project_id ):
14+ @pytest .fixture
15+ def setup_with_client_id (self , set_api_key_and_client_id ):
16+ self .provider = ConnectionType .GoogleOAuth
17+ self .customer_domain = "workos.com"
18+ self .redirect_uri = "https://localhost/auth/callback"
19+ self .state = json .dumps ({"things" : "with_stuff" ,})
20+
21+ self .sso = SSO ()
22+
23+ @pytest .fixture
24+ def setup_with_project_id (self , set_api_key_and_project_id ):
1625 self .provider = ConnectionType .GoogleOAuth
1726 self .customer_domain = "workos.com"
1827 self .redirect_uri = "https://localhost/auth/callback"
@@ -91,22 +100,26 @@ def mock_connections(self):
91100 }
92101
93102 def test_authorization_url_throws_value_error_with_missing_domain_and_provider (
94- self ,
103+ self , setup_with_client_id
95104 ):
96105 with pytest .raises (ValueError , match = r"Incomplete arguments.*" ):
97106 self .sso .get_authorization_url (
98107 redirect_uri = self .redirect_uri , state = self .state
99108 )
100109
101- def test_authorization_url_throws_value_error_with_incorrect_provider_type (self ):
110+ def test_authorization_url_throws_value_error_with_incorrect_provider_type (
111+ self , setup_with_client_id
112+ ):
102113 with pytest .raises (
103114 ValueError , match = "'provider' must be of type ConnectionType"
104115 ):
105116 self .sso .get_authorization_url (
106117 provider = "foo" , redirect_uri = self .redirect_uri , state = self .state
107118 )
108119
109- def test_authorization_url_has_expected_query_params_with_provider (self ):
120+ def test_authorization_url_has_expected_query_params_with_provider (
121+ self , setup_with_client_id
122+ ):
110123 authorization_url = self .sso .get_authorization_url (
111124 provider = self .provider , redirect_uri = self .redirect_uri , state = self .state
112125 )
@@ -115,13 +128,15 @@ def test_authorization_url_has_expected_query_params_with_provider(self):
115128
116129 assert dict (parse_qsl (parsed_url .query )) == {
117130 "provider" : str (self .provider .value ),
118- "client_id" : workos .project_id ,
131+ "client_id" : workos .client_id ,
119132 "redirect_uri" : self .redirect_uri ,
120133 "response_type" : RESPONSE_TYPE_CODE ,
121134 "state" : self .state ,
122135 }
123136
124- def test_authorization_url_has_expected_query_params_with_domain (self ):
137+ def test_authorization_url_has_expected_query_params_with_domain (
138+ self , setup_with_client_id
139+ ):
125140 authorization_url = self .sso .get_authorization_url (
126141 domain = self .customer_domain ,
127142 redirect_uri = self .redirect_uri ,
@@ -132,13 +147,15 @@ def test_authorization_url_has_expected_query_params_with_domain(self):
132147
133148 assert dict (parse_qsl (parsed_url .query )) == {
134149 "domain" : self .customer_domain ,
135- "client_id" : workos .project_id ,
150+ "client_id" : workos .client_id ,
136151 "redirect_uri" : self .redirect_uri ,
137152 "response_type" : RESPONSE_TYPE_CODE ,
138153 "state" : self .state ,
139154 }
140155
141- def test_authorization_url_has_expected_query_params_with_domain_and_provider (self ):
156+ def test_authorization_url_has_expected_query_params_with_domain_and_provider (
157+ self , setup_with_client_id
158+ ):
142159 authorization_url = self .sso .get_authorization_url (
143160 domain = self .customer_domain ,
144161 provider = self .provider ,
@@ -151,14 +168,36 @@ def test_authorization_url_has_expected_query_params_with_domain_and_provider(se
151168 assert dict (parse_qsl (parsed_url .query )) == {
152169 "domain" : self .customer_domain ,
153170 "provider" : str (self .provider .value ),
154- "client_id" : workos .project_id ,
171+ "client_id" : workos .client_id ,
155172 "redirect_uri" : self .redirect_uri ,
156173 "response_type" : RESPONSE_TYPE_CODE ,
157174 "state" : self .state ,
158175 }
159176
177+ def test_authorization_url_supports_project_id_with_deprecation_warning (
178+ self , setup_with_project_id
179+ ):
180+ with pytest .deprecated_call ():
181+ authorization_url = self .sso .get_authorization_url (
182+ domain = self .customer_domain ,
183+ provider = self .provider ,
184+ redirect_uri = self .redirect_uri ,
185+ state = self .state ,
186+ )
187+
188+ parsed_url = urlparse (authorization_url )
189+
190+ assert dict (parse_qsl (parsed_url .query )) == {
191+ "domain" : self .customer_domain ,
192+ "provider" : str (self .provider .value ),
193+ "client_id" : workos .project_id ,
194+ "redirect_uri" : self .redirect_uri ,
195+ "response_type" : RESPONSE_TYPE_CODE ,
196+ "state" : self .state ,
197+ }
198+
160199 def test_get_profile_returns_expected_workosprofile_object (
161- self , mock_profile , mock_request_method
200+ self , setup_with_client_id , mock_profile , mock_request_method
162201 ):
163202 response_dict = {
164203 "profile" : {
@@ -185,7 +224,9 @@ def test_get_profile_returns_expected_workosprofile_object(
185224
186225 assert profile .to_dict () == mock_profile
187226
188- def test_create_connection (self , mock_request_method , mock_connection ):
227+ def test_create_connection (
228+ self , setup_with_client_id , mock_request_method , mock_connection
229+ ):
189230 response_dict = {
190231 "object" : "connection" ,
191232 "id" : mock_connection ["id" ],
@@ -208,7 +249,9 @@ def test_create_connection(self, mock_request_method, mock_connection):
208249 connection = self .sso .create_connection ("draft_conn_id" )
209250 assert connection == response_dict
210251
211- def test_get_connection (self , mock_connection , mock_request_method ):
252+ def test_get_connection (
253+ self , setup_with_client_id , mock_connection , mock_request_method
254+ ):
212255 mock_response = Response ()
213256 mock_response .status_code = 200
214257 mock_response .response_dict = mock_connection
@@ -217,7 +260,9 @@ def test_get_connection(self, mock_connection, mock_request_method):
217260 assert response .status_code == 200
218261 assert response .response_dict == mock_connection
219262
220- def test_list_connections (self , mock_connections , mock_request_method ):
263+ def test_list_connections (
264+ self , setup_with_client_id , mock_connections , mock_request_method
265+ ):
221266 mock_response = Response ()
222267 mock_response .status_code = 200
223268 mock_response .response_dict = mock_connections
@@ -226,7 +271,7 @@ def test_list_connections(self, mock_connections, mock_request_method):
226271 assert response .status_code == 200
227272 assert response .response_dict == mock_connections
228273
229- def test_delete_connection (self , mock_request_method ):
274+ def test_delete_connection (self , setup_with_client_id , mock_request_method ):
230275 mock_response = Response ()
231276 mock_response .status_code = 200
232277 mock_request_method ("delete" , mock_response , 200 )
0 commit comments