Skip to content

Commit 5f45470

Browse files
committed
add some type hints
Signed-off-by: Gaëtan Lehmann <[email protected]>
1 parent 5de806e commit 5f45470

File tree

7 files changed

+89
-43
lines changed

7 files changed

+89
-43
lines changed

lib/basevm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22

33
import lib.commands as commands
4+
import lib.host
45

56
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
67
from lib.sr import SR
@@ -10,7 +11,7 @@ class BaseVM:
1011

1112
xe_prefix = "vm"
1213

13-
def __init__(self, uuid, host):
14+
def __init__(self, uuid, host: 'lib.host.Host'):
1415
logging.info("New %s: %s", type(self).__name__, uuid)
1516
self.uuid = uuid
1617
self.host = host
@@ -36,7 +37,9 @@ def param_clear(self, param_name):
3637
param_name)
3738

3839
def name(self):
39-
return self.param_get('name-label')
40+
n = self.param_get('name-label')
41+
assert isinstance(n, str)
42+
return n
4043

4144
def vdi_uuids(self, sr_uuid=None):
4245
output = self._disk_list()

lib/commands.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import logging
33
import shlex
44
import subprocess
5+
from typing import Union
56

6-
import lib.config as config
7+
from _pytest.fixtures import _teardown_yield_fixture
78

9+
import lib.config as config
810
from lib.netutil import wrap_ip
911

12+
1013
class BaseCommandFailed(Exception):
1114
__slots__ = 'returncode', 'stdout', 'cmd'
1215

@@ -61,7 +64,7 @@ def _ellide_log_lines(log):
6164
return "\n{}".format("\n".join(reduced_message))
6265

6366
def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
64-
background, target_os, decode, options):
67+
background, target_os, decode, options) -> Union[SSHResult, SSHCommandFailed, str, bytes, subprocess.Popen]:
6568
opts = list(options)
6669
opts.append('-o "BatchMode yes"')
6770
if suppress_fingerprint_warnings:
@@ -86,6 +89,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
8689

8790
windows_background = background and target_os == "windows"
8891
# Fetch banner and remove it to avoid stdout/stderr pollution.
92+
banner_res = None
8993
if config.ignore_ssh_banner and not windows_background:
9094
banner_res = subprocess.run(
9195
"ssh root@%s %s '%s'" % (hostname_or_ip, ' '.join(opts), '\n'),
@@ -103,9 +107,10 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
103107
)
104108
logging.debug(f"[{hostname_or_ip}] {command}")
105109
if windows_background:
106-
return True, process
110+
return process
107111

108112
stdout = []
113+
assert process.stdout is not None
109114
for line in iter(process.stdout.readline, b''):
110115
readable_line = line.decode(errors='replace').strip()
111116
stdout.append(line)
@@ -118,34 +123,55 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
118123

119124
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
120125
if res.returncode == 255:
121-
return False, SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
126+
return SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
122127

123-
output = res.stdout
124-
if config.ignore_ssh_banner:
128+
output: bytes | str = res.stdout
129+
if banner_res:
125130
if banner_res.returncode == 255:
126-
return False, SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
131+
return SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
127132
output = output[len(banner_res.stdout):]
128133

129134
if decode:
130135
output = output.decode()
131136

132137
if res.returncode and check:
133-
return False, SSHCommandFailed(res.returncode, output_for_errors, command)
138+
return SSHCommandFailed(res.returncode, output_for_errors, command)
134139

135140
if simple_output:
136-
return True, output.strip()
137-
return True, SSHResult(res.returncode, output)
141+
return output.strip()
142+
return SSHResult(res.returncode, output)
138143

139144
# The actual code is in _ssh().
140145
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
141146
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
142-
def ssh(hostname_or_ip, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True,
143-
background=False, target_os='linux', decode=True, options=[]):
144-
success, result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
145-
background, target_os, decode, options)
146-
if not success:
147+
def ssh(hostname_or_ip, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True, background=False,
148+
target_os='linux', decode=True, options=[]) -> Union[SSHResult, str, bytes, subprocess.Popen]:
149+
result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
150+
background, target_os, decode, options)
151+
if isinstance(result_or_exc, SSHCommandFailed):
152+
raise result_or_exc
153+
else:
154+
return result_or_exc
155+
156+
def ssh_str(hostname_or_ip, cmd, check=True, suppress_fingerprint_warnings=True,
157+
background=False, target_os='linux', options=[]) -> str:
158+
result_or_exc = _ssh(hostname_or_ip, cmd, check, True, suppress_fingerprint_warnings,
159+
background, target_os, True, options)
160+
if isinstance(result_or_exc, SSHCommandFailed):
161+
raise result_or_exc
162+
elif isinstance(result_or_exc, str):
163+
return result_or_exc
164+
assert False, "unexpected type"
165+
166+
def ssh_with_result(hostname_or_ip, cmd, suppress_fingerprint_warnings=True,
167+
background=False, target_os='linux', decode=True, options=[]) -> SSHResult:
168+
result_or_exc = _ssh(hostname_or_ip, cmd, False, False, suppress_fingerprint_warnings,
169+
background, target_os, decode, options)
170+
if isinstance(result_or_exc, SSHCommandFailed):
147171
raise result_or_exc
148-
return result_or_exc
172+
elif isinstance(result_or_exc, SSHResult):
173+
return result_or_exc
174+
assert False, "unexpected type"
149175

150176
def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
151177
opts = '-o "BatchMode yes"'
@@ -179,6 +205,7 @@ def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=Tru
179205
return res
180206

181207
def sftp(hostname_or_ip, cmds, check=True, suppress_fingerprint_warnings=True):
208+
opts = ''
182209
if suppress_fingerprint_warnings:
183210
# Suppress warnings and questions related to host key fingerprints
184211
# because on a test network IPs get reused, VMs are reinstalled, etc.

lib/host.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import lib.commands as commands
1111
import lib.pif as pif
12+
import lib.pool
1213

1314
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
1415
from lib.common import safe_split, strip_suffix, to_xapi_bool, wait_for, wait_for_not
@@ -34,7 +35,7 @@ def host_data(hostname_or_ip):
3435
class Host:
3536
xe_prefix = "host"
3637

37-
def __init__(self, pool, hostname_or_ip):
38+
def __init__(self, pool: 'lib.pool.Pool', hostname_or_ip):
3839
self.pool = pool
3940
self.hostname_or_ip = hostname_or_ip
4041
self.inventory = None
@@ -62,9 +63,13 @@ def ssh(self, cmd, check=True, simple_output=True, suppress_fingerprint_warnings
6263
suppress_fingerprint_warnings=suppress_fingerprint_warnings, background=background,
6364
decode=decode)
6465

65-
def ssh_with_result(self, cmd):
66+
def ssh_str(self, cmd, check=True, background=False) -> str:
67+
# raises by default for any nonzero return code
68+
return commands.ssh_str(self.hostname_or_ip, cmd, check=check, background=background)
69+
70+
def ssh_with_result(self, cmd) -> commands.SSHResult:
6671
# doesn't raise if the command's return is nonzero, unless there's a SSH error
67-
return self.ssh(cmd, check=False, simple_output=False)
72+
return commands.ssh_with_result(self.hostname_or_ip, cmd)
6873

6974
def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
7075
return commands.scp(

lib/pool.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import logging
22
import traceback
3+
from typing import Optional
34

45
from packaging import version
56

67
import lib.commands as commands
7-
8-
from lib.common import safe_split, wait_for, wait_for_not, _param_get, _param_set
8+
from lib.common import _param_get, _param_set, safe_split, wait_for, wait_for_not
99
from lib.host import Host
1010
from lib.sr import SR
1111

12+
1213
class Pool:
1314
xe_prefix = "pool"
1415

@@ -20,8 +21,8 @@ def __init__(self, master_hostname_or_ip):
2021

2122
# wait for XAPI startup to be done, or we can get "Connection
2223
# refused (calling connect )" when calling self.hosts_uuids()
23-
wait_for(lambda: commands.ssh(master_hostname_or_ip, ['xapi-wait-init-complete', '60'],
24-
check=False, simple_output=False).returncode == 0,
24+
wait_for(lambda: commands.ssh_with_result(master_hostname_or_ip,
25+
['xapi-wait-init-complete', '60']).returncode == 0,
2526
f"Wait for XAPI init to be complete on {master_hostname_or_ip}",
2627
timeout_secs=30 * 60)
2728

@@ -108,7 +109,7 @@ def first_host_that_isnt(self, host):
108109
return h
109110
return None
110111

111-
def first_shared_sr(self):
112+
def first_shared_sr(self) -> Optional[SR]:
112113
uuids = safe_split(self.master.xe('sr-list', {'shared': True, 'content-type': 'user'}, minimal=True))
113114
if len(uuids) > 0:
114115
return SR(uuids[0], self)
@@ -138,7 +139,7 @@ def save_uefi_certs(self):
138139
assert self.master.xcp_version < version.parse("8.3"), "this function should only be needed on XCP-ng 8.2"
139140
logging.info('Saving pool UEFI certificates')
140141

141-
if int(self.master.ssh(["secureboot-certs", "--version"]).split(".")[0]) < 1:
142+
if int(self.master.ssh_str(["secureboot-certs", "--version"]).split(".")[0]) < 1:
142143
raise RuntimeError("The host must have secureboot-certs version >= 1.0.0")
143144

144145
saved_certs = {

lib/vm.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,15 @@ def ssh(self, cmd, check=True, simple_output=True, background=False, decode=True
8888
return commands.ssh(self.ip, cmd, check=check, simple_output=simple_output, background=background,
8989
target_os=target_os, decode=decode)
9090

91-
def ssh_with_result(self, cmd):
91+
def ssh_str(self, cmd, check=True, background=False) -> str:
92+
# raises by default for any nonzero return code
93+
target_os = "windows" if self.is_windows else "linux"
94+
return commands.ssh_str(self.ip, cmd, check=check, background=background,
95+
target_os=target_os)
96+
97+
def ssh_with_result(self, cmd) -> commands.SSHResult:
9298
# doesn't raise if the command's return is nonzero, unless there's a SSH error
93-
return self.ssh(cmd, check=False, simple_output=False)
99+
return commands.ssh_with_result(self.ip, cmd)
94100

95101
def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
96102
# Stop execution if scp() is used on Windows VMs as some OpenSSH releases for Windows don't
@@ -235,7 +241,7 @@ def snapshot(self, ignore_vdis=None):
235241
args['ignore-vdi-uuids'] = ','.join(ignore_vdis)
236242
return Snapshot(self.host.xe('vm-snapshot', args), self.host)
237243

238-
def checkpoint(self):
244+
def checkpoint(self) -> Snapshot:
239245
logging.info("Checkpoint VM")
240246
return Snapshot(self.host.xe('vm-checkpoint', {'uuid': self.uuid,
241247
'new-name-label': 'Checkpoint of %s' % self.uuid}),
@@ -255,7 +261,7 @@ def get_residence_host(self):
255261
host_uuid = self.param_get('resident-on')
256262
return self.host.pool.get_host_by_uuid(host_uuid)
257263

258-
def start_background_process(self, cmd):
264+
def start_background_process(self, cmd) -> str:
259265
script = "/tmp/bg_process.sh"
260266
pidfile = "/tmp/bg_process.pid"
261267
with tempfile.NamedTemporaryFile('w') as f:
@@ -276,7 +282,7 @@ def start_background_process(self, cmd):
276282
self.ssh(['bash', script], background=True)
277283
wait_for(lambda: self.ssh_with_result(['test', '-f', pidfile]),
278284
"wait for pid file %s to exist" % pidfile)
279-
pid = self.ssh(['cat', pidfile])
285+
pid = self.ssh_str(['cat', pidfile])
280286
self.ssh(['rm', '-f', script])
281287
self.ssh(['rm', '-f', pidfile])
282288
return pid

tests/misc/test_basic_without_ssh.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import pytest
33

44
from lib.common import wait_for
5+
from lib.host import Host
6+
from lib.sr import SR
7+
from lib.vm import VM
58

69
# These tests are basic tests meant to be run to check that a VM performs
710
# well, without obvious issues.
@@ -19,14 +22,14 @@
1922
# Note however that an existing VM will be left on a different SR after the tests.
2023

2124
@pytest.fixture(scope='session')
22-
def existing_shared_sr(host):
25+
def existing_shared_sr(host: Host) -> SR:
2326
sr = host.pool.first_shared_sr()
2427
assert sr is not None, "A shared SR on the pool is required"
2528
return sr
2629

2730
@pytest.mark.multi_vms # run them on a variety of VMs
2831
@pytest.mark.big_vm # and also on a really big VM ideally
29-
def test_vm_start_stop(imported_vm):
32+
def test_vm_start_stop(imported_vm: VM):
3033
vm = imported_vm
3134
# if VM already running, stop it
3235
if (vm.is_running()):
@@ -43,19 +46,19 @@ def test_vm_start_stop(imported_vm):
4346
@pytest.mark.big_vm # and also on a really big VM ideally
4447
@pytest.mark.usefixtures("started_vm")
4548
class TestBasicNoSSH:
46-
def test_pause(self, imported_vm):
49+
def test_pause(self, imported_vm: VM):
4750
vm = imported_vm
4851
vm.pause(verify=True)
4952
vm.unpause()
5053
vm.wait_for_os_booted()
5154

52-
def test_suspend(self, imported_vm):
55+
def test_suspend(self, imported_vm: VM):
5356
vm = imported_vm
5457
vm.suspend(verify=True)
5558
vm.resume()
5659
vm.wait_for_os_booted()
5760

58-
def test_snapshot(self, imported_vm):
61+
def test_snapshot(self, imported_vm: VM):
5962
vm = imported_vm
6063
snapshot = vm.snapshot()
6164
try:
@@ -65,7 +68,7 @@ def test_snapshot(self, imported_vm):
6568
finally:
6669
snapshot.destroy(verify=True)
6770

68-
def test_checkpoint(self, imported_vm):
71+
def test_checkpoint(self, imported_vm: VM):
6972
vm = imported_vm
7073
snapshot = vm.checkpoint()
7174
try:
@@ -79,7 +82,7 @@ def test_checkpoint(self, imported_vm):
7982
# We want to test storage migration (memory+disks) and live migration without storage migration (memory only).
8083
# The order will depend on the initial location of the VM: a local SR or a shared SR.
8184
@pytest.mark.usefixtures("hostA2")
82-
def test_live_migrate(self, imported_vm, existing_shared_sr):
85+
def test_live_migrate(self, imported_vm: VM, existing_shared_sr: SR):
8386
def live_migrate(vm, dest_host, dest_sr, check_vdis=False):
8487
vm.migrate(dest_host, dest_sr)
8588
if check_vdis:

tests/misc/test_vm_basic_operations.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,30 @@
22
import pytest
33

44
from lib.common import wait_for_not
5+
from lib.vm import VM
56

67
@pytest.mark.multi_vms
78
class Test:
8-
def test_pause(self, running_vm):
9+
def test_pause(self, running_vm: VM):
910
vm = running_vm
1011
vm.pause(verify=True)
1112
vm.unpause()
1213
vm.wait_for_vm_running_and_ssh_up()
1314

14-
def test_suspend(self, running_vm):
15+
def test_suspend(self, running_vm: VM):
1516
vm = running_vm
1617
vm.suspend(verify=True)
1718
vm.resume()
1819
vm.wait_for_vm_running_and_ssh_up()
1920

20-
def test_snapshot(self, running_vm):
21+
def test_snapshot(self, running_vm: VM):
2122
vm = running_vm
2223
vm.test_snapshot_on_running_vm()
2324

2425
# When using a windows VM the background ssh process is never terminated
2526
# This results in a ResourceWarning
2627
@pytest.mark.filterwarnings("ignore::ResourceWarning")
27-
def test_checkpoint(self, running_vm):
28+
def test_checkpoint(self, running_vm: VM):
2829
vm = running_vm
2930
logging.info("Start a 'sleep' process on VM through SSH")
3031
pid = vm.start_background_process('sleep 10000')

0 commit comments

Comments
 (0)