|
34 | 34 | from jax import core as jax_core |
35 | 35 | from jax import custom_derivatives as cd |
36 | 36 | from jax import linear_util as lu |
37 | | -from jax import source_info_util |
38 | 37 | from jax import tree_util |
39 | 38 | from jax import util as jax_util |
| 39 | +from jax._src import source_info_util |
40 | 40 | from jax.interpreters import partial_eval as pe |
41 | 41 | import numpy as onp |
42 | 42 |
|
@@ -282,14 +282,13 @@ def handle_call_primitive(self, call_primitive, f, tracers, params, is_map): |
282 | 282 | return current_custom_rules()[call_primitive](self, f, *tracers, **params) |
283 | 283 | if call_primitive in pe.call_partial_eval_rules: |
284 | 284 | raise NotImplementedError |
285 | | - in_pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) |
| 285 | + in_pvals = [t.pval for t in tracers] |
286 | 286 | if is_map: |
287 | | - pvs = [ |
288 | | - None if pv is None else mapped_aval(params['axis_size'], pv) |
289 | | - for pv in in_pvs |
290 | | - ] |
291 | | - else: |
292 | | - pvs = in_pvs |
| 287 | + unknown = pe.PartialVal.unknown |
| 288 | + in_pvals = [pval if pval.is_known() or in_axis is None else |
| 289 | + unknown(mapped_aval(params['axis_size'], in_axis, pval[0])) |
| 290 | + for pval, in_axis in zip(in_pvals, params['in_axes'])] |
| 291 | + pvs, in_consts = jax_util.unzip2(t.pval for t in tracers) |
293 | 292 | keys = tuple(t.is_key() for t in tracers) |
294 | 293 | new_settings = UnzipSettings(settings.tag, call_primitive in block_registry) |
295 | 294 | fun, aux = unzip_eval(f, self, keys, tuple(pvs), new_settings) |
@@ -360,12 +359,6 @@ def _bound_output_tracers(self, primitive, params, jaxpr, consts, env, |
360 | 359 | for pv, const, key in safe_zip(out_pvs, out_consts, out_keys) |
361 | 360 | ] |
362 | 361 | new_params = dict(params, name=name, call_jaxpr=lifted_jaxpr) |
363 | | - if is_map: |
364 | | - new_params = dict( |
365 | | - new_params, |
366 | | - mapped_invars=tuple([True] * len(const_tracers) + |
367 | | - [False] * len(env_tracers) + |
368 | | - [True] * len(in_tracers))) |
369 | 362 | if 'donated_invars' in params: |
370 | 363 | new_donated_invars = ( |
371 | 364 | (False,) * len(const_tracers) + (False,) * len(env_tracers) + |
|
0 commit comments