-
Notifications
You must be signed in to change notification settings - Fork 76
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Checklist
- I have searched related issues but cannot get the expected help.
- The bug has not been fixed in the latest version.
- Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
- If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sgl-jax/discussions/new/choose Otherwise, it will be closed.
- Please use English, otherwise it will be closed.
Describe the bug
Based on #795, run JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 test/srt/rl/multi_engines_in_one_process.py and you will meet the following problems.
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/scheduler.py", line 280, in __init__
self.tp_worker = TpWorkerClass(
^^^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/tp_worker_overlap_thread.py", line 36, in __init__
self.worker = ModelWorker(server_args, mesh=mesh)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/managers/tp_worker.py", line 105, in __init__
self.model_runner = ModelRunner(
^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 117, in __init__
self.initialize()
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 136, in initialize
self.load_model()
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_executor/model_runner.py", line 280, in load_model
self.model = self.model_loader.load_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_loader/loader.py", line 205, in load_model
jit_model = self._get_model(model, model_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/model_loader/loader.py", line 262, in _get_model
model.load_weights(model_config)
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/models/llama.py", line 407, in load_weights
loader.load_weights_from_safetensors(weight_mappings)
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/utils/weight_utils.py", line 853, in load_weights_from_safetensors
self._process_and_assign_weight(params, hf_key, lazy_weight, mapping)
File "/home/gcpuser/aolemila/repos/sglang-jax/python/sgl_jax/srt/utils/weight_utils.py", line 1182, in _process_and_assign_weight
File "/home/gcpuser/miniconda3/envs/sgl312/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/gcpuser/miniconda3/envs/sgl312/lib/python3.12/site-packages/jax/_src/array.py", line 635, in _value
npy_value, did_copy = self._single_device_array_to_np_array_did_copy()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.JaxRuntimeError: INTERNAL: Core halted unexpectedly: INTERNAL: Accelerator device halted prematurely, perhaps due to an on-device check-failure. Node 0 halted unexpectedly at tag:pc TensorCoreSequencer:1:0x11e (from TensorCoreSequencer:1:0x24e): schecklt: Invalid logical z: enhanced-barrier-parent-phase-1 no HLO mapping
=== Source Location Trace: ===
learning/45eac/tpu/runtime/hal/internal/tpu_program_termination_validation.cc:180
Reproduction
Based on #795, run JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 test/srt/rl/multi_engines_in_one_process.py
pip list | egrep 'jax|flax|libtpu'
flax 0.12.0
jax 0.8.1
jaxlib 0.8.1
jaxtyping 0.3.7
libtpu 0.0.24
sglang-jax 0.0.2 /home/gcpuser/aolemila/repos/sglang-jax/python
Environment
tpu-v6e-4.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working