Skip to content

Commit 714c930

Browse files
committed
add some type hints
Signed-off-by: Gaëtan Lehmann <[email protected]>
1 parent bd8ba6b commit 714c930

File tree

7 files changed

+86
-53
lines changed

7 files changed

+86
-53
lines changed

lib/basevm.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22

3+
from typing import Any, Optional, TYPE_CHECKING
4+
35
import lib.commands as commands
6+
if TYPE_CHECKING:
7+
import lib.host
48

59
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
610
from lib.sr import SR
@@ -10,33 +14,36 @@ class BaseVM:
1014

1115
xe_prefix = "vm"
1216

13-
def __init__(self, uuid, host):
17+
def __init__(self, uuid, host: 'lib.host.Host'):
1418
logging.info("New %s: %s", type(self).__name__, uuid)
1519
self.uuid = uuid
1620
self.host = host
1721

18-
def param_get(self, param_name, key=None, accept_unknown_key=False):
22+
def param_get(self, param_name: str, key: Optional[str] = None,
23+
accept_unknown_key: bool = False) -> Optional[str]:
1924
return _param_get(self.host, self.xe_prefix, self.uuid,
2025
param_name, key, accept_unknown_key)
2126

22-
def param_set(self, param_name, value, key=None):
27+
def param_set(self, param_name: str, value: Any, key: Optional[str] = None) -> None:
2328
_param_set(self.host, self.xe_prefix, self.uuid,
2429
param_name, value, key)
2530

26-
def param_remove(self, param_name, key, accept_unknown_key=False):
31+
def param_remove(self, param_name: str, key: str, accept_unknown_key=False) -> None:
2732
_param_remove(self.host, self.xe_prefix, self.uuid,
2833
param_name, key, accept_unknown_key)
2934

30-
def param_add(self, param_name, value, key=None):
35+
def param_add(self, param_name: str, value: str, key=None) -> None:
3136
_param_add(self.host, self.xe_prefix, self.uuid,
3237
param_name, value, key)
3338

34-
def param_clear(self, param_name):
39+
def param_clear(self, param_name: str) -> None:
3540
_param_clear(self.host, self.xe_prefix, self.uuid,
3641
param_name)
3742

38-
def name(self):
39-
return self.param_get('name-label')
43+
def name(self) -> str:
44+
n = self.param_get('name-label')
45+
assert isinstance(n, str)
46+
return n
4047

4148
def vdi_uuids(self, sr_uuid=None):
4249
output = self._disk_list()
@@ -54,7 +61,7 @@ def vdi_uuids(self, sr_uuid=None):
5461
vdis_on_sr.append(vdi)
5562
return vdis_on_sr
5663

57-
def destroy_vdi(self, vdi_uuid):
64+
def destroy_vdi(self, vdi_uuid: str) -> None:
5865
self.host.xe('vdi-destroy', {'uuid': vdi_uuid})
5966

6067
# FIXME: move this method and the above back to class VM if not useful in Snapshot class?
@@ -70,7 +77,7 @@ def all_vdis_on_host(self, host):
7077
return False
7178
return True
7279

73-
def all_vdis_on_sr(self, sr):
80+
def all_vdis_on_sr(self, sr) -> bool:
7481
for vdi_uuid in self.vdi_uuids():
7582
if self.host.pool.get_vdi_sr_uuid(vdi_uuid) != sr.uuid:
7683
return False
@@ -84,7 +91,7 @@ def get_sr(self):
8491
assert sr.attached_to_host(self.host)
8592
return sr
8693

87-
def export(self, filepath, compress='none'):
94+
def export(self, filepath, compress='none') -> None:
8895
logging.info("Export VM %s to %s with compress=%s" % (self.uuid, filepath, compress))
8996
params = {
9097
'uuid': self.uuid,

lib/commands.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import logging
33
import shlex
44
import subprocess
5+
from typing import Union
56

6-
import lib.config as config
7+
from _pytest.fixtures import _teardown_yield_fixture
78

9+
import lib.config as config
810
from lib.netutil import wrap_ip
911

12+
1013
class BaseCommandFailed(Exception):
1114
__slots__ = 'returncode', 'stdout', 'cmd'
1215

@@ -61,7 +64,7 @@ def _ellide_log_lines(log):
6164
return "\n{}".format("\n".join(reduced_message))
6265

6366
def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
64-
background, target_os, decode, options):
67+
background, target_os, decode, options) -> Union[SSHResult, SSHCommandFailed, str, bytes, subprocess.Popen]:
6568
opts = list(options)
6669
opts.append('-o "BatchMode yes"')
6770
if suppress_fingerprint_warnings:
@@ -86,6 +89,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
8689

8790
windows_background = background and target_os == "windows"
8891
# Fetch banner and remove it to avoid stdout/stderr pollution.
92+
banner_res = None
8993
if config.ignore_ssh_banner and not windows_background:
9094
banner_res = subprocess.run(
9195
"ssh root@%s %s '%s'" % (hostname_or_ip, ' '.join(opts), '\n'),
@@ -103,9 +107,10 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
103107
)
104108
logging.debug(f"[{hostname_or_ip}] {command}")
105109
if windows_background:
106-
return True, process
110+
return process
107111

108112
stdout = []
113+
assert process.stdout is not None
109114
for line in iter(process.stdout.readline, b''):
110115
readable_line = line.decode(errors='replace').strip()
111116
stdout.append(line)
@@ -118,34 +123,46 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
118123

119124
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
120125
if res.returncode == 255:
121-
return False, SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
126+
return SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
122127

123-
output = res.stdout
124-
if config.ignore_ssh_banner:
128+
output: Union[bytes, str] = res.stdout
129+
if banner_res:
125130
if banner_res.returncode == 255:
126-
return False, SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
131+
return SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
127132
output = output[len(banner_res.stdout):]
128133

129134
if decode:
135+
assert isinstance(output, bytes)
130136
output = output.decode()
131137

132138
if res.returncode and check:
133-
return False, SSHCommandFailed(res.returncode, output_for_errors, command)
139+
return SSHCommandFailed(res.returncode, output_for_errors, command)
134140

135141
if simple_output:
136-
return True, output.strip()
137-
return True, SSHResult(res.returncode, output)
142+
return output.strip()
143+
return SSHResult(res.returncode, output)
138144

139145
# The actual code is in _ssh().
140146
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
141147
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
142-
def ssh(hostname_or_ip, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True,
143-
background=False, target_os='linux', decode=True, options=[]):
144-
success, result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
145-
background, target_os, decode, options)
146-
if not success:
148+
def ssh(hostname_or_ip, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True, background=False,
149+
target_os='linux', decode=True, options=[]) -> Union[SSHResult, str, bytes, subprocess.Popen]:
150+
result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
151+
background, target_os, decode, options)
152+
if isinstance(result_or_exc, SSHCommandFailed):
153+
raise result_or_exc
154+
else:
155+
return result_or_exc
156+
157+
def ssh_with_result(hostname_or_ip, cmd, suppress_fingerprint_warnings=True,
158+
background=False, target_os='linux', decode=True, options=[]) -> SSHResult:
159+
result_or_exc = _ssh(hostname_or_ip, cmd, False, False, suppress_fingerprint_warnings,
160+
background, target_os, decode, options)
161+
if isinstance(result_or_exc, SSHCommandFailed):
147162
raise result_or_exc
148-
return result_or_exc
163+
elif isinstance(result_or_exc, SSHResult):
164+
return result_or_exc
165+
assert False, "unexpected type"
149166

150167
def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
151168
opts = '-o "BatchMode yes"'
@@ -179,6 +196,7 @@ def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=Tru
179196
return res
180197

181198
def sftp(hostname_or_ip, cmds, check=True, suppress_fingerprint_warnings=True):
199+
opts = ''
182200
if suppress_fingerprint_warnings:
183201
# Suppress warnings and questions related to host key fingerprints
184202
# because on a test network IPs get reused, VMs are reinstalled, etc.

lib/host.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import uuid
77

88
from packaging import version
9+
from typing import TYPE_CHECKING
910

1011
import lib.commands as commands
1112
import lib.pif as pif
13+
if TYPE_CHECKING:
14+
import lib.pool
1215

1316
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
1417
from lib.common import safe_split, strip_suffix, to_xapi_bool, wait_for, wait_for_not
@@ -34,7 +37,7 @@ def host_data(hostname_or_ip):
3437
class Host:
3538
xe_prefix = "host"
3639

37-
def __init__(self, pool, hostname_or_ip):
40+
def __init__(self, pool: 'lib.pool.Pool', hostname_or_ip):
3841
self.pool = pool
3942
self.hostname_or_ip = hostname_or_ip
4043
self.inventory = None
@@ -62,9 +65,9 @@ def ssh(self, cmd, check=True, simple_output=True, suppress_fingerprint_warnings
6265
suppress_fingerprint_warnings=suppress_fingerprint_warnings, background=background,
6366
decode=decode)
6467

65-
def ssh_with_result(self, cmd):
68+
def ssh_with_result(self, cmd) -> commands.SSHResult:
6669
# doesn't raise if the command's return is nonzero, unless there's a SSH error
67-
return self.ssh(cmd, check=False, simple_output=False)
70+
return commands.ssh_with_result(self.hostname_or_ip, cmd)
6871

6972
def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
7073
return commands.scp(

lib/pool.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
import logging
22
import traceback
3+
from typing import Any, Dict, Optional
34

45
from packaging import version
56

67
import lib.commands as commands
7-
8-
from lib.common import safe_split, wait_for, wait_for_not, _param_get, _param_set
8+
from lib.common import _param_get, _param_set, safe_split, wait_for, wait_for_not
99
from lib.host import Host
1010
from lib.sr import SR
1111

12+
1213
class Pool:
1314
xe_prefix = "pool"
1415

15-
def __init__(self, master_hostname_or_ip):
16+
def __init__(self, master_hostname_or_ip: str) -> None:
1617
master = Host(self, master_hostname_or_ip)
1718
assert master.is_master(), f"Host {master_hostname_or_ip} is not a master host. Aborting."
1819
self.master = master
1920
self.hosts = [master]
2021

2122
# wait for XAPI startup to be done, or we can get "Connection
2223
# refused (calling connect )" when calling self.hosts_uuids()
23-
wait_for(lambda: commands.ssh(master_hostname_or_ip, ['xapi-wait-init-complete', '60'],
24-
check=False, simple_output=False).returncode == 0,
24+
wait_for(lambda: commands.ssh_with_result(master_hostname_or_ip,
25+
['xapi-wait-init-complete', '60']).returncode == 0,
2526
f"Wait for XAPI init to be complete on {master_hostname_or_ip}",
2627
timeout_secs=30 * 60)
2728

@@ -30,7 +31,7 @@ def __init__(self, master_hostname_or_ip):
3031
host = Host(self, self.host_ip(host_uuid))
3132
self.hosts.append(host)
3233
self.uuid = self.master.xe('pool-list', minimal=True)
33-
self.saved_uefi_certs = None
34+
self.saved_uefi_certs: Optional[Dict[str, Any]] = None
3435

3536
def param_get(self, param_name, key=None, accept_unknown_key=False):
3637
return _param_get(self.master, Pool.xe_prefix, self.uuid, param_name, key, accept_unknown_key)
@@ -108,7 +109,7 @@ def first_host_that_isnt(self, host):
108109
return h
109110
return None
110111

111-
def first_shared_sr(self):
112+
def first_shared_sr(self) -> Optional[SR]:
112113
uuids = safe_split(self.master.xe('sr-list', {'shared': True, 'content-type': 'user'}, minimal=True))
113114
if len(uuids) > 0:
114115
return SR(uuids[0], self)
@@ -117,7 +118,7 @@ def first_shared_sr(self):
117118
def get_vdi_sr_uuid(self, vdi_uuid):
118119
return self.master.xe('vdi-param-get', {'uuid': vdi_uuid, 'param-name': 'sr-uuid'})
119120

120-
def save_uefi_certs(self):
121+
def save_uefi_certs(self) -> None:
121122
"""
122123
Save UEFI certificates in order to restore them later. XCP-ng 8.2 only.
123124

lib/vm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ def ssh(self, cmd, check=True, simple_output=True, background=False, decode=True
8888
return commands.ssh(self.ip, cmd, check=check, simple_output=simple_output, background=background,
8989
target_os=target_os, decode=decode)
9090

91-
def ssh_with_result(self, cmd):
91+
def ssh_with_result(self, cmd) -> commands.SSHResult:
9292
# doesn't raise if the command's return is nonzero, unless there's a SSH error
93-
return self.ssh(cmd, check=False, simple_output=False)
93+
return commands.ssh_with_result(self.ip, cmd)
9494

9595
def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
9696
# Stop execution if scp() is used on Windows VMs as some OpenSSH releases for Windows don't
@@ -235,7 +235,7 @@ def snapshot(self, ignore_vdis=None):
235235
args['ignore-vdi-uuids'] = ','.join(ignore_vdis)
236236
return Snapshot(self.host.xe('vm-snapshot', args), self.host)
237237

238-
def checkpoint(self):
238+
def checkpoint(self) -> Snapshot:
239239
logging.info("Checkpoint VM")
240240
return Snapshot(self.host.xe('vm-checkpoint', {'uuid': self.uuid,
241241
'new-name-label': 'Checkpoint of %s' % self.uuid}),
@@ -255,7 +255,7 @@ def get_residence_host(self):
255255
host_uuid = self.param_get('resident-on')
256256
return self.host.pool.get_host_by_uuid(host_uuid)
257257

258-
def start_background_process(self, cmd):
258+
def start_background_process(self, cmd) -> str:
259259
script = "/tmp/bg_process.sh"
260260
pidfile = "/tmp/bg_process.pid"
261261
with tempfile.NamedTemporaryFile('w') as f:

tests/misc/test_basic_without_ssh.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import pytest
33

44
from lib.common import wait_for
5+
from lib.host import Host
6+
from lib.sr import SR
7+
from lib.vm import VM
58

69
# These tests are basic tests meant to be run to check that a VM performs
710
# well, without obvious issues.
@@ -19,14 +22,14 @@
1922
# Note however that an existing VM will be left on a different SR after the tests.
2023

2124
@pytest.fixture(scope='session')
22-
def existing_shared_sr(host):
25+
def existing_shared_sr(host: Host) -> SR:
2326
sr = host.pool.first_shared_sr()
2427
assert sr is not None, "A shared SR on the pool is required"
2528
return sr
2629

2730
@pytest.mark.multi_vms # run them on a variety of VMs
2831
@pytest.mark.big_vm # and also on a really big VM ideally
29-
def test_vm_start_stop(imported_vm):
32+
def test_vm_start_stop(imported_vm: VM):
3033
vm = imported_vm
3134
# if VM already running, stop it
3235
if (vm.is_running()):
@@ -43,19 +46,19 @@ def test_vm_start_stop(imported_vm):
4346
@pytest.mark.big_vm # and also on a really big VM ideally
4447
@pytest.mark.usefixtures("started_vm")
4548
class TestBasicNoSSH:
46-
def test_pause(self, imported_vm):
49+
def test_pause(self, imported_vm: VM):
4750
vm = imported_vm
4851
vm.pause(verify=True)
4952
vm.unpause()
5053
vm.wait_for_os_booted()
5154

52-
def test_suspend(self, imported_vm):
55+
def test_suspend(self, imported_vm: VM):
5356
vm = imported_vm
5457
vm.suspend(verify=True)
5558
vm.resume()
5659
vm.wait_for_os_booted()
5760

58-
def test_snapshot(self, imported_vm):
61+
def test_snapshot(self, imported_vm: VM):
5962
vm = imported_vm
6063
snapshot = vm.snapshot()
6164
try:
@@ -65,7 +68,7 @@ def test_snapshot(self, imported_vm):
6568
finally:
6669
snapshot.destroy(verify=True)
6770

68-
def test_checkpoint(self, imported_vm):
71+
def test_checkpoint(self, imported_vm: VM):
6972
vm = imported_vm
7073
snapshot = vm.checkpoint()
7174
try:
@@ -79,7 +82,7 @@ def test_checkpoint(self, imported_vm):
7982
# We want to test storage migration (memory+disks) and live migration without storage migration (memory only).
8083
# The order will depend on the initial location of the VM: a local SR or a shared SR.
8184
@pytest.mark.usefixtures("hostA2")
82-
def test_live_migrate(self, imported_vm, existing_shared_sr):
85+
def test_live_migrate(self, imported_vm: VM, existing_shared_sr: SR):
8386
def live_migrate(vm, dest_host, dest_sr, check_vdis=False):
8487
vm.migrate(dest_host, dest_sr)
8588
if check_vdis:

0 commit comments

Comments
 (0)