1717import  asyncio 
1818import  os 
1919import  socket 
20+ import  ssl 
2021from  threading  import  Thread 
21- from  typing  import  Any , AsyncGenerator ,  Generator 
22+ from  typing  import  Any , AsyncGenerator 
2223
24+ from  aiofiles .tempfile  import  TemporaryDirectory 
2325from  aiohttp  import  web 
26+ from  cryptography .hazmat .primitives  import  serialization 
2427import  pytest   # noqa F401 Needed to run the tests 
28+ from  unit .mocks  import  create_ssl_context   # type: ignore 
2529from  unit .mocks  import  FakeCredentials   # type: ignore 
2630from  unit .mocks  import  FakeCSQLInstance   # type: ignore 
2731
2832from  google .cloud .sql .connector .client  import  CloudSQLClient 
2933from  google .cloud .sql .connector .connection_name  import  ConnectionName 
3034from  google .cloud .sql .connector .instance  import  RefreshAheadCache 
3135from  google .cloud .sql .connector .utils  import  generate_keys 
36+ from  google .cloud .sql .connector .utils  import  write_to_file 
3237
3338SCOPES  =  ["https://www.googleapis.com/auth/sqlservice.admin" ]
3439
@@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials:
7984    return  FakeCredentials ()
8085
8186
82- def  mock_server ( server_sock :  socket . socket ) ->  None :
83-     """Create mock  server listening on specified ip_address and port. """ 
87+ async   def  start_proxy_server ( instance :  FakeCSQLInstance ) ->  None :
88+     """Run local proxy  server capable of performing mTLS """ 
8489    ip_address  =  "127.0.0.1" 
8590    port  =  3307 
86-     server_sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
87-     server_sock .bind ((ip_address , port ))
88-     server_sock .listen (0 )
89-     server_sock .accept ()
91+     # create socket 
92+     with  socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as  sock :
93+         # create SSL/TLS context 
94+         context  =  ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
95+         context .minimum_version  =  ssl .TLSVersion .TLSv1_3 
96+         # tmpdir and its contents are automatically deleted after the CA cert 
97+         # and cert chain are loaded into the SSLcontext. The values 
98+         # need to be written to files in order to be loaded by the SSLContext 
99+         server_key_bytes  =  instance .server_key .private_bytes (
100+             encoding = serialization .Encoding .PEM ,
101+             format = serialization .PrivateFormat .TraditionalOpenSSL ,
102+             encryption_algorithm = serialization .NoEncryption (),
103+         )
104+         async  with  TemporaryDirectory () as  tmpdir :
105+             server_filename , _ , key_filename  =  await  write_to_file (
106+                 tmpdir , instance .server_cert_pem , "" , server_key_bytes 
107+             )
108+             context .load_cert_chain (server_filename , key_filename )
109+         # allow socket to be re-used 
110+         sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
111+         # bind socket to Cloud SQL proxy server port on localhost 
112+         sock .bind ((ip_address , port ))
113+         # listen for incoming connections 
114+         sock .listen (5 )
115+ 
116+         with  context .wrap_socket (sock , server_side = True ) as  ssock :
117+             while  True :
118+                 conn , _  =  ssock .accept ()
119+                 conn .close ()
120+ 
121+ 
122+ @pytest .fixture (scope = "session" ) 
123+ def  proxy_server (fake_instance : FakeCSQLInstance ) ->  None :
124+     """Run local proxy server capable of performing mTLS""" 
125+     thread  =  Thread (
126+         target = asyncio .run ,
127+         args = (
128+             start_proxy_server (
129+                 fake_instance ,
130+             ),
131+         ),
132+         daemon = True ,
133+     )
134+     thread .start ()
135+     thread .join (1.0 )  # add a delay to allow the proxy server to start 
90136
91137
92138@pytest .fixture  
93- def  server () ->  Generator :
94-     """Create thread with server listening on proper port""" 
95-     server_sock  =  socket .socket ()
96-     thread  =  Thread (target = mock_server , args = (server_sock ,), daemon = True )
97-     thread .start ()
98-     yield  thread 
99-     server_sock .close ()
100-     thread .join ()
139+ async  def  context (fake_instance : FakeCSQLInstance ) ->  ssl .SSLContext :
140+     return  await  create_ssl_context (fake_instance )
101141
102142
103143@pytest .fixture  
@@ -107,7 +147,7 @@ def kwargs() -> Any:
107147    return  kwargs 
108148
109149
110- @pytest .fixture  
150+ @pytest .fixture ( scope = "session" )  
111151def  fake_instance () ->  FakeCSQLInstance :
112152    return  FakeCSQLInstance ()
113153
0 commit comments