2121from .sdk_configs import _SDK_Configs
2222from .statsig_context import InitContext
2323from .statsig_error_boundary import _StatsigErrorBoundary
24- from .statsig_options import StatsigOptions , STATSIG_API , STATSIG_CDN , AuthenticationMode
24+ from .statsig_options import ProxyConfig , StatsigOptions , STATSIG_API , STATSIG_CDN , AuthenticationMode
2525from .grpc_websocket_worker import load_credential_from_file
2626
2727REQUEST_TIMEOUT = 20
@@ -51,50 +51,8 @@ def __init__(
5151 self .__diagnostics = diagnostics
5252 self .__request_count = 0
5353 self .__temp_cert_files : List [str ] = []
54- self .__request_session = self .__init_session (options )
55-
56- def __init_session (self , options : StatsigOptions ) -> requests .Session :
57- session = requests .Session ()
58- http_proxy_config = None
59- for _ , config in options .proxy_configs .items ():
60- if config .protocol == NetworkProtocol .HTTP :
61- if config .authentication_mode in [AuthenticationMode .TLS , AuthenticationMode .MTLS ]:
62- http_proxy_config = config
63- break
64- if http_proxy_config is None :
65- return session
66- try :
67- if http_proxy_config .authentication_mode == AuthenticationMode .TLS :
68- ca_cert = load_credential_from_file (http_proxy_config .tls_ca_cert_path , "TLS CA certificate" )
69- if ca_cert :
70- with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as ca_file :
71- ca_file .write (ca_cert )
72- session .verify = ca_file .name
73- self .__temp_cert_files .append (ca_file .name )
74- globals .logger .log_process ("HTTP Worker" , "Connecting using an TLS secure channel for HTTP" )
75- elif http_proxy_config .authentication_mode == AuthenticationMode .MTLS :
76- client_cert = load_credential_from_file (http_proxy_config .tls_client_cert_path , "TLS client certificate" )
77- client_key = load_credential_from_file (http_proxy_config .tls_client_key_path , "TLS client key" )
78- ca_cert = load_credential_from_file (http_proxy_config .tls_ca_cert_path , "TLS CA certificate" )
79- if client_cert and client_key and ca_cert :
80- with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as cert_file :
81- cert_file .write (client_cert )
82- cert_path = cert_file .name
83- self .__temp_cert_files .append (cert_path )
84- with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.key' ) as key_file :
85- key_file .write (client_key )
86- key_path = key_file .name
87- self .__temp_cert_files .append (key_path )
88- with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as ca_file :
89- ca_file .write (ca_cert )
90- ca_path = ca_file .name
91- self .__temp_cert_files .append (ca_path )
92- session .cert = (cert_path , key_path )
93- session .verify = ca_path
94- globals .logger .log_process ("HTTP Worker" , "Connecting using an mTLS secure channel for HTTP" )
95- except Exception as e :
96- self .__error_boundary .log_exception ("http_worker:init_session" , e )
97- return session
54+ self .__statsig_request_session = requests .Session ()
55+ self .__request_session = requests .Session ()
9856
9957 def is_pull_worker (self ) -> bool :
10058 return True
@@ -138,6 +96,7 @@ def get_dcs_fallback(
13896 init_timeout = init_timeout ,
13997 log_on_exception = log_on_exception ,
14098 tag = "download_config_specs" ,
99+ useStatsigClient = True ,
141100 )
142101 self ._context .source_api = STATSIG_CDN
143102 if response is not None and self ._is_success_code (response .status_code ):
@@ -175,6 +134,7 @@ def get_id_lists_fallback(
175134 log_on_exception = log_on_exception ,
176135 init_timeout = init_timeout ,
177136 tag = "get_id_lists" ,
137+ useStatsigClient = True ,
178138 )
179139 if response is not None and self ._is_success_code (response .status_code ):
180140 return on_complete (response .data , None )
@@ -220,6 +180,39 @@ def shutdown(self) -> None:
220180 pass
221181 self .__temp_cert_files .clear ()
222182
183+ def authenticate_request_session (self , http_proxy_config : ProxyConfig ):
184+ try :
185+ if http_proxy_config .authentication_mode == AuthenticationMode .TLS :
186+ ca_cert = load_credential_from_file (http_proxy_config .tls_ca_cert_path , "TLS CA certificate" )
187+ if ca_cert :
188+ with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as ca_file :
189+ ca_file .write (ca_cert )
190+ self .__request_session .verify = ca_file .name
191+ self .__temp_cert_files .append (ca_file .name )
192+ globals .logger .log_process ("HTTP Worker" , "Connecting using an TLS secure channel for HTTP" )
193+ elif http_proxy_config .authentication_mode == AuthenticationMode .MTLS :
194+ client_cert = load_credential_from_file (http_proxy_config .tls_client_cert_path , "TLS client certificate" )
195+ client_key = load_credential_from_file (http_proxy_config .tls_client_key_path , "TLS client key" )
196+ ca_cert = load_credential_from_file (http_proxy_config .tls_ca_cert_path , "TLS CA certificate" )
197+ if client_cert and client_key and ca_cert :
198+ with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as cert_file :
199+ cert_file .write (client_cert )
200+ cert_path = cert_file .name
201+ self .__temp_cert_files .append (cert_path )
202+ with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.key' ) as key_file :
203+ key_file .write (client_key )
204+ key_path = key_file .name
205+ self .__temp_cert_files .append (key_path )
206+ with tempfile .NamedTemporaryFile (mode = 'wb' , delete = False , suffix = '.pem' ) as ca_file :
207+ ca_file .write (ca_cert )
208+ ca_path = ca_file .name
209+ self .__temp_cert_files .append (ca_path )
210+ self .__request_session .cert = (cert_path , key_path )
211+ self .__request_session .verify = ca_path
212+ globals .logger .log_process ("HTTP Worker" , "Connecting using an mTLS secure channel for HTTP" )
213+ except Exception as e :
214+ self .__error_boundary .log_exception ("http_worker:init_session" , e )
215+
223216 def _run_task_for_initialize (
224217 self , task , timeout
225218 ) -> Tuple [Optional [Any ], Optional [Exception ]]:
@@ -239,9 +232,10 @@ def _post_request(
239232 init_timeout = None ,
240233 zipped = None ,
241234 tag = None ,
235+ useStatsigClient = False ,
242236 ):
243237 return self ._request (
244- "POST" , url , headers , payload , log_on_exception , init_timeout , zipped , tag
238+ "POST" , url , headers , payload , log_on_exception , init_timeout , zipped , tag , useStatsigClient
245239 )
246240
247241 def _get_request (
@@ -253,6 +247,7 @@ def _get_request(
253247 zipped = None ,
254248 tag = None ,
255249 get_text_value_only = False ,
250+ useStatsigClient = False ,
256251 ):
257252 return self ._request (
258253 "GET" ,
@@ -264,6 +259,7 @@ def _get_request(
264259 zipped ,
265260 tag ,
266261 get_text_value_only ,
262+ useStatsigClient
267263 )
268264
269265 def _request (
@@ -277,6 +273,7 @@ def _request(
277273 zipped = False ,
278274 tag = None ,
279275 get_text_value_only = False ,
276+ useStatsigClient = False ,
280277 ) -> RequestResult :
281278 if self .__local_mode :
282279 globals .logger .debug ("Using local mode. Dropping network request" )
@@ -312,6 +309,7 @@ def _request(
312309 timeout ,
313310 init_timeout is not None ,
314311 get_text_value_only ,
312+ useStatsigClient
315313 )
316314
317315 if create_marker is not None :
@@ -333,10 +331,12 @@ def _run_request_with_strict_timeout(
333331 timeout ,
334332 for_initialize = False ,
335333 get_text_value_only = False ,
334+ useStatsigClient = False
336335 ) -> RequestResult :
337336 def request_task ():
338337 try :
339- with self .__request_session .request (
338+ request_session = self .__statsig_request_session if useStatsigClient else self .__request_session
339+ with request_session .request (
340340 method ,
341341 url ,
342342 data = payload ,
0 commit comments