2
2
import logging
3
3
import shlex
4
4
import subprocess
5
+ from typing import List , Literal , overload , Union
5
6
6
- import lib .config as config
7
7
8
+ import lib .config as config
8
9
from lib .netutil import wrap_ip
9
10
11
+
10
12
class BaseCommandFailed (Exception ):
11
13
__slots__ = 'returncode' , 'stdout' , 'cmd'
12
14
@@ -61,7 +63,7 @@ def _ellide_log_lines(log):
61
63
return "\n {}" .format ("\n " .join (reduced_message ))
62
64
63
65
def _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
64
- background , decode , options ):
66
+ background , decode , options ) -> Union [ SSHResult , SSHCommandFailed , str , bytes , None ] :
65
67
opts = list (options )
66
68
opts .append ('-o "BatchMode yes"' )
67
69
if suppress_fingerprint_warnings :
@@ -80,6 +82,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
80
82
ssh_cmd = f"ssh root@{ hostname_or_ip } { ' ' .join (opts )} { shlex .quote (command )} "
81
83
82
84
# Fetch banner and remove it to avoid stdout/stderr pollution.
85
+ banner_res = None
83
86
if config .ignore_ssh_banner :
84
87
banner_res = subprocess .run (
85
88
"ssh root@%s %s '%s'" % (hostname_or_ip , ' ' .join (opts ), '\n ' ),
@@ -97,9 +100,10 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
97
100
stderr = subprocess .STDOUT
98
101
)
99
102
if background :
100
- return True , None
103
+ return None
101
104
102
105
stdout = []
106
+ assert process .stdout is not None
103
107
for line in iter (process .stdout .readline , b'' ):
104
108
readable_line = line .decode (errors = 'replace' ).strip ()
105
109
stdout .append (line )
@@ -112,34 +116,73 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
112
116
113
117
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
114
118
if res .returncode == 255 :
115
- return False , SSHCommandFailed (255 , "SSH Error: %s" % output_for_errors , command )
119
+ return SSHCommandFailed (255 , "SSH Error: %s" % output_for_errors , command )
116
120
117
- output = res .stdout
118
- if config . ignore_ssh_banner :
121
+ output : Union [ bytes , str ] = res .stdout
122
+ if banner_res :
119
123
if banner_res .returncode == 255 :
120
- return False , SSHCommandFailed (255 , "SSH Error: %s" % banner_res .stdout .decode (errors = 'replace' ), command )
124
+ return SSHCommandFailed (255 , "SSH Error: %s" % banner_res .stdout .decode (errors = 'replace' ), command )
121
125
output = output [len (banner_res .stdout ):]
122
126
123
127
if decode :
128
+ assert isinstance (output , bytes )
124
129
output = output .decode ()
125
130
126
131
if res .returncode and check :
127
- return False , SSHCommandFailed (res .returncode , output_for_errors , command )
132
+ return SSHCommandFailed (res .returncode , output_for_errors , command )
128
133
129
134
if simple_output :
130
- return True , output .strip ()
131
- return True , SSHResult (res .returncode , output )
135
+ return output .strip ()
136
+ return SSHResult (res .returncode , output )
132
137
133
138
# The actual code is in _ssh().
134
139
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
135
140
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
136
- def ssh (hostname_or_ip , cmd , check = True , simple_output = True , suppress_fingerprint_warnings = True ,
141
+ @overload
142
+ def ssh (hostname_or_ip : str , cmd : Union [str , List [str ]], * , check : bool = True , simple_output : Literal [True ] = True ,
143
+ suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
144
+ decode : Literal [True ] = True , options : List [str ] = []) -> str :
145
+ ...
146
+ @overload
147
+ def ssh (hostname_or_ip : str , cmd : Union [str , List [str ]], * , check : bool = True , simple_output : Literal [True ] = True ,
148
+ suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
149
+ decode : Literal [False ], options : List [str ] = []) -> bytes :
150
+ ...
151
+ @overload
152
+ def ssh (hostname_or_ip : str , cmd : Union [str , List [str ]], * , check : bool = True , simple_output : Literal [False ],
153
+ suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
154
+ decode : bool = True , options : List [str ] = []) -> SSHResult :
155
+ ...
156
+ @overload
157
+ def ssh (hostname_or_ip : str , cmd : Union [str , List [str ]], * , check : bool = True , simple_output : Literal [False ],
158
+ suppress_fingerprint_warnings : bool = True , background : Literal [True ],
159
+ decode : bool = True , options : List [str ] = []) -> None :
160
+ ...
161
+ @overload
162
+ def ssh (hostname_or_ip : str , cmd : Union [str , List [str ]], * , check = True , simple_output : bool = True ,
163
+ suppress_fingerprint_warnings = True , background : bool = False ,
164
+ decode : bool = True , options : List [str ] = []) \
165
+ -> Union [str , bytes , SSHResult , None ]:
166
+ ...
167
+ def ssh (hostname_or_ip , cmd , * , check = True , simple_output = True ,
168
+ suppress_fingerprint_warnings = True ,
137
169
background = False , decode = True , options = []):
138
- success , result_or_exc = _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
139
- background , decode , options )
140
- if not success :
170
+ result_or_exc = _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
171
+ background , decode , options )
172
+ if isinstance (result_or_exc , SSHCommandFailed ):
173
+ raise result_or_exc
174
+ else :
175
+ return result_or_exc
176
+
177
+ def ssh_with_result (hostname_or_ip , cmd , suppress_fingerprint_warnings = True ,
178
+ background = False , decode = True , options = []) -> SSHResult :
179
+ result_or_exc = _ssh (hostname_or_ip , cmd , False , False , suppress_fingerprint_warnings ,
180
+ background , decode , options )
181
+ if isinstance (result_or_exc , SSHCommandFailed ):
141
182
raise result_or_exc
142
- return result_or_exc
183
+ elif isinstance (result_or_exc , SSHResult ):
184
+ return result_or_exc
185
+ assert False , "unexpected type"
143
186
144
187
def scp (hostname_or_ip , src , dest , check = True , suppress_fingerprint_warnings = True , local_dest = False ):
145
188
opts = '-o "BatchMode yes"'
@@ -173,6 +216,7 @@ def scp(hostname_or_ip, src, dest, check=True, suppress_fingerprint_warnings=Tru
173
216
return res
174
217
175
218
def sftp (hostname_or_ip , cmds , check = True , suppress_fingerprint_warnings = True ):
219
+ opts = ''
176
220
if suppress_fingerprint_warnings :
177
221
# Suppress warnings and questions related to host key fingerprints
178
222
# because on a test network IPs get reused, VMs are reinstalled, etc.
0 commit comments