@@ -117,9 +117,13 @@ def _check(
117117 """
118118 if cls is not None :
119119 if not cond :
120- raise cls (f"{ msg } \n \n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} " )
120+ smsg = msg if isinstance (msg , str ) else msg ()
121+ raise cls (f"{ smsg } \n \n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} " )
121122 return
122- assert cond , f"{ msg } \n \n --\n { ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
123+ assert cond , (
124+ f"{ msg if isinstance (msg , str ) else msg ()} \n \n --\n "
125+ f"{ ast .unparse (node )} \n --\n { ast .dump (node , indent = 2 )} "
126+ )
123127
124128 def visit_Name (self , node ):
125129 node = self .generic_visit (node )
@@ -362,18 +366,30 @@ def _find_loop_vars(self, node):
362366 assert isinstance (node , ast .For ), f"Unexpected type { type (node )} for node"
363367 finder = ShapeFinder ()
364368 finder .visit (node .iter )
365- scan_vars = finder .found_shape
369+ scan_shape_vars = finder .found_shape
370+ scan_vars = set ()
366371
367372 finder = UsedVarsFinder ()
368373 for stmt in node .body :
369374 finder .visit (stmt )
370375
376+ assigned_in_body = set ()
377+ for stmt in node .body :
378+ if isinstance (stmt , ast .Assign ):
379+ for tgt in stmt .targets :
380+ if isinstance (tgt , ast .Name ) and isinstance (tgt .value .ctx , ast .Store ):
381+ assigned_in_body |= {tgt .value .id }
382+
371383 extra_defined = set ()
372384 for stmt in node .body :
373385 if isinstance (stmt , ast .Assign ):
374386 for tgt in stmt .targets :
375387 if isinstance (tgt , ast .Subscript ):
376- if isinstance (tgt .value , ast .Name ):
388+ # It means the target existed before.
389+ if (
390+ isinstance (tgt .value , ast .Name )
391+ and tgt .value .id not in assigned_in_body
392+ ):
377393 extra_defined .add (tgt .value .id )
378394
379395 loop_vars = set ()
@@ -382,11 +398,21 @@ def _find_loop_vars(self, node):
382398 elif isinstance (node .target , (ast .Tuple , ast .List )):
383399 loop_vars |= {elt .id for elt in node .target .elts if isinstance (elt , ast .Name )}
384400
385- output_vars = finder .defined | extra_defined
386- input_vars = finder .used - finder .defined - loop_vars - scan_vars - output_vars
401+ output_vars = finder .defined | assigned_in_body
402+ input_vars = (
403+ finder .used
404+ - finder .defined
405+ - loop_vars
406+ - scan_shape_vars
407+ - scan_vars
408+ - output_vars
409+ - assigned_in_body
410+ - extra_defined
411+ )
387412 return dict (
388- init = [] ,
413+ init = sorted ( extra_defined ) ,
389414 loop = sorted (loop_vars ),
415+ scan_shape = sorted (scan_shape_vars ),
390416 scan = sorted (scan_vars ),
391417 input = sorted (input_vars ),
392418 output = sorted (output_vars ),
@@ -397,19 +423,30 @@ def visit_For(self, node):
397423 self .generic_visit (node )
398424 # look for variables, loop, inputs and outputs of the body
399425 vars = self ._find_loop_vars (node )
400- init_vars , loop_vars , scan_vars , input_vars , output_vars = [
401- vars [k ] for k in ["init" , "loop" , "scan" , "input" , "output" ]
426+ init_vars , loop_vars , scan_shape_vars , scan_vars , input_vars , output_vars = [
427+ vars [k ] for k in ["init" , "loop" , "scan_shape" , " scan" , "input" , "output" ]
402428 ]
403-
404- # return, one value or a tuple of values
405- return_stmt = ast .Return (
406- value = (
407- ast .Name (id = output_vars [0 ], ctx = ast .Load ())
408- if len (output_vars ) == 1
409- else ast .Tuple (
410- elts = [ast .Name (id = v , ctx = ast .Load ()) for v in output_vars ], ctx = ast .Load ()
411- )
412- )
429+ self ._check (
430+ len (scan_shape_vars ) == len (loop_vars ),
431+ node ,
432+ lambda : (
433+ f"Inconsistencies between loop_vars={ loop_vars } "
434+ f"and scan_shape_vars={ scan_shape_vars } "
435+ ),
436+ )
437+ self ._check (
438+ len (scan_shape_vars ) in {0 , 1 },
439+ node ,
440+ lambda : f"Inconsistencies with scan_shape_vars={ scan_shape_vars } " ,
441+ )
442+ self ._check (
443+ (len (scan_shape_vars ) == 0 or len (scan_vars ) == 0 )
444+ and (scan_shape_vars or scan_vars ),
445+ node ,
446+ lambda : (
447+ f"Inconsistencies between scan_vars={ scan_vars } "
448+ f"and scan_shape_vars={ scan_shape_vars } "
449+ ),
413450 )
414451
415452 # creates the function
@@ -419,12 +456,50 @@ def visit_For(self, node):
419456 name = func_name ,
420457 args = ast .arguments (
421458 posonlyargs = [],
422- args = [ast .arg (arg = v ) for v in [* init_vars , * scan_vars , * input_vars ]],
459+ args = [
460+ ast .arg (arg = v )
461+ for v in [
462+ * init_vars ,
463+ * loop_vars ,
464+ * scan_vars ,
465+ * scan_shape_vars ,
466+ * input_vars ,
467+ ]
468+ ],
423469 kwonlyargs = [],
424470 kw_defaults = [],
425471 defaults = [],
426472 ),
427- body = [* node .body , return_stmt ],
473+ body = [
474+ * [
475+ ast .Assign (
476+ targets = [ast .Name (id = i , ctx = ast .Load ())],
477+ value = [
478+ ast .Call (
479+ func = ast .Attribute (
480+ value = ast .Name (id = i , ctx = ast .Load ()),
481+ attr = "clone" ,
482+ ctx = ast .Load (),
483+ ),
484+ args = [],
485+ keywords = [],
486+ ctx = ast .Load (),
487+ )
488+ ],
489+ )
490+ for i in init_vars
491+ ],
492+ * node .body ,
493+ ast .Return (
494+ value = ast .List (
495+ [
496+ ast .Name (id = v , ctx = ast .Load ())
497+ for v in [* init_vars , * loop_vars , * output_vars ]
498+ ],
499+ ctx = ast .Load (),
500+ )
501+ ),
502+ ],
428503 decorator_list = [],
429504 ctx = ast .Store (),
430505 )
@@ -452,17 +527,56 @@ def visit_For(self, node):
452527 elts = [ast .Name (id = v , ctx = ast .Load ()) for v in init_vars ], ctx = ast .Store ()
453528 ),
454529 ast .List (
455- elts = [ast .Name (id = v , ctx = ast .Load ()) for v in scan_vars ], ctx = ast .Store ()
530+ elts = [
531+ * [
532+ ast .Call (
533+ ast .Attribute (
534+ value = ast .Name (id = "torch" , ctx = ast .Load ()),
535+ attr = "arange" ,
536+ ctx = ast .Load (),
537+ ),
538+ args = [
539+ ast .Subscript (
540+ value = ast .Attribute (
541+ value = ast .Name (id = v , ctx = ast .Load ()),
542+ attr = "shape" ,
543+ ctx = ast .Load (),
544+ ),
545+ slice = ast .Constant (value = 0 , ctx = ast .Load ()),
546+ ctx = ast .Load (),
547+ ),
548+ ],
549+ keywords = [
550+ ast .keyword (
551+ arg = "dtype" ,
552+ value = ast .Attribute (
553+ value = ast .Name (id = "torch" , ctx = ast .Load ()),
554+ attr = "int64" ,
555+ ctx = ast .Load (),
556+ ),
557+ )
558+ ],
559+ ctx = ast .Load (),
560+ )
561+ for v in scan_shape_vars
562+ ],
563+ * [ast .Name (id = v , ctx = ast .Load ()) for v in scan_vars ],
564+ ],
565+ ctx = ast .Store (),
456566 ),
457567 ast .List (
458- elts = [ast .Name (id = v , ctx = ast .Load ()) for v in input_vars ], ctx = ast .Store ()
568+ elts = [
569+ ast .Name (id = v , ctx = ast .Load ()) for v in [* scan_shape_vars , * input_vars ]
570+ ],
571+ ctx = ast .Store (),
459572 ),
460573 ],
461574 keywords = [],
462575 ctx = ast .Load (),
463576 )
464- target = ast .List (
465- [ast .Name (id = v , ctx = ast .Store ()) for v in output_vars ], ctx = ast .Store ()
577+ target = ast .Tuple (
578+ [ast .Name (id = v , ctx = ast .Store ()) for v in [* init_vars , * loop_vars , * output_vars ]],
579+ ctx = ast .Store (),
466580 )
467581 assign = ast .Assign (targets = [target ], value = call )
468582 return [func_def , assign ]
0 commit comments