Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 25 additions & 30 deletions sambacc/commands/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,19 @@ def _print_join_error(err: typing.Any) -> None:
def _add_join_sources(joiner: joinutil.Joiner, cli: typing.Any) -> None:
if cli.insecure or getattr(cli, "insecure_auto_join", False):
upass = joinutil.UserPass(cli.username, cli.password)
joiner.add_source(joinutil.JoinBy.PASSWORD, upass)
joiner.add_pw_source(upass)
if cli.files:
for path in cli.join_files or []:
joiner.add_source(joinutil.JoinBy.FILE, path)
joiner.add_file_source(path)
if cli.odj_files:
for path in cli.odj_files:
joiner.add_odj_file_source(path)
if cli.interactive:
upass = joinutil.UserPass(cli.username)
joiner.add_source(joinutil.JoinBy.INTERACTIVE, upass)
joiner.add_interactive_source(upass)


def _join_args(parser: Parser) -> None:
parser.set_defaults(insecure=False, files=True, interactive=True)
def _join_args_common(parser: Parser) -> None:
toggle_option(
parser,
arg="--insecure",
Expand All @@ -63,19 +65,30 @@ def _join_args(parser: Parser) -> None:
dest="files",
helpfmt="{} reading user/password from JSON files.",
)
toggle_option(
parser,
arg="--interactive",
dest="interactive",
helpfmt="{} interactive password prompt.",
)
parser.add_argument(
"--join-file",
"-j",
dest="join_files",
action="append",
help="Path to file with user/password in JSON format.",
)
parser.add_argument(
"--odj-file",
dest="odj_files",
action="append",
help="Path to an Offline Domain Join (ODJ) provisioning data file",
)


def _join_args(parser: Parser) -> None:
parser.set_defaults(insecure=False, files=True, interactive=True)
_join_args_common(parser)
toggle_option(
parser,
arg="--interactive",
dest="interactive",
helpfmt="{} interactive password prompt.",
)


@commands.command(name="join", arg_func=_join_args)
Expand All @@ -98,31 +111,13 @@ def join(ctx: Context) -> None:

def _must_join_args(parser: Parser) -> None:
parser.set_defaults(insecure=False, files=True, wait=True)
toggle_option(
parser,
arg="--insecure",
dest="insecure",
helpfmt="{} taking user/password from CLI or environment.",
)
toggle_option(
parser,
arg="--files",
dest="files",
helpfmt="{} reading user/password from JSON files.",
)
_join_args_common(parser)
toggle_option(
parser,
arg="--wait",
dest="wait",
helpfmt="{} waiting until a join is done.",
)
parser.add_argument(
"--join-file",
"-j",
dest="join_files",
action="append",
help="Path to file with user/password in JSON format.",
)


@commands.command(name="must-join", arg_func=_must_join_args)
Expand Down
1 change: 0 additions & 1 deletion sambacc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _schema_validate(data: dict[str, typing.Any], version: str) -> None:
except AttributeError:
_refreserror = _FakeRefResolutionError

global _JSON_SCHEMA
if version == "v0" and version not in _JSON_SCHEMA:
try:
import sambacc.schema.conf_v0_schema
Expand Down
46 changes: 26 additions & 20 deletions sambacc/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class JoinBy(enum.Enum):
PASSWORD = "password"
FILE = "file"
INTERACTIVE = "interactive"
ODJ_FILE = "odj_file"


class UserPass:
Expand Down Expand Up @@ -78,6 +79,7 @@ class Joiner:
"""

_net_ads_join = samba_cmds.net["ads", "join"]
_requestodj = samba_cmds.net["offlinejoin", "requestodj"]

def __init__(
self,
Expand All @@ -90,34 +92,20 @@ def __init__(
self.marker = marker
self._opener = opener or FileOpener()

def add_source(
self,
method: JoinBy,
value: typing.Optional[typing.Union[str, UserPass]] = None,
) -> None:
if method in {JoinBy.PASSWORD, JoinBy.INTERACTIVE}:
if not isinstance(value, UserPass):
raise ValueError("expected UserPass value")
if method == JoinBy.PASSWORD:
self.add_pw_source(value)
else:
self.add_interactive_source(value)
elif method in {JoinBy.FILE}:
if not isinstance(value, str):
raise ValueError("expected str value")
self.add_file_source(value)
else:
raise ValueError(f"invalid method: {method}")

def add_file_source(self, path_or_uri: str) -> None:
self._sources.append(_JoinSource(JoinBy.FILE, None, path_or_uri))

def add_pw_source(self, value: UserPass) -> None:
assert isinstance(value, UserPass)
self._sources.append(_JoinSource(JoinBy.PASSWORD, value, ""))

def add_interactive_source(self, value: UserPass) -> None:
assert isinstance(value, UserPass)
self._sources.append(_JoinSource(JoinBy.INTERACTIVE, value, ""))

def add_odj_file_source(self, path_or_uri: str) -> None:
self._sources.append(_JoinSource(JoinBy.ODJ_FILE, None, path_or_uri))

def join(self, dns_updates: bool = False) -> None:
if not self._sources:
raise JoinError("no sources for join data")
Expand All @@ -127,15 +115,19 @@ def join(self, dns_updates: bool = False) -> None:
if src.method is JoinBy.PASSWORD:
assert src.upass
upass = src.upass
self._join(upass, dns_updates=dns_updates)
elif src.method is JoinBy.FILE:
assert src.path
upass = self._read_from(src.path)
self._join(upass, dns_updates=dns_updates)
elif src.method is JoinBy.INTERACTIVE:
assert src.upass
upass = UserPass(src.upass.username, _PROMPT)
self._join(upass, dns_updates=dns_updates)
elif src.method is JoinBy.ODJ_FILE:
self._offline_join(src.path)
else:
raise ValueError(f"invalid method: {src.method}")
self._join(upass, dns_updates=dns_updates)
self._set_marker()
return
except JoinError as join_err:
Expand Down Expand Up @@ -194,6 +186,20 @@ def _join(self, upass: UserPass, dns_updates: bool = False) -> None:
if ret != 0:
raise JoinError("failed to run {}".format(cmd))

def _offline_join(self, path: str) -> None:
cmd = list(self._requestodj["-i"])
try:
with self._opener.open(path) as fh:
proc = subprocess.Popen(cmd, stdin=subprocess.PIPE)
assert proc.stdin # mypy appeasment
proc.stdin.write(fh.read())
proc.stdin.close()
ret = proc.wait()
if ret != 0:
raise JoinError(f"failed running {cmd}")
except FileNotFoundError:
raise JoinError(f"source file not found: {path}")

def _set_marker(self) -> None:
if self.marker is not None:
with open(self.marker, "w") as fh:
Expand Down
Loading
Loading