Skip to content

Commit efe20ab

Browse files
authored
Do not warn on jax usage when workarounds are available (#9624)
This prevent excessive logging when using xp.Trace or get_op_sharding.
1 parent 2329746 commit efe20ab

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

torch_xla/_internal/jax_workarounds.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def maybe_get_torchax():
5858
return None
5959

6060

61-
def maybe_get_jax():
61+
def maybe_get_jax(log=True):
6262
try:
6363
jax_import_guard()
6464
with jax_env_context():
@@ -67,6 +67,8 @@ def maybe_get_jax():
6767
jax.config.update('jax_use_shardy_partitioner', False)
6868
return jax
6969
except (ModuleNotFoundError, ImportError):
70-
logging.warn('You are trying to use a feature that requires jax/pallas.'
71-
'You can install Jax/Pallas via pip install torch_xla[pallas]')
72-
return None
70+
if log:
71+
logging.warning(
72+
'You are trying to use a feature that requires jax/pallas.'
73+
'You can install Jax/Pallas via pip install torch_xla[pallas]')
74+
return None

torch_xla/debug/profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __enter__(self):
131131

132132
self._jax_scope = None
133133
# Also enter the JAX named scope, to support torchax lowering.
134-
if jax := maybe_get_jax():
134+
if jax := maybe_get_jax(log=False):
135135
self._jax_scope = jax.named_scope(self.name)
136136
self._jax_scope.__enter__()
137137

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,8 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
646646
f"Partition spec length ({len(partition_spec)}) should be equal to the input rank ({len(t.shape)})."
647647

648648
tx = maybe_get_torchax()
649-
jax = maybe_get_jax()
649+
# Do not log jax warnings when workarounds are available.
650+
jax = maybe_get_jax(log=False)
650651
if (jax is not None) and (tx is not None) and isinstance(t, tx.tensor.Tensor):
651652
from jax.sharding import PartitionSpec as P, NamedSharding
652653
jmesh = mesh.get_jax_mesh()

0 commit comments

Comments
 (0)