Skip to content

Commit dfac8a1

Browse files
[refac]refactor http worker to support different session (#467)
Tested with ``` address = "https://127.0.0.1:8443" proxyConfig = ProxyConfig(NetworkProtocol.HTTP, address, authentication_mode="tls", tls_ca_cert_path="x509_test_certs/root/certs/root_ca.crt", tls_client_cert_path="x509_test_certs/intermediate/certs/client.crt", tls_client_key_path="x509_test_certs/intermediate/private/client.key") statsig.initialize(secret_key, StatsigOptions(proxy_configs={NetworkEndpoint.DOWNLOAD_CONFIG_SPECS: proxyConfig})) user = StatsigUser('lrs_grpc_user') i = 0 while i<2 : i = i+1 config_names = await get_config_names() print("Config names to check") print(config_names) for gate in config_names["feature_gates"]: statsig.check_gate(user, gate) await asyncio.sleep(IDLE_INTERVAL / 1000) ``` against sfp enforce tls And verified it can work
1 parent 56de895 commit dfac8a1

File tree

2 files changed

+71
-48
lines changed

2 files changed

+71
-48
lines changed

statsig/http_worker.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .sdk_configs import _SDK_Configs
2222
from .statsig_context import InitContext
2323
from .statsig_error_boundary import _StatsigErrorBoundary
24-
from .statsig_options import StatsigOptions, STATSIG_API, STATSIG_CDN, AuthenticationMode
24+
from .statsig_options import ProxyConfig, StatsigOptions, STATSIG_API, STATSIG_CDN, AuthenticationMode
2525
from .grpc_websocket_worker import load_credential_from_file
2626

2727
REQUEST_TIMEOUT = 20
@@ -51,50 +51,8 @@ def __init__(
5151
self.__diagnostics = diagnostics
5252
self.__request_count = 0
5353
self.__temp_cert_files: List[str] = []
54-
self.__request_session = self.__init_session(options)
55-
56-
def __init_session(self, options: StatsigOptions) -> requests.Session:
57-
session = requests.Session()
58-
http_proxy_config = None
59-
for _, config in options.proxy_configs.items():
60-
if config.protocol == NetworkProtocol.HTTP:
61-
if config.authentication_mode in [AuthenticationMode.TLS, AuthenticationMode.MTLS]:
62-
http_proxy_config = config
63-
break
64-
if http_proxy_config is None:
65-
return session
66-
try:
67-
if http_proxy_config.authentication_mode == AuthenticationMode.TLS:
68-
ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate")
69-
if ca_cert:
70-
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file:
71-
ca_file.write(ca_cert)
72-
session.verify = ca_file.name
73-
self.__temp_cert_files.append(ca_file.name)
74-
globals.logger.log_process("HTTP Worker", "Connecting using an TLS secure channel for HTTP")
75-
elif http_proxy_config.authentication_mode == AuthenticationMode.MTLS:
76-
client_cert = load_credential_from_file(http_proxy_config.tls_client_cert_path, "TLS client certificate")
77-
client_key = load_credential_from_file(http_proxy_config.tls_client_key_path, "TLS client key")
78-
ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate")
79-
if client_cert and client_key and ca_cert:
80-
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as cert_file:
81-
cert_file.write(client_cert)
82-
cert_path = cert_file.name
83-
self.__temp_cert_files.append(cert_path)
84-
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.key') as key_file:
85-
key_file.write(client_key)
86-
key_path = key_file.name
87-
self.__temp_cert_files.append(key_path)
88-
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file:
89-
ca_file.write(ca_cert)
90-
ca_path = ca_file.name
91-
self.__temp_cert_files.append(ca_path)
92-
session.cert = (cert_path, key_path)
93-
session.verify = ca_path
94-
globals.logger.log_process("HTTP Worker", "Connecting using an mTLS secure channel for HTTP")
95-
except Exception as e:
96-
self.__error_boundary.log_exception("http_worker:init_session", e)
97-
return session
54+
self.__statsig_request_session = requests.Session()
55+
self.__request_session = requests.Session()
9856

9957
def is_pull_worker(self) -> bool:
10058
return True
@@ -138,6 +96,7 @@ def get_dcs_fallback(
13896
init_timeout=init_timeout,
13997
log_on_exception=log_on_exception,
14098
tag="download_config_specs",
99+
useStatsigClient = True,
141100
)
142101
self._context.source_api = STATSIG_CDN
143102
if response is not None and self._is_success_code(response.status_code):
@@ -175,6 +134,7 @@ def get_id_lists_fallback(
175134
log_on_exception=log_on_exception,
176135
init_timeout=init_timeout,
177136
tag="get_id_lists",
137+
useStatsigClient = True,
178138
)
179139
if response is not None and self._is_success_code(response.status_code):
180140
return on_complete(response.data, None)
@@ -220,6 +180,39 @@ def shutdown(self) -> None:
220180
pass
221181
self.__temp_cert_files.clear()
222182

183+
def authenticate_request_session(self, http_proxy_config: ProxyConfig):
184+
try:
185+
if http_proxy_config.authentication_mode == AuthenticationMode.TLS:
186+
ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate")
187+
if ca_cert:
188+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file:
189+
ca_file.write(ca_cert)
190+
self.__request_session.verify = ca_file.name
191+
self.__temp_cert_files.append(ca_file.name)
192+
globals.logger.log_process("HTTP Worker", "Connecting using an TLS secure channel for HTTP")
193+
elif http_proxy_config.authentication_mode == AuthenticationMode.MTLS:
194+
client_cert = load_credential_from_file(http_proxy_config.tls_client_cert_path, "TLS client certificate")
195+
client_key = load_credential_from_file(http_proxy_config.tls_client_key_path, "TLS client key")
196+
ca_cert = load_credential_from_file(http_proxy_config.tls_ca_cert_path, "TLS CA certificate")
197+
if client_cert and client_key and ca_cert:
198+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as cert_file:
199+
cert_file.write(client_cert)
200+
cert_path = cert_file.name
201+
self.__temp_cert_files.append(cert_path)
202+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.key') as key_file:
203+
key_file.write(client_key)
204+
key_path = key_file.name
205+
self.__temp_cert_files.append(key_path)
206+
with tempfile.NamedTemporaryFile(mode='wb', delete=False, suffix='.pem') as ca_file:
207+
ca_file.write(ca_cert)
208+
ca_path = ca_file.name
209+
self.__temp_cert_files.append(ca_path)
210+
self.__request_session.cert = (cert_path, key_path)
211+
self.__request_session.verify = ca_path
212+
globals.logger.log_process("HTTP Worker", "Connecting using an mTLS secure channel for HTTP")
213+
except Exception as e:
214+
self.__error_boundary.log_exception("http_worker:init_session", e)
215+
223216
def _run_task_for_initialize(
224217
self, task, timeout
225218
) -> Tuple[Optional[Any], Optional[Exception]]:
@@ -239,9 +232,10 @@ def _post_request(
239232
init_timeout=None,
240233
zipped=None,
241234
tag=None,
235+
useStatsigClient=False,
242236
):
243237
return self._request(
244-
"POST", url, headers, payload, log_on_exception, init_timeout, zipped, tag
238+
"POST", url, headers, payload, log_on_exception, init_timeout, zipped, tag, useStatsigClient
245239
)
246240

247241
def _get_request(
@@ -253,6 +247,7 @@ def _get_request(
253247
zipped=None,
254248
tag=None,
255249
get_text_value_only=False,
250+
useStatsigClient=False,
256251
):
257252
return self._request(
258253
"GET",
@@ -264,6 +259,7 @@ def _get_request(
264259
zipped,
265260
tag,
266261
get_text_value_only,
262+
useStatsigClient
267263
)
268264

269265
def _request(
@@ -277,6 +273,7 @@ def _request(
277273
zipped=False,
278274
tag=None,
279275
get_text_value_only=False,
276+
useStatsigClient = False,
280277
) -> RequestResult:
281278
if self.__local_mode:
282279
globals.logger.debug("Using local mode. Dropping network request")
@@ -312,6 +309,7 @@ def _request(
312309
timeout,
313310
init_timeout is not None,
314311
get_text_value_only,
312+
useStatsigClient
315313
)
316314

317315
if create_marker is not None:
@@ -333,10 +331,12 @@ def _run_request_with_strict_timeout(
333331
timeout,
334332
for_initialize=False,
335333
get_text_value_only=False,
334+
useStatsigClient=False
336335
) -> RequestResult:
337336
def request_task():
338337
try:
339-
with self.__request_session.request(
338+
request_session = self.__statsig_request_session if useStatsigClient else self.__request_session
339+
with request_session.request(
340340
method,
341341
url,
342342
data=payload,

statsig/statsig_network.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import importlib
22
import threading
3-
from typing import Any, Callable, Optional
3+
from typing import Any, Callable, Optional, cast
44

55
from . import globals
66
from .diagnostics import Diagnostics
@@ -17,6 +17,7 @@
1717
from .statsig_error_boundary import _StatsigErrorBoundary
1818
from .statsig_options import (
1919
DEFAULT_RULESET_SYNC_INTERVAL,
20+
AuthenticationMode,
2021
StatsigOptions,
2122
STATSIG_CDN,
2223
ProxyConfig,
@@ -100,11 +101,33 @@ def __init__(
100101
self.load_grpc_worker(endpoint, config)
101102
elif protocol == NetworkProtocol.GRPC_WEBSOCKET:
102103
self.load_grpc_websocket_worker(endpoint, config)
104+
elif protocol == NetworkProtocol.HTTP:
105+
if config.authentication_mode in [AuthenticationMode.TLS, AuthenticationMode.MTLS]:
106+
self.load_authenticated_http_worker(endpoint, config)
103107

104108
self._background_download_configs_from_statsig = None
105109
self._background_download_id_lists_from_statsig = None
106110
self._streaming_fallback: Optional[StreamingFallback] = None
107111

112+
def load_authenticated_http_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig):
113+
if endpoint == NetworkEndpoint.ALL:
114+
http_worker = cast(HttpWorker, self.http_worker)
115+
http_worker.authenticate_request_session(config)
116+
self.log_event_worker = http_worker
117+
self.id_list_worker = http_worker
118+
self.dcs_worker = http_worker
119+
return
120+
121+
worker = HttpWorker(self.sdk_key, self.options, self.statsig_metadata, self.error_boundary, self.diagnostics, self.context)
122+
worker.authenticate_request_session(config)
123+
if endpoint == NetworkEndpoint.DOWNLOAD_CONFIG_SPECS:
124+
self.dcs_worker = worker
125+
elif endpoint == NetworkEndpoint.GET_ID_LISTS:
126+
self.id_list_worker = worker
127+
elif endpoint == NetworkEndpoint.LOG_EVENT:
128+
self.log_event_worker = worker
129+
130+
108131
def load_grpc_websocket_worker(self, endpoint: NetworkEndpoint, config: ProxyConfig):
109132
grpc_worker_module = importlib.import_module("statsig.grpc_websocket_worker")
110133
grpc_webhook_worker_class = getattr(grpc_worker_module, 'GRPCWebsocketWorker')

0 commit comments

Comments
 (0)