Skip to content

[Bug] fail to set JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache with multi engines #796

@aolemila

Description

@aolemila

Checklist

  • I have searched related issues but cannot get the expected help.
  • The bug has not been fixed in the latest version.
  • Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sgl-jax/discussions/new/choose Otherwise, it will be closed.
  • Please use English, otherwise it will be closed.

Describe the bug

Based on #795, run JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 test/srt/rl/multi_engines_in_one_process.py and you will meet the following problems.

                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/scheduler.py", line 280, in __init__
    self.tp_worker = TpWorkerClass(
                     ^^^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py", line 36, in __init__
    self.worker = ModelWorker(server_args, mesh=mesh)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/tp_worker.py", line 105, in __init__
    self.model_runner = ModelRunner(
                        ^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 117, in __init__
    self.initialize()
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 136, in initialize
    self.load_model()
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 280, in load_model
    self.model = self.model_loader.load_model(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_loader/loader.py", line 205, in load_model
    jit_model = self._get_model(model, model_config)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_loader/loader.py", line 262, in _get_model
    model.load_weights(model_config)
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/models/llama.py", line 407, in load_weights
    loader.load_weights_from_safetensors(weight_mappings)
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/utils/weight_utils.py", line 853, in load_weights_from_safetensors
    self._process_and_assign_weight(params, hf_key, lazy_weight, mapping)
  File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/utils/weight_utils.py", line 1182, in _process_and_assign_weight

  File "/home/gcpuser/miniconda3/envs/sgl312/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/gcpuser/miniconda3/envs/sgl312/lib/python3.12/site-packages/jax/_src/array.py", line 635, in _value
    npy_value, did_copy = self._single_device_array_to_np_array_did_copy()
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.JaxRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x11e (from TensorCoreSequencer:1:0x24e): schecklt: Invalid logical z: enhanced-barrier-parent-phase-1 no HLO mapping
=== Source Location Trace: ===
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:180

Reproduction

Based on #795, run JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 test/srt/rl/multi_engines_in_one_process.py

pip list | egrep 'jax|flax|libtpu'
flax                   0.12.0
jax                    0.8.1
jaxlib                 0.8.1
jaxtyping              0.3.7
libtpu                 0.0.24
sglang-jax             0.0.2         /home/gcpuser/aolemila/repos/sglang-jax/python

Environment

tpu-v6e-4.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions