11import logging
2- from typing import Annotated , Callable , Optional , TypeVar
2+ from typing import (
3+ Annotated ,
4+ Callable ,
5+ Optional ,
6+ TypeVar ,
7+ Union ,
8+ )
9+
10+ import typing_extensions
311
412from .address_helper import validate_address
513from .consumer import Consumer
1018from .qpid .proton ._handlers import MessagingHandler
1119from .qpid .proton ._transport import SSLDomain
1220from .qpid .proton .utils import BlockingConnection
13- from .ssl_configuration import SslConfigurationContext
21+ from .ssl_configuration import (
22+ CurrentUserStore ,
23+ FriendlyName ,
24+ LocalMachineStore ,
25+ PKCS12Store ,
26+ PosixSslConfigurationContext ,
27+ Unambiguous ,
28+ WinSslConfigurationContext ,
29+ )
1430
1531logger = logging .getLogger (__name__ )
1632
@@ -34,7 +50,9 @@ def __init__(
3450 uri : Optional [str ] = None ,
3551 # multi-node mode
3652 uris : Optional [list [str ]] = None ,
37- ssl_context : Optional [SslConfigurationContext ] = None ,
53+ ssl_context : Union [
54+ PosixSslConfigurationContext , WinSslConfigurationContext , None
55+ ] = None ,
3856 on_disconnection_handler : Optional [CB ] = None , # type: ignore
3957 ):
4058 """
@@ -60,7 +78,9 @@ def __init__(
6078 self ._conn : BlockingConnection
6179 self ._management : Management
6280 self ._on_disconnection_handler = on_disconnection_handler
63- self ._conf_ssl_context : Optional [SslConfigurationContext ] = ssl_context
81+ self ._conf_ssl_context : Union [
82+ PosixSslConfigurationContext , WinSslConfigurationContext , None
83+ ] = ssl_context
6484 self ._ssl_domain = None
6585 self ._connections = [] # type: ignore
6686 self ._index : int = - 1
@@ -80,17 +100,47 @@ def dial(self) -> None:
80100 logger .debug ("Enabling SSL" )
81101
82102 self ._ssl_domain = SSLDomain (SSLDomain .MODE_CLIENT )
83- if self ._ssl_domain is not None :
84- self ._ssl_domain .set_trusted_ca_db (self ._conf_ssl_context .ca_cert )
103+ assert self ._ssl_domain
104+
105+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
106+ ca_cert = self ._conf_ssl_context .ca_cert
107+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
108+ ca_cert = self ._win_store_to_cert (self ._conf_ssl_context .ca_store )
109+ else :
110+ typing_extensions .assert_never (self ._conf_ssl_context )
111+ self ._ssl_domain .set_trusted_ca_db (ca_cert )
112+
85113 # for mutual authentication
86114 if self ._conf_ssl_context .client_cert is not None :
87115 logger .debug ("Enabling mutual authentication as well" )
88- if self ._ssl_domain is not None :
89- self ._ssl_domain .set_credentials (
90- self ._conf_ssl_context .client_cert .client_cert ,
91- self ._conf_ssl_context .client_cert .client_key ,
92- self ._conf_ssl_context .client_cert .password ,
116+
117+ if isinstance (self ._conf_ssl_context , PosixSslConfigurationContext ):
118+ client_cert = self ._conf_ssl_context .client_cert .client_cert
119+ client_key = self ._conf_ssl_context .client_cert .client_key
120+ password = self ._conf_ssl_context .client_cert .password
121+ elif isinstance (self ._conf_ssl_context , WinSslConfigurationContext ):
122+ client_cert = self ._win_store_to_cert (
123+ self ._conf_ssl_context .client_cert .store
93124 )
125+ disambiguation_method = (
126+ self ._conf_ssl_context .client_cert .disambiguation_method
127+ )
128+ if isinstance (disambiguation_method , Unambiguous ):
129+ client_key = None
130+ elif isinstance (disambiguation_method , FriendlyName ):
131+ client_key = disambiguation_method .name
132+ else :
133+ typing_extensions .assert_never (disambiguation_method )
134+
135+ password = self ._conf_ssl_context .client_cert .password
136+ else :
137+ typing_extensions .assert_never (self ._conf_ssl_context )
138+
139+ self ._ssl_domain .set_credentials (
140+ client_cert ,
141+ client_key ,
142+ password ,
143+ )
94144 self ._conn = BlockingConnection (
95145 url = self ._addr ,
96146 urls = self ._addrs ,
@@ -100,6 +150,19 @@ def dial(self) -> None:
100150 self ._open ()
101151 logger .debug ("Connection to the server established" )
102152
153+ def _win_store_to_cert (
154+ self , store : Union [LocalMachineStore , CurrentUserStore , PKCS12Store ]
155+ ) -> str :
156+ if isinstance (store , LocalMachineStore ):
157+ ca_cert = f"lmss:{ store .name } "
158+ elif isinstance (store , CurrentUserStore ):
159+ ca_cert = f"ss:{ store .name } "
160+ elif isinstance (store , PKCS12Store ):
161+ ca_cert = store .path
162+ else :
163+ typing_extensions .assert_never (store )
164+ return ca_cert
165+
103166 def _open (self ) -> None :
104167 self ._management = Management (self ._conn )
105168 self ._management .open ()
0 commit comments