Skip to content

Commit 2fe8fd3

Browse files
authored
Merge pull request #284 from xcp-ng/add-some-type-hints
add some type hints
2 parents cbfcd2e + 0765423 commit 2fe8fd3

File tree

14 files changed

+291
-98
lines changed

14 files changed

+291
-98
lines changed

.github/workflows/jobs-check.yml

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,21 @@ jobs:
3737
run: cp data.py-dist data.py
3838
- name: Check with mypy
3939
run: mypy lib/ tests/
40+
41+
pyright:
42+
runs-on: ubuntu-latest
43+
steps:
44+
- uses: actions/checkout@v2
45+
- name: Set up Python
46+
uses: actions/setup-python@v4
47+
with:
48+
python-version: 3.8
49+
- name: Install dependencies
50+
run: |
51+
python -m pip install --upgrade pip
52+
pip install -r requirements/base.txt
53+
pip install pyright
54+
- name: Create a dummy data.py
55+
run: cp data.py-dist data.py
56+
- name: Check with pyright
57+
run: pyright lib/ # tests/

data.py-dist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ HOST_DEFAULT_PASSWORD = ""
1212
# that the tests will create or import, such as VMs and SRs.
1313
# Default value: [your login/user]
1414
# OBJECTS_NAME_PREFIX = "[TEST]"
15+
OBJECTS_NAME_PREFIX = None
1516

1617
# Override settings for specific hosts
1718
# skip_xo_config allows to not touch XO's configuration regarding the host

lib/basevm.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
22

3+
from typing import Any, Optional, TYPE_CHECKING, Union
4+
35
import lib.commands as commands
6+
if TYPE_CHECKING:
7+
import lib.host
48

59
from lib.common import _param_add, _param_clear, _param_get, _param_remove, _param_set
610
from lib.sr import SR
@@ -10,33 +14,40 @@ class BaseVM:
1014

1115
xe_prefix = "vm"
1216

13-
def __init__(self, uuid, host):
17+
def __init__(self, uuid, host: 'lib.host.Host'):
1418
logging.info("New %s: %s", type(self).__name__, uuid)
1519
self.uuid = uuid
1620
self.host = host
1721

18-
def param_get(self, param_name, key=None, accept_unknown_key=False):
22+
def param_get(self, param_name: str, key: Optional[str] = None,
23+
accept_unknown_key: bool = False) -> Union[str, bool, None]:
1924
return _param_get(self.host, self.xe_prefix, self.uuid,
2025
param_name, key, accept_unknown_key)
2126

22-
def param_set(self, param_name, value, key=None):
27+
def param_set(self, param_name: str, value: Any, key: Optional[str] = None) -> None:
2328
_param_set(self.host, self.xe_prefix, self.uuid,
2429
param_name, value, key)
2530

26-
def param_remove(self, param_name, key, accept_unknown_key=False):
31+
def param_remove(self, param_name: str, key: str, accept_unknown_key=False) -> None:
2732
_param_remove(self.host, self.xe_prefix, self.uuid,
2833
param_name, key, accept_unknown_key)
2934

30-
def param_add(self, param_name, value, key=None):
35+
def param_add(self, param_name: str, value: str, key=None) -> None:
3136
_param_add(self.host, self.xe_prefix, self.uuid,
3237
param_name, value, key)
3338

34-
def param_clear(self, param_name):
39+
def param_clear(self, param_name: str) -> None:
3540
_param_clear(self.host, self.xe_prefix, self.uuid,
3641
param_name)
3742

38-
def name(self):
39-
return self.param_get('name-label')
43+
def name(self) -> str:
44+
n = self.param_get('name-label')
45+
assert isinstance(n, str)
46+
return n
47+
48+
# @abstractmethod
49+
def _disk_list(self):
50+
raise NotImplementedError()
4051

4152
def vdi_uuids(self, sr_uuid=None):
4253
output = self._disk_list()
@@ -54,23 +65,17 @@ def vdi_uuids(self, sr_uuid=None):
5465
vdis_on_sr.append(vdi)
5566
return vdis_on_sr
5667

57-
def destroy_vdi(self, vdi_uuid):
68+
def destroy_vdi(self, vdi_uuid: str) -> None:
5869
self.host.xe('vdi-destroy', {'uuid': vdi_uuid})
5970

60-
# FIXME: move this method and the above back to class VM if not useful in Snapshot class?
61-
def destroy(self):
62-
for vdi_uuid in self.vdi_uuids():
63-
self.destroy_vdi(vdi_uuid)
64-
self._destroy()
65-
6671
def all_vdis_on_host(self, host):
6772
for vdi_uuid in self.vdi_uuids():
6873
sr = SR(self.host.pool.get_vdi_sr_uuid(vdi_uuid), self.host.pool)
6974
if not sr.attached_to_host(host):
7075
return False
7176
return True
7277

73-
def all_vdis_on_sr(self, sr):
78+
def all_vdis_on_sr(self, sr) -> bool:
7479
for vdi_uuid in self.vdi_uuids():
7580
if self.host.pool.get_vdi_sr_uuid(vdi_uuid) != sr.uuid:
7681
return False
@@ -84,7 +89,7 @@ def get_sr(self):
8489
assert sr.attached_to_host(self.host)
8590
return sr
8691

87-
def export(self, filepath, compress='none'):
92+
def export(self, filepath, compress='none') -> None:
8893
logging.info("Export VM %s to %s with compress=%s" % (self.uuid, filepath, compress))
8994
params = {
9095
'uuid': self.uuid,

lib/commands.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import logging
33
import shlex
44
import subprocess
5+
from typing import List, Literal, overload, Union
56

6-
import lib.config as config
77

8+
import lib.config as config
89
from lib.netutil import wrap_ip
910

11+
1012
class BaseCommandFailed(Exception):
1113
__slots__ = 'returncode', 'stdout', 'cmd'
1214

@@ -61,7 +63,7 @@ def _ellide_log_lines(log):
6163
return "\n{}".format("\n".join(reduced_message))
6264

6365
def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
64-
background, decode, options):
66+
background, decode, options) -> Union[SSHResult, SSHCommandFailed, str, bytes, None]:
6567
opts = list(options)
6668
opts.append('-o "BatchMode yes"')
6769
if suppress_fingerprint_warnings:
@@ -80,6 +82,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
8082
ssh_cmd = f"ssh root@{hostname_or_ip} {' '.join(opts)} {shlex.quote(command)}"
8183

8284
# Fetch banner and remove it to avoid stdout/stderr pollution.
85+
banner_res = None
8386
if config.ignore_ssh_banner:
8487
banner_res = subprocess.run(
8588
"ssh root@%s %s '%s'" % (hostname_or_ip, ' '.join(opts), '\n'),
@@ -97,9 +100,10 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
97100
stderr=subprocess.STDOUT
98101
)
99102
if background:
100-
return True, None
103+
return None
101104

102105
stdout = []
106+
assert process.stdout is not None
103107
for line in iter(process.stdout.readline, b''):
104108
readable_line = line.decode(errors='replace').strip()
105109
stdout.append(line)
@@ -112,34 +116,73 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
112116

113117
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
114118
if res.returncode == 255:
115-
return False, SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
119+
return SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
116120

117-
output = res.stdout
118-
if config.ignore_ssh_banner:
121+
output: Union[bytes, str] = res.stdout
122+
if banner_res:
119123
if banner_res.returncode == 255:
120-
return False, SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
124+
return SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
121125
output = output[len(banner_res.stdout):]
122126

123127
if decode:
128+
assert isinstance(output, bytes)
124129
output = output.decode()
125130

126131
if res.returncode and check:
127-
return False, SSHCommandFailed(res.returncode, output_for_errors, command)
132+
return SSHCommandFailed(res.returncode, output_for_errors, command)
128133

129134
if simple_output:
130-
return True, output.strip()
131-
return True, SSHResult(res.returncode, output)
135+
return output.strip()
136+
return SSHResult(res.returncode, output)
132137

133138
# The actual code is in _ssh().
134139
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
135140
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
136-
def ssh(hostname_or_ip, cmd, check=True, simple_output=True, suppress_fingerprint_warnings=True,
141+
@overload
142+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
143+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
144+
decode: Literal[True] = True, options: List[str] = []) -> str:
145+
...
146+
@overload
147+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[True] = True,
148+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
149+
decode: Literal[False], options: List[str] = []) -> bytes:
150+
...
151+
@overload
152+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
153+
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
154+
decode: bool = True, options: List[str] = []) -> SSHResult:
155+
...
156+
@overload
157+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check: bool = True, simple_output: Literal[False],
158+
suppress_fingerprint_warnings: bool = True, background: Literal[True],
159+
decode: bool = True, options: List[str] = []) -> None:
160+
...
161+
@overload
162+
def ssh(hostname_or_ip: str, cmd: Union[str, List[str]], *, check=True, simple_output: bool = True,
163+
suppress_fingerprint_warnings=True, background: bool = False,
164+
decode: bool = True, options: List[str] = []) \
165+
-> Union[str, bytes, SSHResult, None]:
166+
...
167+
def ssh(hostname_or_ip, cmd, *, check=True, simple_output=True,
168+
suppress_fingerprint_warnings=True,
137169
background=False, decode=True, options=[]):
138-
success, result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
139-
background, decode, options)
140-
if not success:
170+
result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
171+
background, decode, options)
172+
if isinstance(result_or_exc, SSHCommandFailed):
173+
raise result_or_exc
174+
else:
175+
return result_or_exc
176+
177+
def ssh_with_result(hostname_or_ip, cmd, suppress_fingerprint_warnings=True,
178+
background=False, decode=True, options=[]) -> SSHResult:
179+
result_or_exc = _ssh(hostname_or_ip, cmd, False, False, suppress_fingerprint_warnings,
180+
background, decode, options)
181+
if isinstance(result_or_exc, SSHCommandFailed):
141182
raise result_or_exc
142-
return result_or_exc
183+
elif isinstance(result_or_exc, SSHResult):
184+
return result_or_exc
185+
assert False, "unexpected type"
143186

144187
def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=True, local_dest=False):
145188
opts = '-o "BatchMode yes"'
@@ -173,6 +216,7 @@ def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=Tru
173216
return res
174217

175218
def sftp(hostname_or_ip, cmds, check=True, suppress_fingerprint_warnings=True):
219+
opts = ''
176220
if suppress_fingerprint_warnings:
177221
# Suppress warnings and questions related to host key fingerprints
178222
# because on a test network IPs get reused, VMs are reinstalled, etc.

lib/common.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
import time
66
import traceback
77
from enum import Enum
8+
from typing import TYPE_CHECKING, Dict, Optional, Union
89
from uuid import UUID
910

1011
import lib.commands as commands
12+
if TYPE_CHECKING:
13+
import lib.host
1114

1215
class PackageManagerEnum(Enum):
1316
UNKNOWN = 1
@@ -23,10 +26,13 @@ def vm_image(vm_key):
2326
return url
2427

2528
def prefix_object_name(label):
29+
name_prefix = None
2630
try:
2731
from data import OBJECTS_NAME_PREFIX
2832
name_prefix = OBJECTS_NAME_PREFIX
2933
except ImportError:
34+
pass
35+
if name_prefix is None:
3036
name_prefix = f"[{getpass.getuser()}]"
3137
return f"{name_prefix} {label}"
3238

@@ -154,9 +160,10 @@ def strtobool(str):
154160
return False
155161
raise ValueError("invalid truth value '{}'".format(str))
156162

157-
def _param_get(host, xe_prefix, uuid, param_name, key=None, accept_unknown_key=False):
163+
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = None,
164+
accept_unknown_key=False) -> Union[str, bool, None]:
158165
""" Common implementation for param_get. """
159-
args = {'uuid': uuid, 'param-name': param_name}
166+
args: Dict[str, Union[str, bool]] = {'uuid': uuid, 'param-name': param_name}
160167
if key is not None:
161168
args['param-key'] = key
162169
try:

lib/efi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def sign(payload, key_file, cert_file):
218218
"""Returns a signed PKCS7 of payload signed by key and cert."""
219219
with open(key_file, 'rb') as f:
220220
priv_key = serialization.load_pem_private_key(f.read(), password=None)
221+
assert isinstance(priv_key, (pkcs7.PKCS7PrivateKeyTypes))
221222

222223
with open(cert_file, 'rb') as f:
223224
cert = x509.load_pem_x509_certificate(f.read())

0 commit comments

Comments
 (0)