Skip to content

Commit 9fbbc6e

Browse files
committed
more type hints
Signed-off-by: Gaëtan Lehmann <[email protected]>
1 parent 3f42ac1 commit 9fbbc6e

File tree

7 files changed

+157
-33
lines changed

7 files changed

+157
-33
lines changed

lib/commands.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
import logging
33
import shlex
44
import subprocess
5-
from typing import Union
5+
from typing import List, Literal, overload, Union
66

7-
from _pytest.fixtures import _teardown_yield_fixture
87

98
import lib.config as config
109
from lib.netutil import wrap_ip
@@ -145,7 +144,34 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
145144
# The actual code is in _ssh().
146145
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
147146
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
148-
def ssh(hostname_or_ip, cmd, *, check=True, simple_output=True, suppress_fingerprint_warnings=True,
147+
@overload
148+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
149+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
150+
target_os: str = 'linux', decode: Literal[True] = True, options: List[str] = []) -> str:
151+
...
152+
@overload
153+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
154+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
155+
target_os: str = 'linux', decode: Literal[False], options: List[str] = []) -> bytes:
156+
...
157+
@overload
158+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
159+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
160+
target_os: str = 'linux', decode: bool = True, options: List[str] = []) -> SSHResult:
161+
...
162+
@overload
163+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
164+
suppress_fingerprint_warnings: bool = True, background: Literal[True],
165+
target_os: str = 'linux', decode: bool = True, options: List[str] = []) -> subprocess.Popen:
166+
...
167+
@overload
168+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check=True, simple_output: bool = True,
169+
suppress_fingerprint_warnings=True, background: bool = False,
170+
target_os='linux', decode: bool = True, options: List[str] = []) \
171+
-> Union[str, bytes, SSHResult, subprocess.Popen]:
172+
...
173+
def ssh(hostname_or_ip, cmd, *, check=True, simple_output=True,
174+
suppress_fingerprint_warnings=True,
149175
background=False, target_os='linux', decode=True, options=[]):
150176
result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
151177
background, target_os, decode, options)

lib/common.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import time
66
import traceback
77
from enum import Enum
8+
from typing import TYPE_CHECKING, Dict, Optional, Union
89
from uuid import UUID
910

1011
import lib.commands as commands
12+
if TYPE_CHECKING:
13+
import lib.host
1114

1215
class PackageManagerEnum(Enum):
1316
UNKNOWN = 1
@@ -154,9 +157,10 @@ def strtobool(str):
154157
return False
155158
raise ValueError("invalid truth value '{}'".format(str))
156159

157-
def _param_get(host, xe_prefix, uuid, param_name, key=None, accept_unknown_key=False):
160+
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = None,
161+
accept_unknown_key=False) -> Optional[str]:
158162
""" Common implementation for param_get. """
159-
args = {'uuid': uuid, 'param-name': param_name}
163+
args: Dict[str, Union[str, bool]] = {'uuid': uuid, 'param-name': param_name}
160164
if key is not None:
161165
args['param-key'] = key
162166
try:

lib/host.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import json
21
import logging
32
import os
43
import shlex
4+
import subprocess
55
import tempfile
66
import uuid
77

88
from packaging import version
9-
from typing import TYPE_CHECKING
9+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, overload
1010

1111
import lib.commands as commands
1212
import lib.pif as pif
@@ -40,9 +40,7 @@ class Host:
4040
def __init__(self, pool: 'lib.pool.Pool', hostname_or_ip):
4141
self.pool = pool
4242
self.hostname_or_ip = hostname_or_ip
43-
self.inventory = None
44-
self.uuid = None
45-
self.xo_srv_id = None
43+
self.xo_srv_id: Optional[str] = None
4644

4745
h_data = host_data(self.hostname_or_ip)
4846
self.user = h_data['user']
@@ -59,11 +57,41 @@ def __init__(self, pool: 'lib.pool.Pool', hostname_or_ip):
5957
def __str__(self):
6058
return self.hostname_or_ip
6159

62-
def ssh(self, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True,
63-
background=False, decode=True):
60+
@overload
61+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
62+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
63+
decode: Literal[True] = True) -> str:
64+
...
65+
66+
@overload
67+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
68+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
69+
decode: Literal[False]) -> bytes:
70+
...
71+
72+
@overload
73+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
74+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
75+
decode: bool) -> commands.SSHResult:
76+
...
77+
78+
@overload
79+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: bool = True,
80+
suppress_fingerprint_warnings: bool = True, background: Literal[True],
81+
decode: bool = True) -> subprocess.Popen:
82+
...
83+
84+
@overload
85+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: bool = True,
86+
suppress_fingerprint_warnings: bool = True, background: bool = False, decode: bool = True) \
87+
-> Union[str, bytes, commands.SSHResult, subprocess.Popen]:
88+
...
89+
90+
def ssh(self, cmd, *, check=True, simple_output=True, suppress_fingerprint_warnings=True,
91+
background=False, decode=True) -> Union[str, bytes, commands.SSHResult, subprocess.Popen]:
6492
return commands.ssh(self.hostname_or_ip, cmd, check=check, simple_output=simple_output,
65-
suppress_fingerprint_warnings=suppress_fingerprint_warnings, background=background,
66-
decode=decode)
93+
suppress_fingerprint_warnings=suppress_fingerprint_warnings,
94+
background=background, decode=decode)
6795

6896
def ssh_with_result(self, cmd) -> commands.SSHResult:
6997
# doesn't raise if the command's return is nonzero, unless there's a SSH error
@@ -75,7 +103,18 @@ def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_d
75103
suppress_fingerprint_warnings=suppress_fingerprint_warnings, local_dest=local_dest
76104
)
77105

78-
def xe(self, action, args={}, check=True, simple_output=True, minimal=False, force=False):
106+
@overload
107+
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = True,
108+
simple_output: Literal[True] = True, minimal: bool = False, force: bool = False) -> Union[bool, str]:
109+
...
110+
111+
@overload
112+
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = True,
113+
simple_output: Literal[False], minimal: bool = False, force: bool = False) -> commands.SSHResult:
114+
...
115+
116+
def xe(self, action, args={}, *, check=True, simple_output=True, minimal=False, force=False) \
117+
-> Union[bool, str, commands.SSHResult]:
79118
maybe_param_minimal = ['--minimal'] if minimal else []
80119
maybe_param_force = ['--force'] if force else []
81120

@@ -89,13 +128,14 @@ def stringify(key, value):
89128
return ret.rstrip()
90129
return "{}={}".format(key, shlex.quote(value))
91130

92-
command = ['xe', action] + maybe_param_minimal + maybe_param_force + \
93-
[stringify(key, value) for key, value in args.items()]
131+
command: List[str] = ['xe', action] + maybe_param_minimal + maybe_param_force + \
132+
[stringify(key, value) for key, value in args.items()]
94133
result = self.ssh(
95134
command,
96135
check=check,
97136
simple_output=simple_output
98137
)
138+
assert isinstance(result, (str, commands.SSHResult))
99139

100140
if result == 'true':
101141
return True
@@ -163,7 +203,7 @@ def execute_script(self, script_contents, shebang='sh', simple_output=True):
163203

164204
def _get_xensource_inventory(self):
165205
output = self.ssh(['cat', '/etc/xensource-inventory'])
166-
inventory = {}
206+
inventory: dict[str, str] = {}
167207
for line in output.splitlines():
168208
key, raw_value = line.split('=')
169209
inventory[key] = raw_value.strip('\'')
@@ -202,7 +242,7 @@ def xo_server_add(self, username, password, label=None, unregister_first=True):
202242
'allowUnauthorized': 'true',
203243
'label': label
204244
}
205-
)
245+
).decode()
206246
self.xo_srv_id = xo_srv_id
207247

208248
def xo_server_status(self):
@@ -216,6 +256,7 @@ def xo_server_connected(self):
216256
return self.xo_server_status() == "connected"
217257

218258
def xo_server_reconnect(self):
259+
assert self.xo_srv_id is not None
219260
logging.info("Reconnect XO to host %s" % self)
220261
xo_cli('server.disable', {'id': self.xo_srv_id})
221262
xo_cli('server.enable', {'id': self.xo_srv_id})
@@ -285,6 +326,7 @@ def import_iso(self, uri, sr: SR):
285326
},
286327
)
287328

329+
download_path = None
288330
try:
289331
params = {'uuid': vdi_uuid}
290332
if '://' in uri:
@@ -293,7 +335,6 @@ def import_iso(self, uri, sr: SR):
293335
self.ssh(f"curl -o '{download_path}' '{uri}'")
294336
params['filename'] = download_path
295337
else:
296-
download_path = None
297338
params['filename'] = uri
298339
logging.info(f"Import ISO {uri}: name {random_name}, uuid {vdi_uuid}")
299340

@@ -355,7 +396,7 @@ def get_last_yum_history_tid(self):
355396
"""
356397
try:
357398
history_str = self.ssh(['yum', 'history', 'list', '--noplugins'])
358-
except commands.SSHCommandFailed as e:
399+
except commands.SSHCommandFailed:
359400
# yum history list fails if the list is empty, and it's also not possible to rollback
360401
# to before the first transaction, so "0" would not be appropriate as last transaction.
361402
# To workaround this, create transactions: install and remove a small package.

lib/pool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import traceback
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, Optional, cast
44

55
from packaging import version
66

@@ -30,7 +30,7 @@ def __init__(self, master_hostname_or_ip: str) -> None:
3030
if host_uuid != self.hosts[0].uuid:
3131
host = Host(self, self.host_ip(host_uuid))
3232
self.hosts.append(host)
33-
self.uuid = self.master.xe('pool-list', minimal=True)
33+
self.uuid = cast(str, self.master.xe('pool-list', minimal=True))
3434
self.saved_uefi_certs: Optional[Dict[str, Any]] = None
3535

3636
def param_get(self, param_name, key=None, accept_unknown_key=False):

lib/pxe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lib.commands import ssh, scp, SSHCommandFailed
1+
from lib.commands import ssh, scp
22

33
PXE_CONFIG_DIR = "/pxe/configs/custom"
44

lib/vm.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import json
21
import logging
32
import os
3+
import subprocess
44
import tempfile
5+
from typing import List, Literal, Optional, Union, overload
56

67
import lib.commands as commands
78
import lib.efi as efi
@@ -14,7 +15,7 @@
1415
class VM(BaseVM):
1516
def __init__(self, uuid, host):
1617
super().__init__(uuid, host)
17-
self.ip = None
18+
self.ip: Optional[str] = None
1819
self.previous_host = None # previous host when migrated or being migrated
1920
self.is_windows = self.param_get('platform', 'device_id', accept_unknown_key=True) == '0002'
2021
self.is_uefi = self.param_get('HVM-boot-params', 'firmware', accept_unknown_key=True) == 'uefi'
@@ -82,9 +83,31 @@ def try_get_and_store_ip(self):
8283
self.ip = ip
8384
return True
8485

85-
def ssh(self, cmd, check=True, simple_output=True, background=False, decode=True):
86+
@overload
87+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
88+
background: Literal[False] = False, decode: Literal[True] = True) -> str:
89+
...
90+
91+
@overload
92+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
93+
background: Literal[False] = False, decode: Literal[False]) -> bytes:
94+
...
95+
96+
@overload
97+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
98+
background: Literal[False] = False, decode: bool = True) -> commands.SSHResult:
99+
...
100+
101+
@overload
102+
def ssh(self, cmd: Union[str, List[str]], *, check: bool = True, simple_output: bool = True,
103+
background: Literal[True], decode: bool = True) -> subprocess.Popen:
104+
...
105+
106+
def ssh(self, cmd: Union[str, List[str]], *, check=True, simple_output=True, background=False, decode=True) \
107+
-> Union[str, bytes, commands.SSHResult, subprocess.Popen]:
86108
# raises by default for any nonzero return code
87-
target_os = "windows" if self.is_windows else "linux"
109+
assert self.ip is not None
110+
target_os: str = "windows" if self.is_windows else "linux"
88111
return commands.ssh(self.ip, cmd, check=check, simple_output=simple_output, background=background,
89112
target_os=target_os, decode=decode)
90113

@@ -178,6 +201,7 @@ def exists(self):
178201
return self.host.pool_has_vm(self.uuid)
179202

180203
def exists_on_previous_pool(self):
204+
assert self.previous_host is not None
181205
return self.previous_host.pool_has_vm(self.uuid)
182206

183207
def migrate(self, target_host, sr=None, network=None):
@@ -256,7 +280,7 @@ def get_residence_host(self):
256280
host_uuid = self.param_get('resident-on')
257281
return self.host.pool.get_host_by_uuid(host_uuid)
258282

259-
def start_background_process(self, cmd) -> str:
283+
def start_background_process(self, cmd: str) -> str:
260284
script = "/tmp/bg_process.sh"
261285
pidfile = "/tmp/bg_process.pid"
262286
with tempfile.NamedTemporaryFile('w') as f:
@@ -420,7 +444,7 @@ def get_efi_var(self, var, guid):
420444
if not self.file_exists(efivarfs):
421445
return b''
422446

423-
data = self.ssh(['cat', efivarfs], simple_output=False, decode=False).stdout
447+
data = self.ssh(['cat', efivarfs], decode=False)
424448

425449
# The efivarfs file starts with the attributes, which are 4 bytes long
426450
return data[4:]
@@ -448,11 +472,10 @@ def get_all_efi_bins(self):
448472
'do', 'echo', '$file', '$(head', '-c', magicsz, '$file);',
449473
'done'
450474
],
451-
simple_output=False,
452-
decode=False).stdout.split(b'\n')
475+
decode=False).split(b'\n')
453476

454477
magic = efi.EFI_HEADER_MAGIC.encode('ascii')
455-
binaries = []
478+
binaries: list[str] = []
456479
for f in files:
457480
if magic in f:
458481
# Avoid decoding an unsplit f, as some headers are not utf8
@@ -590,6 +613,18 @@ def is_cert_present(self, key):
590613
check=False, simple_output=False, decode=False)
591614
return res.returncode == 0
592615

616+
@overload
617+
def execute_powershell_script(self, script_contents: str,
618+
simple_output: Literal[True] = True,
619+
prepend: str = "$ProgressPreference = 'SilentlyContinue';") -> str:
620+
...
621+
622+
@overload
623+
def execute_powershell_script(self, script_contents: str,
624+
simple_output: Literal[False],
625+
prepend: str = "$ProgressPreference = 'SilentlyContinue';") -> commands.SSHResult:
626+
...
627+
593628
def execute_powershell_script(
594629
self,
595630
script_contents: str,

lib/xo.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
import json
22
import subprocess
3+
from typing import Any, Dict, Literal, Union, overload
34

5+
6+
@overload
7+
def xo_cli(action: str, args: Dict[str, str] = {}, *, check: bool = True, simple_output: Literal[True] = True,
8+
use_json: Literal[False] = False) -> bytes:
9+
...
10+
@overload
11+
def xo_cli(action: str, args: Dict[str, str] = {}, *, check: bool = True, simple_output: Literal[True] = True,
12+
use_json: Literal[True]) -> Any:
13+
...
14+
@overload
15+
def xo_cli(action: str, args: Dict[str, str] = {}, *, check: bool = True, simple_output: Literal[False],
16+
use_json: bool = False) -> subprocess.CompletedProcess:
17+
...
18+
@overload
19+
def xo_cli(action: str, args: Dict[str, str] = {}, *, check: bool = True, simple_output: bool = True,
20+
use_json: bool = False) -> Union[subprocess.CompletedProcess, Any, bytes]:
21+
...
422
def xo_cli(action, args={}, check=True, simple_output=True, use_json=False):
523
run_array = ['xo-cli', action]
624
if use_json:

0 commit comments

Comments
 (0)