@@ -148,6 +148,45 @@ def format_annotation(annotation,
148
148
role = role , prefix = prefix , full_name = full_name , formatted_args = formatted_args )
149
149
150
150
151
+ # reference: https://github.com/pytorch/pytorch/pull/46548/files
152
+ def normalize_source_lines (sourcelines : str ) -> str :
153
+ """
154
+ This helper function accepts a list of source lines. It finds the
155
+ indentation level of the function definition (`def`), then it indents
156
+ all lines in the function body to a point at or greater than that
157
+ level. This allows for comments and continued string literals that
158
+ are at a lower indentation than the rest of the code.
159
+ Arguments:
160
+ sourcelines: source code
161
+ Returns:
162
+ source lines that have been correctly aligned
163
+ """
164
+ sourcelines = sourcelines .split ("\n " )
165
+
166
+ def remove_prefix (text , prefix ):
167
+ return text [text .startswith (prefix ) and len (prefix ):]
168
+
169
+ # Find the line and line number containing the function definition
170
+ for i , l in enumerate (sourcelines ):
171
+ if l .lstrip ().startswith ("def" ):
172
+ idx = i
173
+ break
174
+ else :
175
+ return "\n " .join (sourcelines )
176
+ fn_def = sourcelines [idx ]
177
+
178
+ # Get a string representing the amount of leading whitespace
179
+ whitespace = fn_def .split ("def" )[0 ]
180
+
181
+ # Add this leading whitespace to all lines before and after the `def`
182
+ aligned_prefix = [whitespace + remove_prefix (s , whitespace ) for s in sourcelines [:idx ]]
183
+ aligned_suffix = [whitespace + remove_prefix (s , whitespace ) for s in sourcelines [idx + 1 :]]
184
+
185
+ # Put it together again
186
+ aligned_prefix .append (fn_def )
187
+ return "\n " .join (aligned_prefix + aligned_suffix )
188
+
189
+
151
190
def process_signature (app , what : str , name : str , obj , options , signature , return_annotation ):
152
191
if not callable (obj ):
153
192
return
@@ -270,7 +309,8 @@ def _one_child(module):
270
309
return children [0 ]
271
310
272
311
try :
273
- obj_ast = ast .parse (textwrap .dedent (inspect .getsource (obj )), ** parse_kwargs )
312
+ obj_ast = ast .parse (textwrap .dedent (
313
+ normalize_source_lines (inspect .getsource (obj ))), ** parse_kwargs )
274
314
except (OSError , TypeError ):
275
315
return {}
276
316
0 commit comments