diff --git a/sambacc/commands/join.py b/sambacc/commands/join.py index 4b3701e3..c4183b0c 100644 --- a/sambacc/commands/join.py +++ b/sambacc/commands/join.py @@ -40,17 +40,19 @@ def _print_join_error(err: typing.Any) -> None: def _add_join_sources(joiner: joinutil.Joiner, cli: typing.Any) -> None: if cli.insecure or getattr(cli, "insecure_auto_join", False): upass = joinutil.UserPass(cli.username, cli.password) - joiner.add_source(joinutil.JoinBy.PASSWORD, upass) + joiner.add_pw_source(upass) if cli.files: for path in cli.join_files or []: - joiner.add_source(joinutil.JoinBy.FILE, path) + joiner.add_file_source(path) + if cli.odj_files: + for path in cli.odj_files: + joiner.add_odj_file_source(path) if cli.interactive: upass = joinutil.UserPass(cli.username) - joiner.add_source(joinutil.JoinBy.INTERACTIVE, upass) + joiner.add_interactive_source(upass) -def _join_args(parser: Parser) -> None: - parser.set_defaults(insecure=False, files=True, interactive=True) +def _join_args_common(parser: Parser) -> None: toggle_option( parser, arg="--insecure", @@ -63,12 +65,6 @@ def _join_args(parser: Parser) -> None: dest="files", helpfmt="{} reading user/password from JSON files.", ) - toggle_option( - parser, - arg="--interactive", - dest="interactive", - helpfmt="{} interactive password prompt.", - ) parser.add_argument( "--join-file", "-j", @@ -76,6 +72,23 @@ def _join_args(parser: Parser) -> None: action="append", help="Path to file with user/password in JSON format.", ) + parser.add_argument( + "--odj-file", + dest="odj_files", + action="append", + help="Path to an Offline Domain Join (ODJ) provisioning data file", + ) + + +def _join_args(parser: Parser) -> None: + parser.set_defaults(insecure=False, files=True, interactive=True) + _join_args_common(parser) + toggle_option( + parser, + arg="--interactive", + dest="interactive", + helpfmt="{} interactive password prompt.", + ) @commands.command(name="join", arg_func=_join_args) @@ -98,31 +111,13 @@ def join(ctx: Context) -> None: def _must_join_args(parser: Parser) -> None: parser.set_defaults(insecure=False, files=True, wait=True) - toggle_option( - parser, - arg="--insecure", - dest="insecure", - helpfmt="{} taking user/password from CLI or environment.", - ) - toggle_option( - parser, - arg="--files", - dest="files", - helpfmt="{} reading user/password from JSON files.", - ) + _join_args_common(parser) toggle_option( parser, arg="--wait", dest="wait", helpfmt="{} waiting until a join is done.", ) - parser.add_argument( - "--join-file", - "-j", - dest="join_files", - action="append", - help="Path to file with user/password in JSON format.", - ) @commands.command(name="must-join", arg_func=_must_join_args) diff --git a/sambacc/config.py b/sambacc/config.py index 0ec0a487..2be21211 100644 --- a/sambacc/config.py +++ b/sambacc/config.py @@ -120,7 +120,6 @@ def _schema_validate(data: dict[str, typing.Any], version: str) -> None: except AttributeError: _refreserror = _FakeRefResolutionError - global _JSON_SCHEMA if version == "v0" and version not in _JSON_SCHEMA: try: import sambacc.schema.conf_v0_schema diff --git a/sambacc/join.py b/sambacc/join.py index e140b99a..a22a12d6 100644 --- a/sambacc/join.py +++ b/sambacc/join.py @@ -45,6 +45,7 @@ class JoinBy(enum.Enum): PASSWORD = "password" FILE = "file" INTERACTIVE = "interactive" + ODJ_FILE = "odj_file" class UserPass: @@ -78,6 +79,7 @@ class Joiner: """ _net_ads_join = samba_cmds.net["ads", "join"] + _requestodj = samba_cmds.net["offlinejoin", "requestodj"] def __init__( self, @@ -90,34 +92,20 @@ def __init__( self.marker = marker self._opener = opener or FileOpener() - def add_source( - self, - method: JoinBy, - value: typing.Optional[typing.Union[str, UserPass]] = None, - ) -> None: - if method in {JoinBy.PASSWORD, JoinBy.INTERACTIVE}: - if not isinstance(value, UserPass): - raise ValueError("expected UserPass value") - if method == JoinBy.PASSWORD: - self.add_pw_source(value) - else: - self.add_interactive_source(value) - elif method in {JoinBy.FILE}: - if not isinstance(value, str): - raise ValueError("expected str value") - self.add_file_source(value) - else: - raise ValueError(f"invalid method: {method}") - def add_file_source(self, path_or_uri: str) -> None: self._sources.append(_JoinSource(JoinBy.FILE, None, path_or_uri)) def add_pw_source(self, value: UserPass) -> None: + assert isinstance(value, UserPass) self._sources.append(_JoinSource(JoinBy.PASSWORD, value, "")) def add_interactive_source(self, value: UserPass) -> None: + assert isinstance(value, UserPass) self._sources.append(_JoinSource(JoinBy.INTERACTIVE, value, "")) + def add_odj_file_source(self, path_or_uri: str) -> None: + self._sources.append(_JoinSource(JoinBy.ODJ_FILE, None, path_or_uri)) + def join(self, dns_updates: bool = False) -> None: if not self._sources: raise JoinError("no sources for join data") @@ -127,15 +115,19 @@ def join(self, dns_updates: bool = False) -> None: if src.method is JoinBy.PASSWORD: assert src.upass upass = src.upass + self._join(upass, dns_updates=dns_updates) elif src.method is JoinBy.FILE: assert src.path upass = self._read_from(src.path) + self._join(upass, dns_updates=dns_updates) elif src.method is JoinBy.INTERACTIVE: assert src.upass upass = UserPass(src.upass.username, _PROMPT) + self._join(upass, dns_updates=dns_updates) + elif src.method is JoinBy.ODJ_FILE: + self._offline_join(src.path) else: raise ValueError(f"invalid method: {src.method}") - self._join(upass, dns_updates=dns_updates) self._set_marker() return except JoinError as join_err: @@ -194,6 +186,20 @@ def _join(self, upass: UserPass, dns_updates: bool = False) -> None: if ret != 0: raise JoinError("failed to run {}".format(cmd)) + def _offline_join(self, path: str) -> None: + cmd = list(self._requestodj["-i"]) + try: + with self._opener.open(path) as fh: + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE) + assert proc.stdin # mypy appeasment + proc.stdin.write(fh.read()) + proc.stdin.close() + ret = proc.wait() + if ret != 0: + raise JoinError(f"failed running {cmd}") + except FileNotFoundError: + raise JoinError(f"source file not found: {path}") + def _set_marker(self) -> None: if self.marker is not None: with open(self.marker, "w") as fh: diff --git a/tests/test_join.py b/tests/test_join.py index 0fdbba42..3d29ffb9 100644 --- a/tests/test_join.py +++ b/tests/test_join.py @@ -43,6 +43,9 @@ def testjoiner(tmp_path): class TestJoiner(sambacc.join.Joiner): _net_ads_join = samba_cmds.SambaCommand(fake_net_script)["ads", "join"] + _requestodj = samba_cmds.SambaCommand(fake_net_script)[ + "offlinejoin", "requestodj" + ] path = tmp_path logpath = data_path / "log" _nullfh = None @@ -61,17 +64,14 @@ def test_no_sources(testjoiner): def test_invalid_source_vals(testjoiner): - with pytest.raises(ValueError): - testjoiner.add_source("bob", 123) - with pytest.raises(ValueError): - testjoiner.add_source(sambacc.join.JoinBy.PASSWORD, 123) - with pytest.raises(ValueError): - testjoiner.add_source(sambacc.join.JoinBy.FILE, 123) + with pytest.raises(AssertionError): + testjoiner.add_pw_source("abc123") + with pytest.raises(AssertionError): + testjoiner.add_interactive_source("xyzdef") def test_join_password(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) testjoiner.join() @@ -81,19 +81,13 @@ def test_join_file(testjoiner): jpath1 = os.path.join(testjoiner.path, "join1.json") with open(jpath1, "w") as fh: json.dump({"username": "elmer", "password": "hunter2"}, fh) - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) testjoiner.join() def test_join_missing_file(testjoiner): jpath1 = os.path.join(testjoiner.path, "nope.json") - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) with pytest.raises(sambacc.join.JoinError) as err: testjoiner.join() assert "not found" in str(err).lower() @@ -101,10 +95,7 @@ def test_join_missing_file(testjoiner): def test_join_bad_file(testjoiner): jpath1 = os.path.join(testjoiner.path, "join1.json") - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) with open(jpath1, "w") as fh: json.dump({"acme": True}, fh) @@ -123,17 +114,13 @@ def test_join_bad_file(testjoiner): def test_join_multi_source(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) jpath1 = os.path.join(testjoiner.path, "join1.json") with open(jpath1, "w") as fh: json.dump({"username": "elmer", "password": "hunter2"}, fh) - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) testjoiner.join() with open(testjoiner.logpath) as fh: @@ -144,17 +131,13 @@ def test_join_multi_source(testjoiner): def test_join_multi_source_fail_first(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "failme"), ) jpath1 = os.path.join(testjoiner.path, "join1.json") with open(jpath1, "w") as fh: json.dump({"username": "elmer", "password": "hunter2"}, fh) - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) testjoiner.join() with open(testjoiner.logpath) as fh: @@ -165,17 +148,13 @@ def test_join_multi_source_fail_first(testjoiner): def test_join_multi_source_fail_both(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "failme"), ) jpath1 = os.path.join(testjoiner.path, "join1.json") with open(jpath1, "w") as fh: json.dump({"username": "elmer", "password": "failme2"}, fh) - testjoiner.add_source( - sambacc.join.JoinBy.FILE, - jpath1, - ) + testjoiner.add_file_source(jpath1) with pytest.raises(sambacc.join.JoinError) as err: testjoiner.join() assert err.match("2 join attempts") @@ -189,8 +168,7 @@ def test_join_multi_source_fail_both(testjoiner): def test_join_prompt_fake(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.INTERACTIVE, + testjoiner.add_interactive_source( sambacc.join.UserPass("daffy"), ) testjoiner.join() @@ -205,8 +183,7 @@ def test_join_prompt_fake(testjoiner): def test_join_with_marker(testjoiner): testjoiner.marker = os.path.join(testjoiner.path, "marker.json") - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) testjoiner.join() @@ -217,8 +194,7 @@ def test_join_with_marker(testjoiner): def test_join_bad_marker(testjoiner): testjoiner.marker = os.path.join(testjoiner.path, "marker.json") - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) testjoiner.join() @@ -235,8 +211,7 @@ def test_join_bad_marker(testjoiner): def test_join_no_marker(testjoiner): - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) testjoiner.join() @@ -257,7 +232,6 @@ def wait(self): errors = [] def ehandler(err): - nonlocal errors if len(errors) > 5: raise ValueError("xxx") errors.append(err) @@ -273,8 +247,7 @@ def ehandler(err): # success case - performs a password join errors[:] = [] - testjoiner.add_source( - sambacc.join.JoinBy.PASSWORD, + testjoiner.add_pw_source( sambacc.join.UserPass("bugs", "whatsupdoc"), ) testjoiner.marker = os.path.join(testjoiner.path, "marker.json") @@ -293,3 +266,32 @@ def ehandler(err): assert len(errors) == 0 assert waiter.wcount == 6 + + +def test_offline_join(testjoiner): + odj_path = os.path.join(testjoiner.path, "foo.odj") + with open(odj_path, "w") as fh: + fh.write("FAKE!\n") + testjoiner.add_odj_file_source(odj_path) + testjoiner.join() + + with open(testjoiner.logpath) as fh: + lines = fh.readlines() + assert lines[0].startswith("ARGS") + assert lines[1].startswith("FAKE!") + + +def test_offline_join_nofile(testjoiner): + odj_path = os.path.join(testjoiner.path, "foo.odj") + testjoiner.add_odj_file_source(odj_path) + with pytest.raises(sambacc.join.JoinError): + testjoiner.join() + + +def test_offline_join_fail(testjoiner): + odj_path = os.path.join(testjoiner.path, "foo.odj") + with open(odj_path, "w") as fh: + fh.write("failme!\n") + testjoiner.add_odj_file_source(odj_path) + with pytest.raises(sambacc.join.JoinError): + testjoiner.join()