|
3 | 3 | import os |
4 | 4 | import re |
5 | 5 | import sys |
6 | | -from typing import Optional |
| 6 | +from typing import Dict, Optional, Tuple |
7 | 7 |
|
8 | 8 | import requests |
9 | 9 | import structlog |
@@ -172,54 +172,56 @@ async def run_test(self, test: dict, test_headers: dict) -> bool: |
172 | 172 | self.failed_tests.append(test_name) |
173 | 173 | return False |
174 | 174 |
|
175 | | - async def run_tests( |
176 | | - self, |
177 | | - testcases_file: str, |
178 | | - providers: Optional[list[str]] = None, |
179 | | - test_names: Optional[list[str]] = None, |
180 | | - ) -> bool: |
181 | | - with open(testcases_file, "r") as f: |
182 | | - tests = yaml.safe_load(f) |
| 175 | + async def _get_testcases( |
| 176 | + self, testcases_dict: Dict, test_names: Optional[list[str]] = None |
| 177 | + ) -> Dict: |
| 178 | + testcases: Dict[str, Dict[str, str]] = testcases_dict["testcases"] |
183 | 179 |
|
184 | | - headers = tests["headers"] |
185 | | - testcases = tests["testcases"] |
186 | | - |
187 | | - if providers or test_names: |
| 180 | + # Filter testcases by provider and test names |
| 181 | + if test_names: |
188 | 182 | filtered_testcases = {} |
189 | 183 |
|
| 184 | + # Iterate over the original testcases and only keep the ones that match the |
| 185 | + # specified test names |
190 | 186 | for test_id, test_data in testcases.items(): |
191 | | - if providers: |
192 | | - if test_data.get("provider", "").lower() not in [p.lower() for p in providers]: |
193 | | - continue |
194 | | - |
195 | | - if test_names: |
196 | | - if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: |
197 | | - continue |
| 187 | + if test_data.get("name", "").lower() not in [t.lower() for t in test_names]: |
| 188 | + continue |
198 | 189 |
|
199 | 190 | filtered_testcases[test_id] = test_data |
200 | 191 |
|
201 | 192 | testcases = filtered_testcases |
| 193 | + return testcases |
202 | 194 |
|
203 | | - if not testcases: |
204 | | - filter_msg = [] |
205 | | - if providers: |
206 | | - filter_msg.append(f"providers: {', '.join(providers)}") |
207 | | - if test_names: |
208 | | - filter_msg.append(f"test names: {', '.join(test_names)}") |
209 | | - logger.warning(f"No tests found for {' and '.join(filter_msg)}") |
210 | | - return True # No tests is not a failure |
| 195 | + async def _setup( |
| 196 | + self, testcases_file: str, test_names: Optional[list[str]] = None |
| 197 | + ) -> Tuple[Dict, Dict]: |
| 198 | + with open(testcases_file, "r") as f: |
| 199 | + testcases_dict = yaml.safe_load(f) |
| 200 | + |
| 201 | + headers = testcases_dict["headers"] |
| 202 | + testcases = await self._get_testcases(testcases_dict, test_names) |
| 203 | + return headers, testcases |
| 204 | + |
| 205 | + async def run_tests( |
| 206 | + self, |
| 207 | + testcases_file: str, |
| 208 | + provider: str, |
| 209 | + test_names: Optional[list[str]] = None, |
| 210 | + ) -> bool: |
| 211 | + headers, testcases = await self._setup(testcases_file, test_names) |
| 212 | + |
| 213 | + if not testcases: |
| 214 | + logger.warning( |
| 215 | + f"No tests found for provider {provider} in file: {testcases_file} " |
| 216 | + f"and specific testcases: {test_names}" |
| 217 | + ) |
| 218 | + return True # No tests is not a failure |
211 | 219 |
|
212 | 220 | test_count = len(testcases) |
213 | | - filter_msg = [] |
214 | | - if providers: |
215 | | - filter_msg.append(f"providers: {', '.join(providers)}") |
| 221 | + logging_msg = f"Running {test_count} tests for provider {provider}" |
216 | 222 | if test_names: |
217 | | - filter_msg.append(f"test names: {', '.join(test_names)}") |
218 | | - |
219 | | - logger.info( |
220 | | - f"Running {test_count} tests" |
221 | | - + (f" for {' and '.join(filter_msg)}" if filter_msg else "") |
222 | | - ) |
| 223 | + logging_msg += f" and test names: {', '.join(test_names)}" |
| 224 | + logger.info(logging_msg) |
223 | 225 |
|
224 | 226 | all_tests_passed = True |
225 | 227 | for test_id, test_data in testcases.items(): |
@@ -285,10 +287,12 @@ async def main(): |
285 | 287 | logger.warning(f"No testcases.yaml found for provider {provider}") |
286 | 288 | continue |
287 | 289 |
|
| 290 | + # Run tests for the provider. The provider has already been selected when |
| 291 | + # reading the testcases.yaml file. |
288 | 292 | logger.info(f"Running tests for provider: {provider}") |
289 | 293 | provider_tests_passed = await test_runner.run_tests( |
290 | 294 | provider_test_file, |
291 | | - providers=[provider], # Only run tests for current provider |
| 295 | + provider=provider, |
292 | 296 | test_names=test_names, |
293 | 297 | ) |
294 | 298 | all_tests_passed = all_tests_passed and provider_tests_passed |
|
0 commit comments