@@ -288,19 +288,29 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
288288 in_pvals = [pval if pval .is_known () or in_axis is None else
289289 unknown (mapped_aval (params ['axis_size' ], in_axis , pval [0 ]))
290290 for pval , in_axis in zip (in_pvals , params ['in_axes' ])]
291+ out_axes_thunk = params ['out_axes_thunk' ]
292+ @jax_util .as_hashable_function (closure = ('unzip' , out_axes_thunk ))
293+ def new_out_axes_thunk ():
294+ out_axes = out_axes_thunk ()
295+ assert all (out_axis == 0 for out_axis in out_axes )
296+ _ , num_outputs , _ = aux ()
297+ return (0 ,) * num_outputs
298+ new_params = dict (params , out_axes_thunk = new_out_axes_thunk )
299+ else :
300+ new_params = params
291301 pvs , in_consts = jax_util .unzip2 (t .pval for t in tracers )
292302 keys = tuple (t .is_key () for t in tracers )
293303 new_settings = UnzipSettings (settings .tag , call_primitive in block_registry )
294304 fun , aux = unzip_eval (f , self , keys , tuple (pvs ), new_settings )
295- out_flat = call_primitive .bind (fun , * in_consts , ** params )
296- success , results = aux ()
305+ out_flat = call_primitive .bind (fun , * in_consts , ** new_params )
306+ success , _ , results = aux ()
297307 if not success :
298308 out_pvs , out_keys , jaxpr , env = results
299309 out_pv_consts , consts = jax_util .split_list (out_flat , [len (out_pvs )])
300- out_tracers = self ._bound_output_tracers (call_primitive , params , jaxpr ,
301- consts , env , tracers , out_pvs ,
302- out_pv_consts , out_keys , name ,
303- is_map )
310+ out_tracers = self ._bound_output_tracers (call_primitive , new_params ,
311+ jaxpr , consts , env , tracers ,
312+ out_pvs , out_pv_consts ,
313+ out_keys , name , is_map )
304314 return out_tracers
305315 init_name = jax_util .wrap_name (name , 'init' )
306316 apply_name = jax_util .wrap_name (name , 'apply' )
@@ -319,15 +329,16 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
319329 [len (apply_pvs )])
320330
321331 variable_tracers = self ._bound_output_tracers (
322- call_primitive , params , init_jaxpr , init_consts , init_env , key_tracers ,
323- init_pvs , init_pv_consts , [True ] * len (init_pvs ), init_name , is_map )
332+ call_primitive , new_params , init_jaxpr , init_consts , init_env ,
333+ key_tracers , init_pvs , init_pv_consts , [True ] * len (init_pvs ),
334+ init_name , is_map )
324335
325336 unflat_variables = tree_util .tree_unflatten (variable_tree , variable_tracers )
326337 if call_primitive is harvest .nest_p :
327338 variable_dict = harvest .sow (
328339 dict (safe_zip (variable_names , unflat_variables )),
329340 tag = settings .tag ,
330- name = params ['scope' ],
341+ name = new_params ['scope' ],
331342 mode = 'strict' )
332343 unflat_variables = tuple (variable_dict [name ] for name in variable_names )
333344 else :
@@ -342,7 +353,7 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map):
342353 variable_tracers = tree_util .tree_leaves (unflat_variables )
343354
344355 out_tracers = self ._bound_output_tracers (
345- call_primitive , params , apply_jaxpr , apply_consts , apply_env ,
356+ call_primitive , new_params , apply_jaxpr , apply_consts , apply_env ,
346357 variable_tracers + abstract_tracers , apply_pvs , apply_pv_consts ,
347358 apply_keys , apply_name , is_map )
348359 return out_tracers
@@ -365,6 +376,11 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env,
365376 tuple (v for v , t in zip (params ['donated_invars' ], in_tracers )
366377 if not t .pval .is_known ()))
367378 new_params ['donated_invars' ] = new_donated_invars
379+ if is_map :
380+ out_axes = params ['out_axes_thunk' ]()
381+ assert all (out_axis == 0 for out_axis in out_axes )
382+ new_params ['out_axes' ] = (0 ,) * len (out_tracers )
383+ del new_params ['out_axes_thunk' ]
368384 eqn = pe .new_eqn_recipe (
369385 tuple (const_tracers + env_tracers + in_tracers ), out_tracers , primitive ,
370386 new_params , source_info_util .current ()) # pytype: disable=wrong-arg-types
@@ -442,14 +458,16 @@ def unzip_eval_wrapper(pvs, *consts):
442458 out = (
443459 tuple (init_pv_consts ) + tuple (init_consts ) + tuple (apply_pv_consts ) +
444460 tuple (apply_consts ))
445- yield out , (success , ((init_pvs , len (init_consts ), apply_pvs ),
446- (init_jaxpr , apply_jaxpr ), (init_env ,
447- apply_env ), metadata ))
461+ yield out , (success , len (out ),
462+ ((init_pvs , len (init_consts ), apply_pvs ),
463+ (init_jaxpr , apply_jaxpr ),
464+ (init_env , apply_env ),
465+ metadata ))
448466 else :
449467 jaxpr , (out_pvals , out_keys , consts , env ) = result
450468 out_pvs , out_consts = jax_util .unzip2 (out_pvals )
451469 out = tuple (out_consts ) + tuple (consts )
452- yield out , (success , (out_pvs , out_keys , jaxpr , env ))
470+ yield out , (success , len ( out ), (out_pvs , out_keys , jaxpr , env ))
453471
454472
455473@lu .transformation
0 commit comments