Skip to content

Commit 037a329

Browse files
committed
Adapt return types of xe and param_get
Simplify the return type of `xe`, remove bool being a implicit cast for `xe`. Adapt some call that were expecting a bool to use `strtobool` explicitely instead. Adapted `_param_get` to new types for xe. Signed-off-by: Damien Thenot <[email protected]>
1 parent af630cf commit 037a329

File tree

9 files changed

+87
-40
lines changed

9 files changed

+87
-40
lines changed

lib/basevm.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from typing import Any, Optional, TYPE_CHECKING, Union
3+
from typing import Any, Literal, Optional, overload, TYPE_CHECKING
44

55
import lib.commands as commands
66
if TYPE_CHECKING:
@@ -13,14 +13,25 @@ class BaseVM:
1313
""" Base class for VM and Snapshot. """
1414

1515
xe_prefix = "vm"
16+
uuid: str
1617

17-
def __init__(self, uuid, host: 'lib.host.Host'):
18+
def __init__(self, uuid: str, host: 'lib.host.Host'):
1819
logging.info("New %s: %s", type(self).__name__, uuid)
1920
self.uuid = uuid
2021
self.host = host
2122

23+
@overload
24+
def param_get(self, param_name: str, key: Optional[str] = ...,
25+
accept_unknown_key: Literal[False] = ...) -> str:
26+
...
27+
28+
@overload
29+
def param_get(self, param_name: str, key: Optional[str] = ...,
30+
accept_unknown_key: Literal[True] = ...) -> Optional[str]:
31+
...
32+
2233
def param_get(self, param_name: str, key: Optional[str] = None,
23-
accept_unknown_key: bool = False) -> Union[str, bool, None]:
34+
accept_unknown_key: bool = False) -> Optional[str]:
2435
return _param_get(self.host, self.xe_prefix, self.uuid,
2536
param_name, key, accept_unknown_key)
2637

lib/common.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import traceback
77
from enum import Enum
8-
from typing import TYPE_CHECKING, Dict, Optional, Union
8+
from typing import Dict, Literal, Optional, overload, TYPE_CHECKING, Union
99
from uuid import UUID
1010

1111
import lib.commands as commands
@@ -63,7 +63,7 @@ def is_uuid(maybe_uuid):
6363
except ValueError:
6464
return False
6565

66-
def to_xapi_bool(b):
66+
def to_xapi_bool(b: bool):
6767
return 'true' if b else 'false'
6868

6969
def parse_xe_dict(xe_dict):
@@ -160,8 +160,23 @@ def strtobool(str):
160160
return False
161161
raise ValueError("invalid truth value '{}'".format(str))
162162

163+
@overload
164+
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
165+
accept_unknown_key: Literal[False] = ...) -> str:
166+
...
167+
168+
@overload
169+
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
170+
accept_unknown_key: Literal[True] = ...) -> Optional[str]:
171+
...
172+
173+
@overload
174+
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
175+
accept_unknown_key: bool = ...) -> Optional[str]:
176+
...
177+
163178
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = None,
164-
accept_unknown_key=False) -> Union[str, bool, None]:
179+
accept_unknown_key: bool = False) -> Optional[str]:
165180
""" Common implementation for param_get. """
166181
args: Dict[str, Union[str, bool]] = {'uuid': uuid, 'param-name': param_name}
167182
if key is not None:

lib/host.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
import uuid
99

1010
from packaging import version
11-
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union, overload
11+
from typing import Dict, List, Literal, Optional, overload, TYPE_CHECKING, Union
1212

1313
import lib.commands as commands
1414
import lib.pif as pif
1515
if TYPE_CHECKING:
1616
import lib.pool
1717

18-
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
18+
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set, strtobool
1919
from lib.common import safe_split, strip_suffix, to_xapi_bool, wait_for, wait_for_not
2020
from lib.common import prefix_object_name
2121
from lib.netutil import wrap_ip
@@ -106,17 +106,17 @@ def scp(self, src, dest, check=True, suppress_fingerprint_warnings=True, local_d
106106
)
107107

108108
@overload
109-
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = True,
110-
simple_output: Literal[True] = True, minimal: bool = False, force: bool = False) -> Union[bool, str]:
109+
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = ...,
110+
simple_output: Literal[True] = ..., minimal: bool = ..., force: bool = ...) -> str:
111111
...
112112

113113
@overload
114-
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = True,
115-
simple_output: Literal[False], minimal: bool = False, force: bool = False) -> commands.SSHResult:
114+
def xe(self, action: str, args: Dict[str, Union[str, bool]] = {}, *, check: bool = ...,
115+
simple_output: Literal[False], minimal: bool = ..., force: bool = ...) -> commands.SSHResult:
116116
...
117117

118118
def xe(self, action, args={}, *, check=True, simple_output=True, minimal=False, force=False) \
119-
-> Union[bool, str, commands.SSHResult]:
119+
-> Union[str, commands.SSHResult]:
120120
maybe_param_minimal = ['--minimal'] if minimal else []
121121
maybe_param_force = ['--force'] if force else []
122122

@@ -139,13 +139,19 @@ def stringify(key, value):
139139
)
140140
assert isinstance(result, (str, commands.SSHResult))
141141

142-
if result == 'true':
143-
return True
144-
if result == 'false':
145-
return False
146142
return result
147143

148-
def param_get(self, param_name, key=None, accept_unknown_key=False):
144+
@overload
145+
def param_get(self, param_name: str, key: Optional[str] = ...,
146+
accept_unknown_key: Literal[False] = ...) -> str:
147+
...
148+
149+
@overload
150+
def param_get(self, param_name: str, key: Optional[str] = ...,
151+
accept_unknown_key: Literal[True] = ...) -> Optional[str]:
152+
...
153+
154+
def param_get(self, param_name: str, key: Optional[str] = None, accept_unknown_key: bool = False) -> Optional[str]:
149155
return _param_get(self, self.xe_prefix, self.uuid,
150156
param_name, key, accept_unknown_key)
151157

@@ -330,7 +336,7 @@ def import_iso(self, uri, sr: SR):
330336

331337
download_path = None
332338
try:
333-
params = {'uuid': vdi_uuid}
339+
params: Dict[str, Union[str, bool]] = {'uuid': vdi_uuid}
334340
if '://' in uri:
335341
logging.info(f"Download ISO {uri}")
336342
download_path = f'/tmp/{vdi_uuid}'
@@ -363,9 +369,9 @@ def restart_toolstack(self, verify=False):
363369
if verify:
364370
wait_for(self.is_enabled, "Wait for host enabled", timeout_secs=30 * 60)
365371

366-
def is_enabled(self):
372+
def is_enabled(self) -> bool:
367373
try:
368-
return self.param_get('enabled')
374+
return strtobool(self.param_get('enabled'))
369375
except commands.SSHCommandFailed:
370376
# If XAPI is not ready yet, or the host is down, this will throw. We return False in that case.
371377
return False
@@ -621,9 +627,9 @@ def join_pool(self, pool):
621627
f"Wait for joining host {self} to appear in joined pool {master}."
622628
)
623629
pool.hosts.append(Host(pool, pool.host_ip(self.uuid)))
624-
# Do not use `self.is_enabled` since it'd ask the XAPi of hostB1 before the join...
630+
# Do not use `self.is_enabled` since it'd ask the XAPI of hostB1 before the join...
625631
wait_for(
626-
lambda: master.xe('host-param-get', {'uuid': self.uuid, 'param-name': 'enabled'}),
632+
lambda: strtobool(master.xe('host-param-get', {'uuid': self.uuid, 'param-name': 'enabled'})),
627633
f"Wait for pool {master} to see joined host {self} as enabled."
628634
)
629635
self.pool = pool

lib/pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, master_hostname_or_ip: str) -> None:
3030
if host_uuid != self.hosts[0].uuid:
3131
host = Host(self, self.host_ip(host_uuid))
3232
self.hosts.append(host)
33-
self.uuid = cast(str, self.master.xe('pool-list', minimal=True))
33+
self.uuid = self.master.xe('pool-list', minimal=True)
3434
self.saved_uefi_certs: Optional[Dict[str, Any]] = None
3535

3636
def param_get(self, param_name, key=None, accept_unknown_key=False):

lib/sr.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import lib.commands as commands
55

6-
from lib.common import prefix_object_name, safe_split, wait_for, wait_for_not
6+
from lib.common import prefix_object_name, safe_split, strtobool, wait_for, wait_for_not
77
from lib.vdi import VDI
88

99
class SR:
@@ -41,8 +41,10 @@ def unplug_pbds(self, force=False):
4141
def all_pbds_attached(self):
4242
all_attached = True
4343
for pbd_uuid in self.pbd_uuids():
44-
all_attached = all_attached and self.pool.master.xe('pbd-param-get', {'uuid': pbd_uuid,
45-
'param-name': 'currently-attached'})
44+
all_attached = all_attached and strtobool(self.pool.master.xe('pbd-param-get',
45+
{'uuid': pbd_uuid,
46+
'param-name': 'currently-attached',
47+
}))
4648
return all_attached
4749

4850
def plug_pbd(self, pbd_uuid):
@@ -147,7 +149,8 @@ def content_type(self):
147149

148150
def is_shared(self):
149151
if self._is_shared is None:
150-
self._is_shared = self.pool.master.xe('sr-param-get', {'uuid': self.uuid, 'param-name': 'shared'})
152+
self._is_shared = strtobool(self.pool.master.xe('sr-param-get',
153+
{'uuid': self.uuid, 'param-name': 'shared'}))
151154
return self._is_shared
152155

153156
def create_vdi(self, name_label, virtual_size=64):

lib/vdi.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
3+
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set, strtobool
44
from typing import Literal, Optional, overload, TYPE_CHECKING
55
if TYPE_CHECKING:
66
from lib.host import Host
@@ -31,7 +31,7 @@ def __init__(self, uuid, *, host=None, sr=None):
3131
else:
3232
self.sr = sr
3333

34-
def name(self):
34+
def name(self) -> str:
3535
return self.param_get('name-label')
3636

3737
def destroy(self):
@@ -42,13 +42,24 @@ def clone(self):
4242
uuid = self.sr.pool.master.xe('vdi-clone', {'uuid': self.uuid})
4343
return VDI(uuid, sr=self.sr)
4444

45-
def readonly(self):
46-
return self.param_get("read-only") == "true"
45+
def readonly(self) -> bool:
46+
return strtobool(self.param_get("read-only"))
4747

4848
def __str__(self):
4949
return f"VDI {self.uuid} on SR {self.sr.uuid}"
5050

51-
def param_get(self, param_name, key=None, accept_unknown_key=False):
51+
@overload
52+
def param_get(self, param_name: str, key: Optional[str] = ...,
53+
accept_unknown_key: Literal[False] = ...) -> str:
54+
...
55+
56+
@overload
57+
def param_get(self, param_name: str, key: Optional[str] = ...,
58+
accept_unknown_key: Literal[True] = ...) -> Optional[str]:
59+
...
60+
61+
def param_get(self, param_name: str, key: Optional[str] = None,
62+
accept_unknown_key: bool = False) -> Optional[str]:
5263
return _param_get(self.sr.pool.master, self.xe_prefix, self.uuid,
5364
param_name, key, accept_unknown_key)
5465

lib/vm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import subprocess
66
import tempfile
7-
from typing import List, Literal, Optional, Union, cast, overload
7+
from typing import Dict, List, Literal, Optional, overload, Union
88

99
import lib.commands as commands
1010
import lib.efi as efi
@@ -22,7 +22,7 @@ def __init__(self, uuid, host):
2222
self.is_windows = self.param_get('platform', 'device_id', accept_unknown_key=True) == '0002'
2323
self.is_uefi = self.param_get('HVM-boot-params', 'firmware', accept_unknown_key=True) == 'uefi'
2424

25-
def power_state(self):
25+
def power_state(self) -> str:
2626
return self.param_get('power-state')
2727

2828
def is_running(self):
@@ -74,7 +74,7 @@ def reboot(self, force=False, verify=False):
7474
return ret
7575

7676
def try_get_and_store_ip(self):
77-
ip = cast(str, self.param_get('networks', '0/ip', accept_unknown_key=True))
77+
ip = self.param_get('networks', '0/ip', accept_unknown_key=True)
7878

7979
# An IP that starts with 169.254. is not a real routable IP.
8080
# VMs may return such an IP before they get an actual one from DHCP.
@@ -709,7 +709,7 @@ def are_windows_drivers_present(self):
709709

710710
def are_windows_tools_working(self):
711711
assert self.is_windows
712-
return self.is_windows_pv_device_installed() and self.param_get("PV-drivers-detected")
712+
return self.is_windows_pv_device_installed() and strtobool(self.param_get("PV-drivers-detected"))
713713

714714
def are_windows_tools_uninstalled(self):
715715
assert self.is_windows

tests/guest_tools/win/other_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from pathlib import PureWindowsPath
33
from typing import Any, Dict
44

5-
from lib.common import wait_for
5+
from lib.common import strtobool, wait_for
66
from lib.vm import VM
77
from . import WINDOWS_SHUTDOWN_COMMAND, enable_testsign, insert_cd_safe, wait_for_vm_running_and_ssh_up_without_tools
88

99

1010
def install_other_drivers(vm: VM, other_tools_iso_name: str, param: Dict[str, Any]):
1111
if param.get("vendor_device"):
12-
assert not vm.param_get("has-vendor-device")
12+
assert not strtobool(vm.param_get("has-vendor-device"))
1313
vm.param_set("has-vendor-device", True)
1414

1515
vm.start()

tests/xapi/tls_verification/test_tls_verification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33

44
from lib.commands import SSHCommandFailed
5+
from lib.common import strtobool
56

67
# Requirements:
78
# From --hosts parameter:
@@ -21,7 +22,7 @@
2122
def host_with_tls_verification_enabled(hostA1):
2223
for h in hostA1.pool.hosts:
2324
logging.info(f"Check that TLS verification is enabled on host {h}")
24-
assert h.param_get("tls-verification-enabled"), f"TLS verification must be enabled on host {h}"
25+
assert strtobool(h.param_get("tls-verification-enabled")), f"TLS verification must be enabled on host {h}"
2526
logging.info(f"Check that the host certificate exists on host {h}")
2627
cert_uuid = hostA1.xe('certificate-list', {'host': h.uuid, 'type': 'host_internal'}, minimal=True)
2728
assert len(cert_uuid) > 0, f"A host_internal certificate must exist on host {h}"

0 commit comments

Comments
 (0)