@@ -340,6 +340,11 @@ def call_function(
340340 elif target == torch .ops .higher_order .cond :
341341 pred , true_fn , false_fn , inputs = args
342342 return self .callback .call_cond (pred , true_fn , false_fn , inputs , meta )
343+ elif target == torch .ops .higher_order .while_loop :
344+ cond , body , carried_inputs , additional_inputs = args
345+ return self .callback .call_while (
346+ cond , body , carried_inputs , additional_inputs , meta
347+ )
343348 elif target == torch .ops .higher_order .map_impl :
344349 f , mapped_args , operands = args # type: ignore[assignment]
345350 return self .callback .call_map (f , mapped_args , operands , meta )
@@ -497,6 +502,31 @@ def call_cond(
497502 meta ,
498503 )
499504
505+ def call_while (
506+ self ,
507+ cond_fn : torch .fx .GraphModule ,
508+ body_fn : torch .fx .GraphModule ,
509+ carried_inputs : List [Argument ],
510+ additional_inputs : List [Argument ],
511+ meta : NodeMetadata ,
512+ ) -> ProxyValue :
513+ cond_fn = self .call_submodule (cond_fn , (* carried_inputs , * additional_inputs ))
514+ body_fn = self .call_submodule (body_fn , (* carried_inputs , * additional_inputs ))
515+ assert cond_fn is not None
516+ assert body_fn is not None
517+ return self ._fx (
518+ "call_function" ,
519+ torch .ops .higher_order .while_loop ,
520+ (
521+ cond_fn .graph_module ,
522+ body_fn .graph_module ,
523+ carried_inputs ,
524+ additional_inputs ,
525+ ),
526+ {},
527+ meta ,
528+ )
529+
500530 def call_map (
501531 self ,
502532 f : torch .fx .GraphModule ,
0 commit comments