|
6 | 6 | from typing import Any, Dict, Optional, Union |
7 | 7 |
|
8 | 8 | from ip3country import CountryLookup |
9 | | -from ua_parser import user_agent_parser |
10 | 9 |
|
11 | 10 | from .client_initialize_formatter import ClientInitializeResponseFormatter |
12 | 11 | from .config_evaluation import _ConfigEvaluation |
|
17 | 16 | from .utils import HashingAlgorithm, JSONValue, sha256_hash |
18 | 17 |
|
19 | 18 |
|
| 19 | +def load_ua_parser(): |
| 20 | + try: |
| 21 | + from ua_parser import user_agent_parser # pylint: disable=import-outside-toplevel |
| 22 | + return user_agent_parser |
| 23 | + except ImportError: |
| 24 | + logger.warning("ua_parser module not available") |
| 25 | + return None |
| 26 | + |
| 27 | + |
20 | 28 | class _Evaluator: |
21 | | - def __init__(self, spec_store: _SpecStore, global_custom_fields: Optional[Dict[str, JSONValue]]): |
| 29 | + def __init__(self, spec_store: _SpecStore, global_custom_fields: Optional[Dict[str, JSONValue]], |
| 30 | + disable_ua_parser: bool = False, disable_country_lookup: bool = False): |
22 | 31 | self._spec_store = spec_store |
23 | 32 | self._global_custom_fields = global_custom_fields |
| 33 | + self._disable_ua_parser = disable_ua_parser |
| 34 | + self._disable_country_lookup = disable_country_lookup |
24 | 35 |
|
25 | 36 | self._country_lookup: Optional[CountryLookup] = None |
| 37 | + self._ua_parser: Optional[Any] = None # Will be the ua_parser.user_agent_parser module |
26 | 38 | self._gate_overrides: Dict[str, dict] = {} |
27 | 39 | self._config_overrides: Dict[str, dict] = {} |
28 | 40 | self._layer_overrides: Dict[str, dict] = {} |
29 | 41 |
|
30 | 42 | def initialize(self): |
31 | | - self._country_lookup = CountryLookup() |
| 43 | + if not self._disable_country_lookup: |
| 44 | + self._country_lookup = CountryLookup() |
| 45 | + if not self._disable_ua_parser: |
| 46 | + self._ua_parser = load_ua_parser() |
32 | 47 |
|
33 | 48 | def override_gate(self, gate, value, user_id=None): |
34 | 49 | gate_overrides = self._gate_overrides.get(gate) |
@@ -120,6 +135,19 @@ def _create_evaluation_details(self, |
120 | 135 | return EvaluationDetails( |
121 | 136 | self._spec_store.last_update_time(), self._spec_store.initial_update_time, source, reason) |
122 | 137 |
|
| 138 | + |
| 139 | + def _update_evaluation_details_if_needed(self, end_result: _ConfigEvaluation) -> None: |
| 140 | + current_details = end_result.evaluation_details |
| 141 | + |
| 142 | + if current_details is None: |
| 143 | + end_result.evaluation_details = self._create_evaluation_details() |
| 144 | + return |
| 145 | + |
| 146 | + if current_details.source in [DataSource.UA_NOT_LOADED, DataSource.COUNTRY_NOT_LOADED]: |
| 147 | + return |
| 148 | + |
| 149 | + end_result.evaluation_details = self._create_evaluation_details() |
| 150 | + |
123 | 151 | def __lookup_gate_override(self, user, gate): |
124 | 152 | gate_overrides = self._gate_overrides.get(gate) |
125 | 153 | if gate_overrides is None: |
@@ -336,7 +364,11 @@ def __finalize_eval_result(self, config, end_result, did_pass, rule, is_nested=F |
336 | 364 | if config.get("version", None) is not None: |
337 | 365 | end_result.version = config.get("version") |
338 | 366 |
|
339 | | - end_result.evaluation_details = self._create_evaluation_details() |
| 367 | + if end_result.evaluation_details is not None and end_result.evaluation_details.source not in ( |
| 368 | + DataSource.UA_NOT_LOADED, DataSource.COUNTRY_NOT_LOADED): |
| 369 | + end_result.evaluation_details = self._create_evaluation_details() |
| 370 | + |
| 371 | + self._update_evaluation_details_if_needed(end_result) |
340 | 372 |
|
341 | 373 | if rule is None: |
342 | 374 | end_result.json_value = config.get("defaultValue", {}) |
@@ -436,16 +468,22 @@ def __evaluate_condition(self, user, condition, end_result, sampling_rate=None): |
436 | 468 | if value is None: |
437 | 469 | ip = self.__get_from_user(user, "ip") |
438 | 470 | if ip is not None and field == "country": |
439 | | - if not self._country_lookup: |
440 | | - self._country_lookup = CountryLookup() |
441 | | - value = self._country_lookup.lookupStr(ip) |
| 471 | + if self._disable_country_lookup: |
| 472 | + logger.warning("Country lookup is disabled but was attempted during evaluation") |
| 473 | + end_result.evaluation_details = self._create_evaluation_details( |
| 474 | + EvaluationReason.none, DataSource.COUNTRY_NOT_LOADED) |
| 475 | + value = None |
| 476 | + else: |
| 477 | + if not self._country_lookup: |
| 478 | + self._country_lookup = CountryLookup() |
| 479 | + value = self._country_lookup.lookupStr(ip) |
442 | 480 | if value is None: |
443 | 481 | end_result.analytical_condition = sampling_rate is None |
444 | 482 | return False |
445 | 483 | elif type == "UA_BASED": |
446 | 484 | value = self.__get_from_user(user, field) |
447 | 485 | if value is None: |
448 | | - value = self.__get_from_user_agent(user, field) |
| 486 | + value = self.__get_from_user_agent(user, field, end_result) |
449 | 487 | elif type == "USER_FIELD": |
450 | 488 | value = self.__get_from_user(user, field) |
451 | 489 | elif type == "CURRENT_TIME": |
@@ -754,11 +792,26 @@ def __get_value_as_float(self, input): |
754 | 792 | return None |
755 | 793 | return float(input) |
756 | 794 |
|
757 | | - def __get_from_user_agent(self, user, field): |
| 795 | + def __get_from_user_agent(self, user, field, end_result): |
| 796 | + if self._disable_ua_parser: |
| 797 | + logger.warning("UA parser is disabled but was attempted during evaluation") |
| 798 | + end_result.evaluation_details = self._create_evaluation_details(EvaluationReason.none, |
| 799 | + DataSource.UA_NOT_LOADED) |
| 800 | + return None |
758 | 801 | ua = self.__get_from_user(user, "userAgent") |
759 | 802 | if ua is None: |
760 | 803 | return None |
761 | | - parsed = user_agent_parser.Parse(ua) |
| 804 | + |
| 805 | + try: |
| 806 | + if self._ua_parser is None: |
| 807 | + self._ua_parser = load_ua_parser() |
| 808 | + if self._ua_parser is None: |
| 809 | + return None |
| 810 | + parsed = self._ua_parser.Parse(ua) |
| 811 | + except Exception as e: |
| 812 | + logger.warning(f"Error parsing user agent: {e}") |
| 813 | + return None |
| 814 | + |
762 | 815 | field = field.lower() |
763 | 816 | if field in ("osname", "os_name"): |
764 | 817 | return parsed.get("os", {"family": None}).get("family") |
|
0 commit comments