17
17
#
18
18
19
19
import enum
20
+ import errno
20
21
import json
21
22
import logging
22
23
import subprocess
23
24
import typing
24
25
26
+ from .opener import Opener , FileOpener
25
27
from sambacc import samba_cmds
26
28
from sambacc .simple_waiter import Waiter
27
29
@@ -62,6 +64,12 @@ def __init__(
62
64
self .password = password
63
65
64
66
67
+ class _JoinSource (typing .NamedTuple ):
68
+ method : JoinBy
69
+ upass : typing .Optional [UserPass ]
70
+ path : str
71
+
72
+
65
73
class Joiner :
66
74
"""Utility class for joining to AD domain.
67
75
@@ -71,9 +79,16 @@ class Joiner:
71
79
72
80
_net_ads_join = samba_cmds .net ["ads" , "join" ]
73
81
74
- def __init__ (self , marker : typing .Optional [str ] = None ) -> None :
75
- self ._sources : list [tuple [JoinBy , typing .Any ]] = []
82
+ def __init__ (
83
+ self ,
84
+ marker : typing .Optional [str ] = None ,
85
+ * ,
86
+ opener : typing .Optional [Opener ] = None ,
87
+ ) -> None :
88
+ self ._source_paths : list [str ] = []
89
+ self ._sources : list [_JoinSource ] = []
76
90
self .marker = marker
91
+ self ._opener = opener or FileOpener ()
77
92
78
93
def add_source (
79
94
self ,
@@ -83,27 +98,43 @@ def add_source(
83
98
if method in {JoinBy .PASSWORD , JoinBy .INTERACTIVE }:
84
99
if not isinstance (value , UserPass ):
85
100
raise ValueError ("expected UserPass value" )
101
+ if method == JoinBy .PASSWORD :
102
+ self .add_pw_source (value )
103
+ else :
104
+ self .add_interactive_source (value )
86
105
elif method in {JoinBy .FILE }:
87
106
if not isinstance (value , str ):
88
107
raise ValueError ("expected str value" )
108
+ self .add_file_source (value )
89
109
else :
90
110
raise ValueError (f"invalid method: { method } " )
91
- self ._sources .append ((method , value ))
111
+
112
+ def add_file_source (self , path_or_uri : str ) -> None :
113
+ self ._sources .append (_JoinSource (JoinBy .FILE , None , path_or_uri ))
114
+
115
+ def add_pw_source (self , value : UserPass ) -> None :
116
+ self ._sources .append (_JoinSource (JoinBy .PASSWORD , value , "" ))
117
+
118
+ def add_interactive_source (self , value : UserPass ) -> None :
119
+ self ._sources .append (_JoinSource (JoinBy .INTERACTIVE , value , "" ))
92
120
93
121
def join (self , dns_updates : bool = False ) -> None :
94
122
if not self ._sources :
95
123
raise JoinError ("no sources for join data" )
96
124
errors = []
97
- for method , value in self ._sources :
125
+ for src in self ._sources :
98
126
try :
99
- if method is JoinBy .PASSWORD :
100
- upass = value
101
- elif method is JoinBy .FILE :
102
- upass = self ._read_from (value )
103
- elif method is JoinBy .INTERACTIVE :
104
- upass = UserPass (value .username , _PROMPT )
127
+ if src .method is JoinBy .PASSWORD :
128
+ assert src .upass
129
+ upass = src .upass
130
+ elif src .method is JoinBy .FILE :
131
+ assert src .path
132
+ upass = self ._read_from (src .path )
133
+ elif src .method is JoinBy .INTERACTIVE :
134
+ assert src .upass
135
+ upass = UserPass (src .upass .username , _PROMPT )
105
136
else :
106
- raise ValueError (f"invalid method: { method } " )
137
+ raise ValueError (f"invalid method: { src . method } " )
107
138
self ._join (upass , dns_updates = dns_updates )
108
139
self ._set_marker ()
109
140
return
@@ -118,10 +149,14 @@ def join(self, dns_updates: bool = False) -> None:
118
149
119
150
def _read_from (self , path : str ) -> UserPass :
120
151
try :
121
- with open (path ) as fh :
152
+ with self . _opener . open (path ) as fh :
122
153
data = json .load (fh )
123
154
except FileNotFoundError :
124
155
raise JoinError (f"source file not found: { path } " )
156
+ except OSError as err :
157
+ if getattr (err , "errno" , 0 ) != errno .ENOENT :
158
+ raise
159
+ raise JoinError (f"resource not found: { path } " )
125
160
upass = UserPass ()
126
161
try :
127
162
upass .username = data ["username" ]
0 commit comments